├── requirements.txt ├── LICENSE ├── token-classification ├── scripts │ ├── run_pos_tagging_test.sh │ ├── run_pos_tagging_pred.sh │ ├── run_ner.sh │ └── run_pos_tagging.sh ├── utils.py └── run_token_classification.py ├── text-classification ├── scripts │ ├── run_text_classification_test_poetry.sh │ ├── run_text_classification_test.sh │ ├── run_text_classification_madar_tweet.sh │ ├── run_text_classification.sh │ ├── run_text_classification_test_NADI.sh │ ├── run_text_classification_test_madar_tweet.sh │ └── run_text_classification_test_madar_tweet_pred.sh ├── utils │ ├── vote_did.py │ ├── metrics │ │ └── __init__.py │ └── data_utils.py └── run_text_classification.py ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1 2 | transformers==3.1.0 3 | seqeval==1.2.2 4 | scikit-learn==0.24.0 5 | camel_tools 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 New York University Abu Dhabi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /token-classification/scripts/run_pos_tagging_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | ################################# 17 | # POS TAGGING TEST EVAL SCRIPT 18 | ################################# 19 | 20 | # aubmindlab/bert-base-arabertv01 21 | # lanwuwei/GigaBERT-v4-Arabic-and-English 22 | # bashar-talafha/multi-dialect-bert-base-arabic 23 | # asafaya/bert-base-arabic 24 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 25 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 26 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 27 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 28 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 29 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 30 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 31 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 32 | # bert-base-multilingual-cased 33 | # /scratch/ba63/UBC-NLP/MARBERT 34 | # /scratch/ba63/UBC-NLP/ARBERT 35 | # /scratch/ba63/bert-base-arabertv02/ 36 | # /scratch/ba63/bert-base-arabertv01/ 37 | 38 | export DATA_DIR=/scratch/ba63/magold_files/GULF 39 | export MAX_LENGTH=512 40 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/pos_models_new/GULF/ARBERT_POS/checkpoint-3500-best 41 | export BATCH_SIZE=32 42 | export SEED=12345 43 | 44 | 45 | python run_token_classification.py \ 46 | --data_dir $DATA_DIR \ 47 | --task_type pos \ 48 | --labels $DATA_DIR/labels.txt \ 49 | --model_name_or_path $OUTPUT_DIR \ 50 | --output_dir $OUTPUT_DIR \ 51 | --max_seq_length $MAX_LENGTH \ 52 | --per_device_eval_batch_size $BATCH_SIZE \ 53 | --seed $SEED \ 54 | --overwrite_cache \ 55 | --do_pred 56 | -------------------------------------------------------------------------------- /token-classification/scripts/run_pos_tagging_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p nvidia 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | ################################# 17 | # POS TAGGING DEV EVAL SCRIPT 18 | ################################# 19 | 20 | # aubmindlab/bert-base-arabertv01 21 | # lanwuwei/GigaBERT-v4-Arabic-and-English 22 | # bashar-talafha/multi-dialect-bert-base-arabic 23 | # asafaya/bert-base-arabic 24 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 25 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 26 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 27 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 28 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 29 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 30 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 31 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 32 | # bert-base-multilingual-cased 33 | # /scratch/ba63/UBC-NLP/MARBERT 34 | # /scratch/ba63/UBC-NLP/ARBERT 35 | # /scratch/ba63/bert-base-arabertv02/ 36 | # /scratch/ba63/bert-base-arabertv01 37 | 38 | export DATA_DIR=/scratch/ba63/magold_files/GULF 39 | export MAX_LENGTH=512 40 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/pos_models_new/GULF/ARBERT_POS 41 | export BATCH_SIZE=32 42 | export SAVE_STEPS=500 43 | export SEED=12345 44 | 45 | for f in $OUTPUT_DIR/checkpoint-*/ 46 | 47 | do 48 | echo $f 49 | python run_token_classification.py \ 50 | --data_dir $DATA_DIR \ 51 | --task_type pos \ 52 | --labels $DATA_DIR/labels.txt \ 53 | --model_name_or_path $f \ 54 | --output_dir $f \ 55 | --max_seq_length $MAX_LENGTH \ 56 | --per_device_eval_batch_size $BATCH_SIZE \ 57 | --seed $SEED \ 58 | --overwrite_cache \ 59 | --do_eval 60 | done 61 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_test_poetry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p nvidia 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | # export ARABIC_DATA=data/test 17 | # export TASK_NAME=arabic_sentiment 18 | export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 19 | export TASK_NAME=arabic_poetry 20 | 21 | # aubmindlab/bert-base-arabertv01 22 | # lanwuwei/GigaBERT-v4-Arabic-and-English 23 | # bashar-talafha/multi-dialect-bert-base-arabic 24 | # asafaya/bert-base-arabic 25 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 26 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 27 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 28 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 29 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 30 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 31 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 32 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 33 | # bert-base-multilingual-cased 34 | # /scratch/ba63/UBC-NLP/MARBERT 35 | # /scratch/ba63/UBC-NLP/ARBERT 36 | # /scratch/ba63/bert-base-arabertv02 37 | 38 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 39 | # export TASK_NAME=arabic_did 40 | 41 | python run_text_classification.py \ 42 | --model_type bert \ 43 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step \ 44 | --task_name $TASK_NAME \ 45 | --do_eval \ 46 | --eval_all_checkpoints \ 47 | --data_dir $ARABIC_DATA \ 48 | --max_seq_length 128 \ 49 | --per_gpu_eval_batch_size 32 \ 50 | --learning_rate 3e-5 \ 51 | --overwrite_cache \ 52 | --output_dir /scratch/ba63/fine_tuned_models/poetry_models/CAMeLBERT_MSA_sixteenth_poetry/$TASK_NAME \ 53 | --seed 12345 54 | -------------------------------------------------------------------------------- /token-classification/scripts/run_ner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p nvidia 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | 17 | ########################## 18 | # NER FINE-TUNING SCRIPT 19 | ########################## 20 | 21 | # aubmindlab/bert-base-arabertv01 22 | # lanwuwei/GigaBERT-v4-Arabic-and-English 23 | # bashar-talafha/multi-dialect-bert-base-arabic 24 | # asafaya/bert-base-arabic 25 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 26 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 27 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 28 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 29 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 30 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 31 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 32 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 33 | # bert-base-multilingual-cased 34 | # /scratch/ba63/UBC-NLP/MARBERT 35 | # /scratch/ba63/UBC-NLP/ARBERT 36 | # /scratch/ba63/bert-base-arabertv02/ 37 | # /scratch/ba63/bert-base-arabertv01 38 | 39 | export DATA_DIR=ANERCorp-CamelLabSplits/ 40 | export MAX_LENGTH=512 41 | export BERT_MODEL=/scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 42 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/ner_models/CAMeLBERT_MSA_sixteenth_NER 43 | export BATCH_SIZE=32 44 | export NUM_EPOCHS=3 45 | export SAVE_STEPS=750 46 | export SEED=12345 47 | 48 | 49 | python run_token_classification.py \ 50 | --data_dir $DATA_DIR \ 51 | --task_type ner \ 52 | --labels $DATA_DIR/labels.txt \ 53 | --model_name_or_path $BERT_MODEL \ 54 | --output_dir $OUTPUT_DIR \ 55 | --max_seq_length $MAX_LENGTH \ 56 | --num_train_epochs $NUM_EPOCHS \ 57 | --per_device_train_batch_size $BATCH_SIZE \ 58 | --save_steps $SAVE_STEPS \ 59 | --seed $SEED \ 60 | --overwrite_output_dir \ 61 | --overwrite_cache \ 62 | --do_train \ 63 | --do_predict 64 | -------------------------------------------------------------------------------- /token-classification/scripts/run_pos_tagging.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | 17 | ################################ 18 | # POS TAGGING FINE-TUNING SCRIPT 19 | ################################ 20 | 21 | 22 | # aubmindlab/bert-base-arabertv01 23 | # lanwuwei/GigaBERT-v4-Arabic-and-English 24 | # bashar-talafha/multi-dialect-bert-base-arabic 25 | # asafaya/bert-base-arabic 26 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 27 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 28 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 29 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 30 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 31 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 32 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 33 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 34 | # bert-base-multilingual-cased 35 | # /scratch/ba63/UBC-NLP/MARBERT 36 | # /scratch/ba63/UBC-NLP/ARBERT 37 | # /scratch/ba63/bert-base-arabertv02/ 38 | # /scratch/ba63/bert-base-arabertv01 39 | 40 | export DATA_DIR=/scratch/ba63/magold_files/EGY 41 | export MAX_LENGTH=512 42 | export BERT_MODEL=/scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 43 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/pos_models/EGY/CAMeLBERT_MSA_POS_EGY 44 | export BATCH_SIZE=32 45 | export NUM_EPOCHS=10 46 | export SAVE_STEPS=500 47 | export SEED=12345 48 | 49 | 50 | python run_token_classification.py \ 51 | --data_dir $DATA_DIR \ 52 | --task_type pos \ 53 | --labels $DATA_DIR/labels.txt \ 54 | --model_name_or_path $BERT_MODEL \ 55 | --output_dir $OUTPUT_DIR \ 56 | --max_seq_length $MAX_LENGTH \ 57 | --num_train_epochs $NUM_EPOCHS \ 58 | --per_device_train_batch_size $BATCH_SIZE \ 59 | --save_steps $SAVE_STEPS \ 60 | --seed $SEED \ 61 | --overwrite_output_dir \ 62 | --overwrite_cache \ 63 | --do_train \ 64 | --do_eval 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | 17 | ####################################### 18 | # Text classification evaluation script 19 | ####################################### 20 | 21 | 22 | # export ARABIC_DATA=data/test 23 | # export TASK_NAME=arabic_sentiment 24 | 25 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 26 | # export TASK_NAME=arabic_poetry 27 | 28 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 29 | # export TASK_NAME=arabic_did_madar_26 30 | 31 | export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 32 | export TASK_NAME=arabic_did_madar_6 33 | 34 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 35 | # export TASK_NAME=arabic_did_madar_twitter 36 | 37 | 38 | # export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 39 | # export TASK_NAME=arabic_did_nadi_country 40 | 41 | # aubmindlab/bert-base-arabertv01 42 | # lanwuwei/GigaBERT-v4-Arabic-and-English 43 | # bashar-talafha/multi-dialect-bert-base-arabic 44 | # asafaya/bert-base-arabic 45 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 46 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 47 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 48 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 49 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 50 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 51 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 52 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 53 | # bert-base-multilingual-cased 54 | # /scratch/ba63/UBC-NLP/MARBERT 55 | # /scratch/ba63/UBC-NLP/ARBERT 56 | # /scratch/ba63/bert-base-arabertv02 57 | 58 | python run_text_classification.py \ 59 | --model_type bert \ 60 | --model_name_or_path aubmindlab/bert-base-arabertv01 \ 61 | --task_name $TASK_NAME \ 62 | --do_pred \ 63 | --data_dir $ARABIC_DATA \ 64 | --max_seq_length 128 \ 65 | --per_gpu_eval_batch_size 32 \ 66 | --learning_rate 3e-5 \ 67 | --overwrite_cache \ 68 | --output_dir /scratch/ba63/fine_tuned_models/did_models_MADAR_6/arabert_DID_MADAR_6/$TASK_NAME/checkpoint-15000-best \ 69 | --seed 12345 70 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_madar_tweet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:55:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | # export ARABIC_DATA=data/train 17 | # export TASK_NAME=arabic_sentiment 18 | 19 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 20 | # export TASK_NAME=arabic_poetry 21 | 22 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 23 | # export TASK_NAME=arabic_did_madar_26 24 | 25 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 26 | # export TASK_NAME=arabic_did_madar_6 27 | 28 | export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 29 | export TASK_NAME=arabic_did_madar_twitter 30 | 31 | 32 | # export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 33 | # export TASK_NAME=arabic_did_nadi_country 34 | 35 | # lanwuwei/GigaBERT-v4-Arabic-and-English 36 | # bashar-talafha/multi-dialect-bert-base-arabic 37 | # asafaya/bert-base-arabic 38 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 39 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 40 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 41 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 42 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 43 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 44 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 45 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 46 | # bert-base-multilingual-cased 47 | 48 | # /scratch/ba63/UBC-NLP/MARBERT 49 | # /scratch/ba63/UBC-NLP/ARBERT 50 | # /scratch/ba63/bert-base-arabertv02 51 | # /scratch/ba63/bert-base-arabertv01 52 | 53 | python run_text_classification.py \ 54 | --model_type bert \ 55 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step \ 56 | --task_name $TASK_NAME \ 57 | --do_train \ 58 | --save_steps 500 \ 59 | --data_dir $ARABIC_DATA \ 60 | --max_seq_length 128 \ 61 | --per_gpu_train_batch_size 32 \ 62 | --per_gpu_eval_batch_size 32 \ 63 | --learning_rate 3e-5 \ 64 | --num_train_epochs 10.0 \ 65 | --overwrite_output_dir \ 66 | --overwrite_cache \ 67 | --output_dir /scratch/ba63/fine_tuned_models/did_models_MADAR_twitter/CAMeLBERT_MSA_sixteenth_DID/$TASK_NAME \ 68 | --seed 12345 69 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:55:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | ####################################### 17 | # Text classification fine-tuning script 18 | ####################################### 19 | 20 | export ARABIC_DATA=data/train 21 | export TASK_NAME=arabic_sentiment 22 | 23 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 24 | # export TASK_NAME=arabic_poetry 25 | 26 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 27 | # export TASK_NAME=arabic_did_madar_26 28 | 29 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 30 | # export TASK_NAME=arabic_did_madar_6 31 | 32 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 33 | # export TASK_NAME=arabic_did_madar_twitter 34 | 35 | 36 | # export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 37 | # export TASK_NAME=arabic_did_nadi_country 38 | 39 | # aubmindlab/bert-base-arabertv01 40 | # lanwuwei/GigaBERT-v4-Arabic-and-English 41 | # bashar-talafha/multi-dialect-bert-base-arabic 42 | # asafaya/bert-base-arabic 43 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 44 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 45 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 46 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 47 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 48 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 49 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 50 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 51 | # bert-base-multilingual-cased 52 | 53 | # /scratch/ba63/UBC-NLP/MARBERT 54 | # /scratch/ba63/UBC-NLP/ARBERT 55 | # /scratch/ba63/bert-base-arabertv02 56 | 57 | python run_text_classification.py \ 58 | --model_type bert \ 59 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step \ 60 | --task_name $TASK_NAME \ 61 | --do_train \ 62 | --do_eval \ 63 | --eval_all_checkpoints \ 64 | --save_steps 500 \ 65 | --data_dir $ARABIC_DATA \ 66 | --max_seq_length 128 \ 67 | --per_gpu_train_batch_size 32 \ 68 | --per_gpu_eval_batch_size 32 \ 69 | --learning_rate 3e-5 \ 70 | --num_train_epochs 3.0 \ 71 | --overwrite_output_dir \ 72 | --overwrite_cache \ 73 | --output_dir /scratch/ba63/fine_tuned_models/sentiment_models/CAMeLBERT_MSA_arabic_sentiment/$TASK_NAME \ 74 | --seed 12345 75 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_test_NADI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:30:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | #################################### 17 | # NADI DID TEST EVAL SCRIPT 18 | # Note: We upload the predictions 19 | # to codalab as the test gold labels 20 | # were not available at the time of 21 | # writing this paper 22 | #################################### 23 | 24 | # export ARABIC_DATA=data/test 25 | # export TASK_NAME=arabic_sentiment 26 | 27 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 28 | # export TASK_NAME=arabic_poetry 29 | 30 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 31 | # export TASK_NAME=arabic_did_madar_26 32 | 33 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 34 | # export TASK_NAME=arabic_did_madar_6 35 | 36 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 37 | # export TASK_NAME=arabic_did_madar_twitter 38 | 39 | 40 | export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 41 | export TASK_NAME=arabic_did_nadi_country 42 | 43 | # aubmindlab/bert-base-arabertv01 44 | # lanwuwei/GigaBERT-v4-Arabic-and-English 45 | # bashar-talafha/multi-dialect-bert-base-arabic 46 | # asafaya/bert-base-arabic 47 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 48 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 49 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 50 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 51 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 52 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 53 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 54 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 55 | # bert-base-multilingual-cased 56 | # /scratch/ba63/UBC-NLP/MARBERT 57 | # /scratch/ba63/UBC-NLP/ARBERT 58 | # /scratch/ba63/bert-base-arabertv02 59 | 60 | python run_text_classification.py \ 61 | --model_type bert \ 62 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step \ 63 | --task_name $TASK_NAME \ 64 | --do_pred \ 65 | --write_preds \ 66 | --data_dir $ARABIC_DATA \ 67 | --max_seq_length 128 \ 68 | --per_gpu_eval_batch_size 32 \ 69 | --learning_rate 3e-5 \ 70 | --overwrite_cache \ 71 | --output_dir /scratch/ba63/fine_tuned_models/did_models_NADI_country/CAMeLBERT_MSA_sixteenth_DID/$TASK_NAME/checkpoint-1500-best/ \ 72 | --seed 12345 73 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_test_madar_tweet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p condo 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:55:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | ################################### 17 | # MADAR-Twitter-5 dev eval script 18 | ################################### 19 | 20 | 21 | # export ARABIC_DATA=data/train 22 | # export TASK_NAME=arabic_sentiment 23 | 24 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 25 | # export TASK_NAME=arabic_poetry 26 | 27 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 28 | # export TASK_NAME=arabic_did_madar_26 29 | 30 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 31 | # export TASK_NAME=arabic_did_madar_6 32 | 33 | export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 34 | export TASK_NAME=arabic_did_madar_twitter 35 | 36 | 37 | # export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 38 | # export TASK_NAME=arabic_did_nadi_country 39 | 40 | # aubmindlab/bert-base-arabertv01 41 | # lanwuwei/GigaBERT-v4-Arabic-and-English 42 | # bashar-talafha/multi-dialect-bert-base-arabic 43 | # asafaya/bert-base-arabic 44 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 45 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 46 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 47 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 48 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 49 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 50 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 51 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 52 | # bert-base-multilingual-cased 53 | 54 | # /scratch/ba63/UBC-NLP/MARBERT 55 | # /scratch/ba63/UBC-NLP/ARBERT 56 | # /scratch/ba63/bert-base-arabertv02 57 | 58 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/did_models_MADAR_twitter/CAMeLBERT_MSA_sixteenth_DID/$TASK_NAME 59 | for f in $OUTPUT_DIR $OUTPUT_DIR/checkpoint-*/ 60 | do 61 | python run_text_classification.py \ 62 | --model_type bert \ 63 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step \ 64 | --task_name $TASK_NAME \ 65 | --do_eval \ 66 | --write_preds \ 67 | --data_dir $ARABIC_DATA \ 68 | --max_seq_length 128 \ 69 | --per_gpu_eval_batch_size 32 \ 70 | --overwrite_cache \ 71 | --output_dir $f \ 72 | --seed 12345 73 | 74 | export dev_user_ids=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/grouped_tweets.dev.users 75 | export preds=$f/predictions.txt 76 | 77 | paste -d '\t' $dev_user_ids $preds > $f/users_and_preds 78 | 79 | python utils/vote_did.py --preds_file_path $f/users_and_preds --output_file_path $f/users_and_preds.voting 80 | 81 | 82 | python /scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-DID-Scorer.py /scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/dev.gold.CAMeL.labels $f/users_and_preds.voting > $f/eval_results.voting.txt 83 | 84 | done 85 | -------------------------------------------------------------------------------- /text-classification/scripts/run_text_classification_test_madar_tweet_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p nvidia 3 | # use gpus 4 | #SBATCH --gres=gpu:1 5 | # memory 6 | #SBATCH --mem=120000 7 | # Walltime format hh:mm:ss 8 | #SBATCH --time=11:55:00 9 | # Output and error files 10 | #SBATCH -o job.%J.out 11 | #SBATCH -e job.%J.err 12 | 13 | nvidia-smi 14 | module purge 15 | 16 | ################################### 17 | # MADAR-Twitter-5 Test Eval script 18 | ################################### 19 | 20 | 21 | # export ARABIC_DATA=data/train 22 | # export TASK_NAME=arabic_sentiment 23 | 24 | # export ARABIC_DATA=/scratch/ba63/arabic_poetry_dataset/ 25 | # export TASK_NAME=arabic_poetry 26 | 27 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 28 | # export TASK_NAME=arabic_did_madar_26 29 | 30 | # export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-1/ 31 | # export TASK_NAME=arabic_did_madar_6 32 | 33 | export ARABIC_DATA=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/ 34 | export TASK_NAME=arabic_did_madar_twitter 35 | 36 | 37 | # export ARABIC_DATA=/scratch/ba63/NADI/NADI_release/ 38 | # export TASK_NAME=arabic_did_nadi_country 39 | 40 | # aubmindlab/bert-base-arabertv01 41 | # lanwuwei/GigaBERT-v4-Arabic-and-English 42 | # bashar-talafha/multi-dialect-bert-base-arabic 43 | # asafaya/bert-base-arabic 44 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-CA-full-1000000-step 45 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-full-1000000-step 46 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MIX-full-1000000-step 47 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-DA-full-1000000-step 48 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-half-1000000-step 49 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-quarter-1000000-step 50 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-eighth-1000000-step 51 | # /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step 52 | # bert-base-multilingual-cased 53 | 54 | # /scratch/ba63/UBC-NLP/MARBERT 55 | # /scratch/ba63/UBC-NLP/ARBERT 56 | # /scratch/ba63/bert-base-arabertv02 57 | 58 | export OUTPUT_DIR=/scratch/ba63/fine_tuned_models/did_models_MADAR_twitter/CAMeLBERT_MSA_sixteenth_DID/$TASK_NAME/checkpoint-10500-best/ 59 | 60 | python run_text_classification.py \ 61 | --model_type bert \ 62 | --model_name_or_path /scratch/nlp/CAMeLBERT/model/bert-base-wp-30k_msl-512-MSA-sixteenth-1000000-step \ 63 | --task_name $TASK_NAME \ 64 | --do_pred \ 65 | --write_preds \ 66 | --data_dir $ARABIC_DATA \ 67 | --max_seq_length 128 \ 68 | --per_gpu_eval_batch_size 32 \ 69 | --overwrite_cache \ 70 | --output_dir $OUTPUT_DIR \ 71 | --seed 12345 72 | 73 | export test_user_ids=/scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/grouped_tweets.test.users 74 | export preds=$OUTPUT_DIR/predictions.txt 75 | 76 | paste -d '\t' $test_user_ids $preds > $OUTPUT_DIR/users_and_preds 77 | 78 | python utils/vote_did.py --preds_file_path $OUTPUT_DIR/users_and_preds --output_file_path $OUTPUT_DIR/users_and_preds.voting 79 | 80 | 81 | python /scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-DID-Scorer.py /scratch/ba63/MADAR-SHARED-TASK-final-release-25Jul2019/MADAR-Shared-Task-Subtask-2/MADAR-tweets/test.gold.labels $OUTPUT_DIR/users_and_preds.voting > $OUTPUT_DIR/test_results.voting.txt 82 | -------------------------------------------------------------------------------- /text-classification/utils/vote_did.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2021 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Dialect identification voting evaluation. 27 | This is used specifically for the MADAR-Twitter-5 task 28 | """ 29 | 30 | from collections import Counter 31 | import argparse 32 | 33 | 34 | TRAIN_LABEL_COUNTS = {"Algeria": 1096, "Bahrain": 1443, "Djibouti": 30, 35 | "Egypt": 2250, "Iraq": 1380, "Jordan": 1326, 36 | "Kuwait": 2997, "Lebanon": 831, "Libya": 1024, 37 | "Mauritania": 507, "Morocco": 505, "Oman": 1921, 38 | "Palestine": 957, "Qatar": 1758, "Saudi_Arabia": 14875, 39 | "Somalia": 486, "Sudan": 1310, "Syria": 623, 40 | "Tunisia": 575, "United_Arab_Emirates": 2074, 41 | "Yemen": 1868} 42 | 43 | 44 | def read_data(path): 45 | with open(path) as f: 46 | return f.readlines() 47 | 48 | def user_preds(predictions): 49 | users_preds = {} 50 | for line in predictions: 51 | line = line.strip().split('\t') 52 | user_id = line[0] 53 | pred = line[1] 54 | if user_id in users_preds: 55 | users_preds[user_id].append(pred) 56 | else: 57 | users_preds[user_id] = [pred] 58 | 59 | for user in users_preds: 60 | users_preds[user] = Counter(users_preds[user]) 61 | return users_preds 62 | 63 | def write_final_preds(preds_per_user, output_path): 64 | outfile = open(output_path, mode='w') 65 | for user in preds_per_user: 66 | most_common_preds = preds_per_user[user].most_common() 67 | max_count = most_common_preds[0][1] 68 | max_pred = most_common_preds[0][0] 69 | check = [_ for _ in most_common_preds if _[1] == max_count] 70 | # if there's more than one prediction with the same count, 71 | # just pick the prediction that has the maximum count 72 | # based on the Twitter-5 training data 73 | if len(check) > 1: 74 | max_pred = max(check, key=lambda x: x[1])[0] 75 | outfile.write(max_pred) 76 | outfile.write('\n') 77 | else: 78 | outfile.write(max_pred) 79 | outfile.write('\n') 80 | outfile.close() 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--preds_file_path', type=str, help='Predictions file') 85 | parser.add_argument('--output_file_path', type=str, help='Output file') 86 | args = parser.parse_args() 87 | preds = read_data(args.preds_file_path) 88 | preds_per_user = user_preds(preds) 89 | write_final_preds(preds_per_user, args.output_file_path) 90 | -------------------------------------------------------------------------------- /text-classification/utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2020 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Text classification tasks evaluation utils 27 | """ 28 | 29 | from sklearn.metrics import f1_score, precision_score, recall_score 30 | 31 | 32 | SENTIMENT_LABELS = ["positive", "negative", "neutral"] 33 | 34 | POETRY_LABELS = ["شعر حر", "شعر التفعيلة", "عامي", "موشح", "الرجز", 35 | "الرمل", "الهزج", "البسيط", "الخفيف", "السريع", 36 | "الطويل", "الكامل", "المجتث", "المديد", "الوافر", 37 | "الدوبيت", "السلسلة", "المضارع", "المقتضب", "المنسرح", 38 | "المتدارك", "المتقارب", "المواليا"] 39 | 40 | MADAR_26_LABELS = ["KHA", "TUN", "MOS", "CAI", "BAG", "ALE", "DOH", "ALX", 41 | "SAN", "BAS", "TRI", "ALG", "MSA", "FES", "BEN", "SAL", 42 | "JER", "BEI", "SFX", "MUS", "JED", "RIY", "RAB", "DAM", 43 | "ASW", "AMM"] 44 | 45 | MADAR_6_LABELS = ["TUN", "CAI", "DOH", "MSA", "BEI", "RAB"] 46 | 47 | MADAR_TWITTER_LABELS = ["Algeria", "Bahrain", "Djibouti", "Egypt", "Iraq", 48 | "Jordan", "Kuwait", "Lebanon", "Libya", "Mauritania", 49 | "Morocco", "Oman", "Palestine", "Qatar", 50 | "Saudi_Arabia", "Somalia", "Sudan", "Syria", 51 | "Tunisia", "United_Arab_Emirates", "Yemen"] 52 | 53 | NADI_COUNTRY_LABELS = ["Algeria", "Bahrain", "Djibouti", "Egypt", 54 | "Iraq", "Jordan", "Kuwait", "Lebanon", 55 | "Libya", "Mauritania", "Morocco", "Oman", 56 | "Palestine", "Qatar", "Saudi_Arabia", 57 | "Somalia", "Sudan", "Syria", "Tunisia", 58 | "United_Arab_Emirates", "Yemen"] 59 | 60 | 61 | def acc_and_f1_poetry(preds, labels): 62 | acc = (preds == labels).mean() 63 | f1 = f1_score(y_true=labels, y_pred=preds, average='macro') 64 | precision = precision_score(y_true=labels, y_pred=preds, average='macro') 65 | recall = recall_score(y_true=labels, y_pred=preds, average='macro') 66 | 67 | return { 68 | "acc": acc, 69 | "f1": f1, 70 | "precision": precision, 71 | "recall": recall 72 | } 73 | 74 | def acc_and_f1_sentiment(preds, labels): 75 | acc = (preds == labels).mean() 76 | f1 = f1_score(y_true=labels, y_pred=preds, average=None, labels=[0, 1, 2]) 77 | precision = precision_score(y_true=labels, y_pred=preds, average=None, labels=[0, 1, 2]) 78 | recall = recall_score(y_true=labels, y_pred=preds, average=None, labels=[0, 1, 2]) 79 | 80 | f1_macro = float(f1[0] + f1[1] + f1[2]) / 3.0 81 | f1_pn = float(f1[0] + f1[1]) / 2.0 82 | precision_macro = float(precision[0] + precision[1] + precision[2]) / 3.0 83 | precision_pn = float(precision[0] + precision[1]) / 2.0 84 | recall_macro = float(recall[0] + recall[1] + recall[2]) / 3.0 85 | recall_pn = float(recall[0] + recall[1]) / 2.0 86 | 87 | return { 88 | "acc": acc, 89 | "f1": f1_macro, 90 | "f1_pn": f1_pn, 91 | "precision": precision_macro, 92 | "precision_pn": precision_pn, 93 | "recall": recall_macro, 94 | "recall_pn": recall_pn 95 | } 96 | 97 | def acc_and_f1_DID(preds, labels, labels_str): 98 | # gold_labels = list(set(labels)) 99 | # f1 = f1_score(labels, preds, labels=gold_labels, average=None) * 100 100 | # recall = recall_score(labels, preds, labels=gold_labels, average=None) * 100 101 | # precision = precision_score(labels, preds, labels=gold_labels, average=None) * 100 102 | # print(f1, flush=True) 103 | # print(recall, flush=True) 104 | # print(precision, flush=True) 105 | # print(list(labels), flush=True) 106 | # print(2 in list(labels), flush=True) 107 | # print(list(preds), flush=True) 108 | # print(2 in list(preds), flush=True) 109 | # print(gold_labels, flush=True) 110 | # individual_scores = {} 111 | # precisions = {} 112 | # recalls = {} 113 | # f_scores = {} 114 | 115 | # for x in gold_labels: 116 | # precisions[labels_str[x]] = precision[x] 117 | # recalls[labels_str[x]] = recall[x] 118 | # f_scores[labels_str[x]] = f1[x] 119 | 120 | # individual_scores['INDIVIDUAL PRECISION SCORE'] = precisions 121 | # individual_scores['INDIVIDUAL RECALL SCORE'] = recalls 122 | # individual_scores['INDIVIDUAL F1 SCORE'] = f_scores 123 | 124 | ## computes overall scores (accuracy, f1, recall, precision) 125 | accuracy = (preds == labels).mean() * 100 126 | f1 = f1_score(labels, preds, average="macro") * 100 127 | recall = recall_score(labels, preds, average="macro") * 100 128 | precision = precision_score(labels, preds, average="macro") * 100 129 | 130 | return { 131 | "precision": precision, 132 | "recall": recall, 133 | "f1": f1, 134 | "acc": accuracy 135 | # "INDIVIDUAL SCORES": individual_scores 136 | } 137 | 138 | def write_predictions(path_dir, task_name, preds): 139 | predictions_file = open(path_dir, mode='w') 140 | if task_name == "arabic_did_madar_twitter": 141 | for pred in preds: 142 | predictions_file.write(MADAR_TWITTER_LABELS[pred]) 143 | predictions_file.write('\n') 144 | 145 | elif task_name == "arabic_did_nadi_country": 146 | for pred in preds: 147 | predictions_file.write(NADI_COUNTRY_LABELS[pred]) 148 | predictions_file.write('\n') 149 | 150 | predictions_file.close() 151 | 152 | def compute_metrics(task_name, preds, labels): 153 | assert len(preds) == len(labels) 154 | 155 | if task_name == "arabic_sentiment": 156 | return acc_and_f1_sentiment(preds, labels) 157 | 158 | elif task_name == "arabic_poetry": 159 | return acc_and_f1_poetry(preds, labels) 160 | 161 | elif task_name == "arabic_did_madar_26": 162 | return acc_and_f1_DID(preds, labels, labels_str=MADAR_26_LABELS) 163 | 164 | elif task_name == "arabic_did_madar_6": 165 | return acc_and_f1_DID(preds, labels, labels_str=MADAR_6_LABELS) 166 | 167 | elif task_name == "arabic_did_madar_twitter": 168 | return acc_and_f1_DID(preds, labels, labels_str=MADAR_TWITTER_LABELS) 169 | 170 | elif task_name == "arabic_did_nadi_country": 171 | return acc_and_f1_DID(preds, labels, labels_str=NADI_COUNTRY_LABELS) 172 | 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAMeLBERT: A collection of pre-trained models for Arabic NLP tasks: 2 | 3 | This repo contains code for the experiments presented in our paper: [The Interplay of Variant, Size, and Task Type in Arabic Pre-trained Language Models](https://arxiv.org/pdf/2103.06678.pdf). 4 | 5 | ## Requirements: 6 | 7 | This code was written for python>=3.7, pytorch 1.5.1, and transformers 3.1.0. You will also need few additional packages. Here's how you can set up the environment using conda (assuming you have conda and cuda installed): 8 | 9 | ```bash 10 | git clone https://github.com/CAMeL-Lab/CAMeLBERT.git 11 | cd CAMeLBERT 12 | 13 | conda create -n CAMeLBERT python=3.7 14 | conda activate CAMeLBERT 15 | 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## CAMeLBERT: 20 | 21 | ### Pretrained Models 22 | Our eight CAMeLBERT models are available on Hugging Face's [model hub](https://huggingface.co/CAMeL-Lab) along with their detailed descriptions. Note: to download our models as described in the model hub, you would need transformers>=3.5.0. Otherwise, you could download the models manually. 23 | 24 | ### Arabic Frequency Lists 25 | We also provide a frequency lists dataset derived from the pretraining datasets (17.3B tokens) used to pretrain the family of CAMeLBert models. 26 | The frequency dataset is available at https://github.com/CAMeL-Lab/Camel_Arabic_Frequency_Lists. 27 | 28 | 29 | ## Fine-tuning Experiments: 30 | 31 | All fine-tuned models can be found [here](https://drive.google.com/drive/folders/15feD46cPcRBybdUUKKrzR9zTxj2QBJ5w?usp=sharing). 32 | 33 | ## Text Classification: 34 | 35 | ### Sentiment Analysis: 36 | 37 | For the sentiment analysis experiments, we combined four datasets: 1) [ArSAS](http://lrec-conf.org/workshops/lrec2018/W30/pdf/22_W30.pdf); 2) [ASTD](https://www.aclweb.org/anthology/D15-1299.pdf); 3) [SemEval-2017 4A](https://www.aclweb.org/anthology/S17-2088.pdf); 4) [ArSenTD](https://arxiv.org/pdf/1906.01830.pdf).
38 | The models were fine-tuned on ArSenTD and the train splits of ArSAS, ASTD, and SemEval-2017. We then evaluate all the checkpoints on 39 | a single dev split from ArSAS, ASTD, and SemEval-2017 and pick the best checkpoint to report the results on the test splits of ArSAS, ASTD, and SemEval-2017 repsectively. To run the fine-tuning: 40 | 41 | ```bash 42 | export DATA_DIR=/path/to/data 43 | export TASK_NAME=arabic_sentiment 44 | 45 | python run_text_classification.py \ 46 | --model_type bert \ 47 | --model_name_or_path /path/to/pretrained_model/ \ # Or huggingface model id 48 | --task_name $TASK_NAME \ 49 | --do_train \ 50 | --do_eval \ 51 | --eval_all_checkpoints \ 52 | --save_steps 500 \ 53 | --data_dir $DATA_DIR \ 54 | --max_seq_length 128 \ 55 | --per_gpu_train_batch_size 32 \ 56 | --per_gpu_eval_batch_size 32 \ 57 | --learning_rate 3e-5 \ 58 | --num_train_epochs 3.0 \ 59 | --overwrite_output_dir \ 60 | --overwrite_cache \ 61 | --output_dir /path/to/output_dir \ 62 | --seed 12345 63 | ``` 64 | 65 | ### Dialect Identification: 66 | 67 | For the dialect identification experiments, we fine-tuned the models on four different dialect identification datasets: 1) [MADAR Corpus 26](https://www.aclweb.org/anthology/C18-1113.pdf); 2) [MADAR Corpus 6](https://www.aclweb.org/anthology/C18-1113.pdf); 3) [MADAR Twitter-5](https://www.aclweb.org/anthology/W19-4622.pdf); 4) [NADI Country-level](https://www.aclweb.org/anthology/2020.wanlp-1.9.pdf). We fine-tuned the models across the four datasets and we pick the best checkpoints on the dev sets to report results on the test sets. To run the fine-tuning: 68 | 69 | 70 | ```bash 71 | export DATA_DIR=/path/to/data 72 | export TASK_NAME=arabic_did_madar_26 # or arabic_did_madar_6, arabic_did_madar_twitter, arabic_did_nadi_country 73 | 74 | python run_text_classification.py \ 75 | --model_type bert \ 76 | --model_name_or_path /path/to/pretrained_model/ \ # Or huggingface model id 77 | --task_name $TASK_NAME \ 78 | --do_train \ 79 | --do_eval \ 80 | --eval_all_checkpoints \ 81 | --save_steps 500 \ 82 | --data_dir $DATA_DIR \ 83 | --max_seq_length 128 \ 84 | --per_gpu_train_batch_size 32 \ 85 | --per_gpu_eval_batch_size 32 \ 86 | --learning_rate 3e-5 \ 87 | --num_train_epochs 10.0 \ 88 | --overwrite_output_dir \ 89 | --overwrite_cache \ 90 | --output_dir /path/to/output_dir \ 91 | --seed 12345 92 | ``` 93 | 94 | ### Poetry Classification: 95 | 96 | For the poetry classification experiments, we fine-tuned the models on the [APCD](https://arxiv.org/pdf/1905.05700.pdf) dataset. For each model, we pick the best checkpoint based on the dev set to report results on the test set. To run the fine-tuning: 97 | 98 | ```bash 99 | export DATA_DIR=/path/to/data 100 | export TASK_NAME=arabic_poetry 101 | 102 | python run_text_classification.py \ 103 | --model_type bert \ 104 | --model_name_or_path /path/to/pretrained_model/ \ # Or huggingface model id 105 | --task_name $TASK_NAME \ 106 | --do_train \ 107 | --do_eval \ 108 | --eval_all_checkpoints \ 109 | --save_steps 5000 \ 110 | --data_dir $DATA_DIR \ 111 | --max_seq_length 128 \ 112 | --per_gpu_train_batch_size 32 \ 113 | --per_gpu_eval_batch_size 32 \ 114 | --learning_rate 3e-5 \ 115 | --num_train_epochs 3.0 \ 116 | --overwrite_output_dir \ 117 | --overwrite_cache \ 118 | --output_dir /path/to/output_dir \ 119 | --seed 12345 120 | ``` 121 | 122 | Bash scripts to run text-classification fine-tuning and evaluation can be found in `text-classification/scripts/`. 123 | 124 | 125 | ## Token Classification: 126 | 127 | ### NER: 128 | 129 | For the NER experiments, we used the [ANERCorp](https://link.springer.com/chapter/10.1007/978-3-540-70939-8_13) dataset and followed the splits defined by [Obeid et al., 2020](https://camel.abudhabi.nyu.edu/anercorp/). 130 | The dataset doesn't have a dev split, so we fine-tune the models on the train split and evaluate the last checkpoint on the test split. 131 | To run the fine-tuning: 132 | 133 | 134 | ```bash 135 | export DATA_DIR=/path/to/data # Should contain train/dev/test/labels files 136 | export MAX_LENGTH=512 137 | export BERT_MODEL=/path/to/pretrained_model/ # Or huggingface model id 138 | export OUTPUT_DIR=/path/to/output_dir 139 | export BATCH_SIZE=32 140 | export NUM_EPOCHS=3 141 | export SAVE_STEPS=750 142 | export SEED=12345 143 | 144 | python run_token_classification.py \ 145 | --data_dir $DATA_DIR \ 146 | --labels $DATA_DIR/labels.txt \ 147 | --model_name_or_path $BERT_MODEL \ 148 | --output_dir $OUTPUT_DIR \ 149 | --max_seq_length $MAX_LENGTH \ 150 | --num_train_epochs $NUM_EPOCHS \ 151 | --per_device_train_batch_size $BATCH_SIZE \ 152 | --save_steps $SAVE_STEPS \ 153 | --seed $SEED \ 154 | --overwrite_output_dir \ 155 | --overwrite_cache \ 156 | --do_train \ 157 | --do_predict 158 | ``` 159 | 160 | ### POS Tagging: 161 | 162 | For the POS tagging experiments, we fine-tuned the models on three different datasets:
163 | 164 | 1. Penn Arabic Treebank ([PATB](https://www.ldc.upenn.edu/sites/www.ldc.upenn.edu/files/nemlar2004-penn-arabic-treebank.pdf)): in MSA and has 32 POS tags 165 | 2. Egyptian Arabic Treebank ([ARZATB](https://catalog.ldc.upenn.edu/LDC2018T23)): in EGY and has 33 POS tags 166 | 3. [GUMAR](https://www.aclweb.org/anthology/L18-1607.pdf) corpus: in GLF and includes 35 POS tags 167 | 168 | We used the same hyperparameters for the 3 datasets and report results on the test sets by using the best checkpoints on the dev sets. To run the fine-tuning: 169 | 170 | ```bash 171 | export DATA_DIR=/path/to/data # Should contain train/dev/test/labels files 172 | export MAX_LENGTH=512 173 | export BERT_MODEL=/path/to/pretrained_model/ # Or huggingface model id 174 | export OUTPUT_DIR=/path/to/output_dir 175 | export BATCH_SIZE=32 176 | export NUM_EPOCHS=10 177 | export SAVE_STEPS=500 178 | export SEED=12345 179 | 180 | python run_token_classification.py \ 181 | --data_dir $DATA_DIR \ 182 | --labels $DATA_DIR/labels.txt \ 183 | --model_name_or_path $BERT_MODEL \ 184 | --output_dir $OUTPUT_DIR \ 185 | --max_seq_length $MAX_LENGTH \ 186 | --num_train_epochs $NUM_EPOCHS \ 187 | --per_device_train_batch_size $BATCH_SIZE \ 188 | --save_steps $SAVE_STEPS \ 189 | --seed $SEED \ 190 | --overwrite_output_dir \ 191 | --overwrite_cache \ 192 | --do_train \ 193 | --do_eval 194 | ``` 195 | 196 | Bash scripts to run token-classification fine-tuning and evaluation can be found in `token-classification/scripts/`. 197 | 198 | ## Citation: 199 | 200 | If you find any of the CAMeLBERT or the fine-tuned models useful in your work, please cite [our paper](https://arxiv.org/pdf/2103.06678.pdf): 201 | ```bibtex 202 | @inproceedings{inoue-etal-2021-interplay, 203 | title = "The Interplay of Variant, Size, and Task Type in {A}rabic Pre-trained Language Models", 204 | author = "Inoue, Go and 205 | Alhafni, Bashar and 206 | Baimukan, Nurpeiis and 207 | Bouamor, Houda and 208 | Habash, Nizar", 209 | booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop", 210 | month = apr, 211 | year = "2021", 212 | address = "Kyiv, Ukraine (Online)", 213 | publisher = "Association for Computational Linguistics", 214 | abstract = "In this paper, we explore the effects of language variants, data sizes, and fine-tuning task types in Arabic pre-trained language models. To do so, we build three pre-trained language models across three variants of Arabic: Modern Standard Arabic (MSA), dialectal Arabic, and classical Arabic, in addition to a fourth language model which is pre-trained on a mix of the three. We also examine the importance of pre-training data size by building additional models that are pre-trained on a scaled-down set of the MSA variant. We compare our different models to each other, as well as to eight publicly available models by fine-tuning them on five NLP tasks spanning 12 datasets. Our results suggest that the variant proximity of pre-training data to fine-tuning data is more important than the pre-training data size. We exploit this insight in defining an optimized system selection model for the studied tasks.", 215 | } 216 | ``` 217 | -------------------------------------------------------------------------------- /token-classification/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2021 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ Token classification fine-tuning: utilities to work with token 26 | classification tasks (NER, POS tagging, etc.) 27 | Heavily adapted from: https://github.com/huggingface/transformers/blob/ 28 | v3.0.1/examples/token-classification/utils_ner.py""" 29 | 30 | 31 | import logging 32 | import os 33 | from dataclasses import dataclass 34 | from filelock import FileLock 35 | from enum import Enum 36 | from typing import List, Optional, Union 37 | 38 | import torch 39 | import torch.nn as nn 40 | from torch.utils.data import Dataset 41 | from transformers import PreTrainedTokenizer 42 | 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | @dataclass 48 | class InputExample: 49 | """ 50 | A single training/test example for token classification. 51 | 52 | Args: 53 | guid: Unique id for the example. 54 | words: list. The words of the sequence. 55 | labels: (Optional) list. The labels for each word of the sequence. 56 | This should be specified for train and dev examples, but not for test 57 | examples. 58 | """ 59 | 60 | guid: str 61 | words: List[str] 62 | labels: Optional[List[str]] 63 | 64 | 65 | @dataclass 66 | class InputFeatures: 67 | """ 68 | A single set of features of data. 69 | Property names are the same names as the corresponding inputs to a model. 70 | """ 71 | 72 | input_ids: List[int] 73 | attention_mask: List[int] 74 | token_type_ids: Optional[List[int]] = None 75 | label_ids: Optional[List[int]] = None 76 | 77 | 78 | class Split(Enum): 79 | train = "train" 80 | dev = "dev" 81 | test = "test" 82 | 83 | 84 | class TokenClassificationDataSet(Dataset): 85 | """ 86 | This will be superseded by a framework-agnostic approach 87 | soon. 88 | """ 89 | 90 | features: List[InputFeatures] 91 | pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index 92 | # Use cross entropy ignore_index as padding label id so that only 93 | # real label ids contribute to the loss later. 94 | 95 | def __init__( 96 | self, 97 | data_dir: str, 98 | tokenizer: PreTrainedTokenizer, 99 | labels: List[str], 100 | model_type: str, 101 | max_seq_length: Optional[int] = None, 102 | overwrite_cache=False, 103 | mode: Split = Split.train, 104 | ): 105 | # Load data features from cache or dataset file 106 | cached_features_file = os.path.join( 107 | data_dir, "cached_{}_{}_{}".format(mode.value, 108 | tokenizer.__class__.__name__, 109 | str(max_seq_length)),) 110 | 111 | # Make sure only the first process in distributed training 112 | # processes the dataset, and the others will use the cache. 113 | lock_path = cached_features_file + ".lock" 114 | with FileLock(lock_path): 115 | 116 | if os.path.exists(cached_features_file) and not overwrite_cache: 117 | logger.info(f"Loading features from cached file {cached_features_file}") 118 | self.features = torch.load(cached_features_file) 119 | else: 120 | logger.info(f"Creating features from dataset file at {data_dir}") 121 | examples = read_examples_from_file(data_dir, mode) 122 | self.features = convert_examples_to_features( 123 | examples, 124 | labels, 125 | max_seq_length, 126 | tokenizer, 127 | cls_token=tokenizer.cls_token, 128 | cls_token_segment_id=0, 129 | sep_token=tokenizer.sep_token, 130 | pad_token=tokenizer.pad_token_id, 131 | pad_token_segment_id=tokenizer.pad_token_type_id, 132 | pad_token_label_id=self.pad_token_label_id, 133 | ) 134 | logger.info(f"Saving features into cached file {cached_features_file}") 135 | torch.save(self.features, cached_features_file) 136 | 137 | def __len__(self): 138 | return len(self.features) 139 | 140 | def __getitem__(self, i) -> InputFeatures: 141 | return self.features[i] 142 | 143 | 144 | def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]: 145 | if isinstance(mode, Split): 146 | mode = mode.value 147 | file_path = os.path.join(data_dir, f"{mode}.txt") 148 | guid_index = 1 149 | examples = [] 150 | with open(file_path, encoding="utf-8") as f: 151 | words = [] 152 | labels = [] 153 | for line in f: 154 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 155 | if words: 156 | examples.append(InputExample(guid=f"{mode}-{guid_index}", 157 | words=words, labels=labels)) 158 | guid_index += 1 159 | words = [] 160 | labels = [] 161 | else: 162 | splits = line.split(" ") 163 | words.append(splits[0]) 164 | #if len(splits) > 1: 165 | labels.append(splits[-1].replace("\n", "")) 166 | #else: 167 | # # Examples could have no label for mode = "test" 168 | # # This is needed to get around the Trainer evaluation 169 | # labels.append("O") 170 | if words: 171 | examples.append(InputExample(guid=f"{mode}-{guid_index}", 172 | words=words, labels=labels)) 173 | return examples 174 | 175 | 176 | def convert_examples_to_features( 177 | examples: List[InputExample], 178 | label_list: List[str], 179 | max_seq_length: int, 180 | tokenizer: PreTrainedTokenizer, 181 | cls_token="[CLS]", 182 | cls_token_segment_id=0, 183 | sep_token="[SEP]", 184 | pad_token=0, 185 | pad_token_segment_id=0, 186 | pad_token_label_id=-100, 187 | sequence_a_segment_id=0, 188 | mask_padding_with_zero=True, 189 | ) -> List[InputFeatures]: 190 | """ Loads a data file into a list of `InputFeatures'""" 191 | 192 | label_map = {label: i for i, label in enumerate(label_list)} 193 | 194 | features = [] 195 | for (ex_index, example) in enumerate(examples): 196 | if ex_index % 10_000 == 0: 197 | logger.info("Writing example %d of %d", ex_index, len(examples)) 198 | 199 | tokens = [] 200 | label_ids = [] 201 | for word, label in zip(example.words, example.labels): 202 | word_tokens = tokenizer.tokenize(word) 203 | 204 | # bert-base-multilingual-cased sometimes output "nothing ([]) 205 | # when calling tokenize with just a space. 206 | if len(word_tokens) > 0: 207 | tokens.extend(word_tokens) 208 | # Use the real label id for the first token of the word, 209 | # and padding ids for the remaining tokens 210 | label_ids.extend([label_map[label]] + 211 | [pad_token_label_id] * 212 | (len(word_tokens) - 1)) 213 | 214 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 215 | special_tokens_count = tokenizer.num_special_tokens_to_add() 216 | if len(tokens) > max_seq_length - special_tokens_count: 217 | tokens = tokens[: (max_seq_length - special_tokens_count)] 218 | label_ids = label_ids[: (max_seq_length - special_tokens_count)] 219 | 220 | 221 | tokens += [sep_token] 222 | label_ids += [pad_token_label_id] 223 | segment_ids = [sequence_a_segment_id] * len(tokens) 224 | 225 | tokens = [cls_token] + tokens 226 | label_ids = [pad_token_label_id] + label_ids 227 | segment_ids = [cls_token_segment_id] + segment_ids 228 | 229 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 230 | 231 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 232 | # tokens are attended to. 233 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 234 | 235 | # Zero-pad up to the sequence length. 236 | padding_length = max_seq_length - len(input_ids) 237 | input_ids += [pad_token] * padding_length 238 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 239 | segment_ids += [pad_token_segment_id] * padding_length 240 | label_ids += [pad_token_label_id] * padding_length 241 | 242 | assert len(input_ids) == max_seq_length 243 | assert len(input_mask) == max_seq_length 244 | assert len(segment_ids) == max_seq_length 245 | assert len(label_ids) == max_seq_length 246 | 247 | if ex_index < 5: 248 | logger.info("*** Example ***") 249 | logger.info("guid: %s", example.guid) 250 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 251 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 252 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 253 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 254 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 255 | 256 | if "token_type_ids" not in tokenizer.model_input_names: 257 | segment_ids = None 258 | 259 | features.append( 260 | InputFeatures(input_ids=input_ids, 261 | attention_mask=input_mask, 262 | token_type_ids=segment_ids, 263 | label_ids=label_ids)) 264 | 265 | return features 266 | 267 | 268 | def get_labels(path: str) -> List[str]: 269 | with open(path, "r") as f: 270 | labels = f.read().splitlines() 271 | # Adding O to the labels to get around the 272 | # Trainer eval at test time if tokens don't 273 | # have any labels 274 | # if "O" not in labels: 275 | # labels = ["O"] + labels 276 | return labels 277 | -------------------------------------------------------------------------------- /token-classification/run_token_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2021 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ Fine-tuning pre-trained models for token classification tasks. 26 | Heavily adapted from: https://github.com/huggingface/transformers/blob/ 27 | v3.0.1/examples/token-classification/run_ner.py""" 28 | 29 | 30 | import logging 31 | import os 32 | import sys 33 | from dataclasses import dataclass, field 34 | from typing import Dict, List, Optional, Tuple 35 | 36 | import numpy as np 37 | from seqeval.metrics import ( 38 | accuracy_score as seq_accuracy_score, 39 | f1_score as seq_f1_score, 40 | precision_score as seq_precision_score, 41 | recall_score as seq_recall_score 42 | ) 43 | from sklearn.metrics import ( 44 | accuracy_score, 45 | f1_score, 46 | precision_score, 47 | recall_score 48 | ) 49 | 50 | from torch import nn 51 | from transformers import ( 52 | AutoConfig, 53 | AutoModelForTokenClassification, 54 | AutoTokenizer, 55 | EvalPrediction, 56 | HfArgumentParser, 57 | Trainer, 58 | TrainingArguments, 59 | set_seed, 60 | ) 61 | from utils import TokenClassificationDataSet, Split, get_labels 62 | 63 | 64 | logger = logging.getLogger(__name__) 65 | 66 | 67 | @dataclass 68 | class ModelArguments: 69 | """ 70 | Arguments pertaining to which model/config/tokenizer we are 71 | going to fine-tune from. 72 | """ 73 | 74 | model_name_or_path: str = field( 75 | metadata={"help": "Path to pretrained model or model identifier from " 76 | "huggingface.co/models"} 77 | ) 78 | 79 | config_name: Optional[str] = field( 80 | default=None, metadata={"help": "Pretrained config name or path if " 81 | "not the same as model_name"} 82 | ) 83 | 84 | # If you want to tweak more attributes on your tokenizer, you should do it 85 | # in a distinct script, or just modify its tokenizer_config.json. 86 | 87 | tokenizer_name: Optional[str] = field( 88 | default=None, metadata={"help": "Pretrained tokenizer name or path if " 89 | "not the same as model_name"} 90 | ) 91 | 92 | use_fast: bool = field(default=False, metadata={"help": "Set this flag to " 93 | "use fast " 94 | "tokenization."}) 95 | task_type: Optional[str] = field( 96 | default="ner", metadata={"help": "the name of the task (ner or pos)"} 97 | ) 98 | 99 | cache_dir: Optional[str] = field( 100 | default=None, metadata={"help": "Where do you want to store the " 101 | "pretrained models downloaded from s3"} 102 | ) 103 | 104 | 105 | @dataclass 106 | class DataTrainingArguments: 107 | """ 108 | Arguments pertaining to what data we are going to input our model for 109 | training and eval. 110 | """ 111 | 112 | data_dir: str = field( 113 | metadata={"help": "The input data dir. Should contain the .txt files " 114 | "for a CoNLL-2003-formatted task."} 115 | ) 116 | labels: Optional[str] = field( 117 | default=None, 118 | metadata={"help": "Path to a file containing all labels."}, 119 | ) 120 | max_seq_length: int = field( 121 | default=128, 122 | metadata={ 123 | "help": "The maximum total input sequence length after " 124 | "tokenization. Sequences longer than this will be truncated, " 125 | "sequences shorter will be padded." 126 | }, 127 | ) 128 | overwrite_cache: bool = field( 129 | default=False, metadata={"help": "Overwrite the cached training and " 130 | "evaluation sets"} 131 | ) 132 | 133 | 134 | def main(): 135 | # See all possible arguments in src/transformers/training_args.py 136 | # or by passing the --help flag to this script. 137 | # We now keep distinct sets of args, for a cleaner separation of concerns. 138 | 139 | parser = HfArgumentParser((ModelArguments, 140 | DataTrainingArguments, 141 | TrainingArguments)) 142 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 143 | # If we pass only one argument to the script and it's the path to a 144 | # json file, let's parse it to get our arguments. 145 | model_args, data_args, training_args = parser.parse_json_file( 146 | json_file=os.path.abspath( 147 | sys.argv[1])) 148 | else: 149 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 150 | 151 | if ( 152 | os.path.exists(training_args.output_dir) 153 | and os.listdir(training_args.output_dir) 154 | and training_args.do_train 155 | and not training_args.overwrite_output_dir 156 | ): 157 | raise ValueError( 158 | f"Output directory ({training_args.output_dir}) already exists " 159 | "and is not empty. Use --overwrite_output_dir to overcome." 160 | ) 161 | 162 | # Setup logging 163 | logging.basicConfig( 164 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 165 | datefmt="%m/%d/%Y %H:%M:%S", 166 | level=(logging.INFO if training_args.local_rank in [-1, 0] 167 | else logging.WARN), 168 | ) 169 | logger.warning( 170 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, " 171 | "16-bits training: %s", 172 | training_args.local_rank, 173 | training_args.device, 174 | training_args.n_gpu, 175 | bool(training_args.local_rank != -1), 176 | training_args.fp16, 177 | ) 178 | logger.info("Training/evaluation parameters %s", training_args) 179 | 180 | # Set seed 181 | set_seed(training_args.seed) 182 | 183 | # Prepare task 184 | labels = get_labels(data_args.labels) 185 | label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} 186 | num_labels = len(labels) 187 | 188 | # Load pretrained model and tokenizer 189 | # 190 | # Distributed training: 191 | # The .from_pretrained methods guarantee that only one local process can 192 | # concurrently download model & vocab. 193 | 194 | config = AutoConfig.from_pretrained( 195 | (model_args.config_name if model_args.config_name 196 | else model_args.model_name_or_path), 197 | num_labels=num_labels, 198 | id2label=label_map, 199 | label2id={label: i for i, label in enumerate(labels)}, 200 | cache_dir=model_args.cache_dir, 201 | ) 202 | tokenizer = AutoTokenizer.from_pretrained( 203 | (model_args.tokenizer_name if model_args.tokenizer_name 204 | else model_args.model_name_or_path), 205 | cache_dir=model_args.cache_dir, 206 | use_fast=model_args.use_fast, 207 | ) 208 | model = AutoModelForTokenClassification.from_pretrained( 209 | model_args.model_name_or_path, 210 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 211 | config=config, 212 | cache_dir=model_args.cache_dir, 213 | ) 214 | 215 | # Get datasets 216 | train_dataset = ( 217 | TokenClassificationDataSet( 218 | data_dir=data_args.data_dir, 219 | tokenizer=tokenizer, 220 | labels=labels, 221 | model_type=config.model_type, 222 | max_seq_length=data_args.max_seq_length, 223 | overwrite_cache=data_args.overwrite_cache, 224 | mode=Split.train, 225 | ) 226 | if training_args.do_train 227 | else None 228 | ) 229 | eval_dataset = ( 230 | TokenClassificationDataSet( 231 | data_dir=data_args.data_dir, 232 | tokenizer=tokenizer, 233 | labels=labels, 234 | model_type=config.model_type, 235 | max_seq_length=data_args.max_seq_length, 236 | overwrite_cache=data_args.overwrite_cache, 237 | mode=Split.dev, 238 | ) 239 | if training_args.do_eval 240 | else None 241 | ) 242 | 243 | def align_predictions(predictions: np.ndarray, 244 | label_ids: np.ndarray) -> Tuple[List[int], List[int]]: 245 | preds = np.argmax(predictions, axis=2) 246 | 247 | batch_size, seq_len = preds.shape 248 | 249 | out_label_list = [[] for _ in range(batch_size)] 250 | preds_list = [[] for _ in range(batch_size)] 251 | 252 | for i in range(batch_size): 253 | for j in range(seq_len): 254 | if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index: 255 | out_label_list[i].append(label_map[label_ids[i][j]]) 256 | preds_list[i].append(label_map[preds[i][j]]) 257 | 258 | return preds_list, out_label_list 259 | 260 | def compute_metrics(p: EvalPrediction) -> Dict: 261 | preds_list, out_label_list = align_predictions(p.predictions, 262 | p.label_ids) 263 | # If task type is NER, use seqeval metrics. 264 | # Otherwise, use scikit learn 265 | if model_args.task_type == "ner": 266 | return { 267 | "accuracy": seq_accuracy_score(out_label_list, preds_list), 268 | "precision": seq_precision_score(out_label_list, preds_list), 269 | "recall": seq_recall_score(out_label_list, preds_list), 270 | "f1": seq_f1_score(out_label_list, preds_list), 271 | } 272 | else: 273 | # Flatten the preds_list and out_label_list 274 | preds_list = [p for sublist in preds_list for p in sublist] 275 | out_label_list = [p for sublist in out_label_list for p in sublist] 276 | return { 277 | "accuracy": accuracy_score(out_label_list, preds_list), 278 | "precision_micro": precision_score(out_label_list, preds_list, 279 | average="micro"), 280 | "recall_micro": recall_score(out_label_list, preds_list, 281 | average="micro"), 282 | "f1_micro": f1_score(out_label_list, preds_list, 283 | average="micro"), 284 | "precision_macro": precision_score(out_label_list, preds_list, 285 | average="macro"), 286 | "recall_macro": recall_score(out_label_list, preds_list, 287 | average="macro"), 288 | "f1_macro": f1_score(out_label_list, preds_list, 289 | average="macro"), 290 | } 291 | 292 | # Initialize our Trainer 293 | trainer = Trainer( 294 | model=model, 295 | args=training_args, 296 | train_dataset=train_dataset, 297 | eval_dataset=eval_dataset, 298 | compute_metrics=compute_metrics, 299 | ) 300 | 301 | # Training 302 | if training_args.do_train: 303 | trainer.train( 304 | model_path=(model_args.model_name_or_path 305 | if os.path.isdir(model_args.model_name_or_path) 306 | else None) 307 | ) 308 | trainer.save_model() 309 | # For convenience, we also re-save the tokenizer to the same directory, 310 | # so that you can share your model easily on huggingface.co/models =) 311 | if trainer.is_world_master(): 312 | tokenizer.save_pretrained(training_args.output_dir) 313 | 314 | # Evaluation 315 | results = {} 316 | if training_args.do_eval: 317 | logger.info("*** Evaluate ***") 318 | 319 | result = trainer.evaluate() 320 | 321 | output_eval_file = os.path.join(training_args.output_dir, 322 | "eval_results.txt") 323 | if trainer.is_world_master(): 324 | with open(output_eval_file, "w") as writer: 325 | logger.info("***** Eval results *****") 326 | for key, value in result.items(): 327 | logger.info(" %s = %s", key, value) 328 | writer.write("%s = %s\n" % (key, value)) 329 | 330 | results.update(result) 331 | 332 | # Predict 333 | if training_args.do_predict: 334 | test_dataset = TokenClassificationDataSet( 335 | data_dir=data_args.data_dir, 336 | tokenizer=tokenizer, 337 | labels=labels, 338 | model_type=config.model_type, 339 | max_seq_length=data_args.max_seq_length, 340 | overwrite_cache=data_args.overwrite_cache, 341 | mode=Split.test, 342 | ) 343 | 344 | predictions, label_ids, metrics = trainer.predict(test_dataset) 345 | preds_list, _ = align_predictions(predictions, label_ids) 346 | 347 | output_test_results_file = os.path.join(training_args.output_dir, 348 | "test_results.txt") 349 | if trainer.is_world_master(): 350 | with open(output_test_results_file, "w") as writer: 351 | for key, value in metrics.items(): 352 | logger.info(" %s = %s", key, value) 353 | writer.write("%s = %s\n" % (key, value)) 354 | 355 | # Save predictions 356 | output_test_predictions_file = os.path.join(training_args.output_dir, 357 | "test_predictions.txt") 358 | if trainer.is_world_master(): 359 | with open(output_test_predictions_file, "w") as writer: 360 | with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: 361 | example_id = 0 362 | for line in f: 363 | if (line.startswith("-DOCSTART-") or line == "" 364 | or line == "\n"): 365 | writer.write(line) 366 | if not preds_list[example_id]: 367 | example_id += 1 368 | elif preds_list[example_id]: 369 | output_line = (line.split()[0] + " " + 370 | preds_list[example_id].pop(0) + "\n") 371 | writer.write(output_line) 372 | else: 373 | logger.warning( 374 | "Maximum sequence length exceeded: " 375 | "No prediction for '%s'.", line.split()[0]) 376 | 377 | return results 378 | 379 | 380 | if __name__ == "__main__": 381 | main() 382 | -------------------------------------------------------------------------------- /text-classification/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2021 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Text classification tasks utils 27 | """ 28 | 29 | import logging 30 | import os 31 | import re 32 | import camel_tools.utils.normalize as normalize 33 | import camel_tools.utils.dediac as dediac 34 | from transformers.data.processors.utils import ( 35 | DataProcessor, 36 | InputExample, 37 | InputFeatures 38 | ) 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | SENTIMENT_LABELS = ["positive", "negative", "neutral"] 43 | 44 | POETRY_LABELS = ["شعر حر", "شعر التفعيلة", "عامي", "موشح", "الرجز", 45 | "الرمل", "الهزج", "البسيط", "الخفيف", "السريع", 46 | "الطويل", "الكامل", "المجتث", "المديد", "الوافر", 47 | "الدوبيت", "السلسلة", "المضارع", "المقتضب", "المنسرح", 48 | "المتدارك", "المتقارب", "المواليا"] 49 | 50 | MADAR_26_LABELS = ["KHA", "TUN", "MOS", "CAI", "BAG", "ALE", "DOH", "ALX", 51 | "SAN", "BAS", "TRI", "ALG", "MSA", "FES", "BEN", "SAL", 52 | "JER", "BEI", "SFX", "MUS", "JED", "RIY", "RAB", "DAM", 53 | "ASW", "AMM"] 54 | 55 | MADAR_6_LABELS = ["TUN", "CAI", "DOH", "MSA", "BEI", "RAB"] 56 | 57 | MADAR_TWITTER_LABELS = ["Algeria", "Bahrain", "Djibouti", "Egypt", "Iraq", 58 | "Jordan", "Kuwait", "Lebanon", "Libya", "Mauritania", 59 | "Morocco", "Oman", "Palestine", "Qatar", 60 | "Saudi_Arabia", "Somalia", "Sudan", "Syria", 61 | "Tunisia", "United_Arab_Emirates", "Yemen"] 62 | 63 | NADI_COUNTRY_LABELS = ["Algeria", "Bahrain", "Djibouti", "Egypt", 64 | "Iraq", "Jordan", "Kuwait", "Lebanon", 65 | "Libya", "Mauritania", "Morocco", "Oman", 66 | "Palestine", "Qatar", "Saudi_Arabia", 67 | "Somalia", "Sudan", "Syria", "Tunisia", 68 | "United_Arab_Emirates", "Yemen"] 69 | 70 | 71 | def convert_examples_to_features(examples, tokenizer, 72 | max_length=512, 73 | task=None, 74 | label_list=None, 75 | output_mode=None, 76 | pad_on_left=False, 77 | pad_token=0, 78 | pad_token_segment_id=0, 79 | mask_padding_with_zero=True): 80 | """ 81 | Loads a data file into a list of ``InputFeatures`` 82 | 83 | Args: 84 | examples: List of ``InputExamples`` containing the examples. 85 | tokenizer: Instance of a tokenizer that will tokenize the examples 86 | max_length: Maximum example length 87 | task: Arabic Sentiment Analysis task 88 | label_list: List of labels. 89 | Can be obtained from the processor using the 90 | ``processor.get_labels()`` method 91 | output_mode: String indicating the output mode. 92 | Either ``regression`` or ``classification`` 93 | pad_on_left: If set to ``True``, 94 | the examples will be padded on the left rather 95 | than on the right (default) 96 | pad_token: Padding token 97 | pad_token_segment_id: The segment ID for the padding token 98 | (It is usually 0, but can vary such as for 99 | XLNet where it is 4) 100 | mask_padding_with_zero: If set to ``True``, 101 | the attention mask will be filled by ``1`` 102 | for actual values 103 | and by ``0`` for padded values. 104 | If set to ``False``, inverts it (``1`` for 105 | padded values, ``0`` for actual values) 106 | 107 | Returns: 108 | list of task-specific ``InputFeatures`` which can be fed to the model. 109 | 110 | """ 111 | 112 | if task is not None: 113 | processor = processors[task]() 114 | if label_list is None: 115 | label_list = processor.get_labels() 116 | logger.info("Using label list %s for task %s" % 117 | (label_list, task)) 118 | if output_mode is None: 119 | output_mode = output_modes[task] 120 | logger.info("Using output mode %s for task %s" % 121 | (output_mode, task)) 122 | 123 | label_map = {label: i for i, label in enumerate(label_list)} 124 | logger.info('**LABEL MAP**') 125 | logger.info(label_map) 126 | features = [] 127 | for (ex_index, example) in enumerate(examples): 128 | if ex_index % 10000 == 0: 129 | logger.info("Writing example %d" % (ex_index)) 130 | 131 | inputs = tokenizer.encode_plus( 132 | example.text_a, 133 | example.text_b, 134 | add_special_tokens=True, 135 | max_length=max_length, 136 | truncation=True 137 | ) 138 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 139 | 140 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 141 | # tokens are attended to. 142 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 143 | 144 | # Zero-pad up to the sequence length. 145 | padding_length = max_length - len(input_ids) 146 | if pad_on_left: 147 | input_ids = ([pad_token] * padding_length) + input_ids 148 | attention_mask = ([0 if mask_padding_with_zero else 1] 149 | * padding_length) + attention_mask 150 | token_type_ids = (([pad_token_segment_id] * padding_length) 151 | + token_type_ids) 152 | else: 153 | input_ids = input_ids + ([pad_token] * padding_length) 154 | attention_mask = attention_mask + ([0 if mask_padding_with_zero 155 | else 1] * padding_length) 156 | token_type_ids = token_type_ids + ([pad_token_segment_id] 157 | * padding_length) 158 | 159 | assert len(input_ids) == max_length, ( 160 | "Error with input length {} vs {}".format(len(input_ids), 161 | max_length)) 162 | 163 | assert len(attention_mask) == max_length, ( 164 | "Error with input length {} vs {}".format(len(attention_mask), 165 | max_length)) 166 | 167 | assert len(token_type_ids) == max_length, ( 168 | "Error with input length {} vs {}".format(len(token_type_ids), 169 | max_length)) 170 | 171 | if output_mode == "classification": 172 | # DUMMY GOLD LABEL FOR NADI TEST 173 | # BECAUSE WE DON'T HAVE THE GOLD LABELS 174 | # AND WE RUN THE EVAL ON TEST USING 175 | # CODALAB 176 | label = (label_map[example.label] if example.label is not None 177 | else label_map['Syria']) 178 | elif output_mode == "regression": 179 | label = float(example.label) 180 | else: 181 | raise KeyError(output_mode) 182 | 183 | if ex_index < 5: 184 | logger.info("*** Example ***") 185 | logger.info("guid: %s" % (example.guid)) 186 | logger.info("input_ids: %s" % 187 | " ".join([str(x) for x in input_ids])) 188 | logger.info("attention_mask: %s" % 189 | " ".join([str(x) for x in attention_mask])) 190 | logger.info("token_type_ids: %s" % 191 | " ".join([str(x) for x in token_type_ids])) 192 | 193 | logger.info("label: %s (id = %d)" % (example.label, label)) 194 | 195 | features.append( 196 | InputFeatures(input_ids=input_ids, 197 | attention_mask=attention_mask, 198 | token_type_ids=token_type_ids, 199 | label=label)) 200 | 201 | return features 202 | 203 | 204 | class ArabicSentimentProcessor(DataProcessor): 205 | """Processor for Arabic Sentiment Analysis""" 206 | 207 | def get_example_from_tensor_dict(self, tensor_dict): 208 | """See base class.""" 209 | return InputExample(tensor_dict['idx'].numpy(), 210 | tensor_dict['tweet'].numpy().decode('utf-8'), 211 | str(tensor_dict['label'].numpy())) 212 | 213 | def _create_examples(self, lines, set_type): 214 | """Creates examples for the training and dev sets.""" 215 | examples = [] 216 | for (i, line) in enumerate(lines): 217 | guid = "%s-%s" % (set_type, i) 218 | text_a = line[1] 219 | text_a = self.process_tweet(line[1]) 220 | label = line[2] 221 | examples.append(InputExample(guid=guid, text_a=text_a, 222 | label=label)) 223 | return examples 224 | 225 | def get_train_examples(self, data_dir): 226 | """See base class.""" 227 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 228 | "train.tsv"))) 229 | return self._create_examples( 230 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 231 | "train") 232 | 233 | def get_dev_examples(self, data_dir): 234 | """See base class.""" 235 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 236 | "dev.tsv"))) 237 | return self._create_examples( 238 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 239 | "dev") 240 | 241 | def get_test_examples(self, data_dir): 242 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 243 | "test.tsv"))) 244 | return self._create_examples( 245 | self._read_tsv(os.path.join(data_dir, "ArSAS.test.tsv")), 246 | "test") 247 | 248 | def get_labels(self): 249 | return SENTIMENT_LABELS 250 | 251 | def process_tweet(self, tweet): 252 | """ 253 | processes a tweet by normalizing letters, removing urls, 254 | and removing diacritics 255 | 256 | Args: 257 | input tweet 258 | Returns: 259 | processed tweet 260 | """ 261 | 262 | # URL regex 263 | URL_REG = re.compile(r'[https|http|@|RT]([^\s]+)') 264 | # space regex 265 | SPACE_REG = re.compile(r'[\s]+') 266 | 267 | tweet = normalize.normalize_unicode(tweet) 268 | tweet = dediac.dediac_ar(tweet) 269 | tweet = URL_REG.sub(' ', tweet) 270 | tweet = SPACE_REG.sub(' ', tweet) 271 | return tweet 272 | 273 | class ArabicPoetryProcessor(DataProcessor): 274 | """Processor for Arabic Poetry Classification""" 275 | 276 | def get_example_from_tensor_dict(self, tensor_dict): 277 | """See base class.""" 278 | return InputExample(tensor_dict['idx'].numpy(), 279 | tensor_dict['verse_1'].numpy().decode('utf-8'), 280 | tensor_dict['verse_2'].numpy().decode('utf-8'), 281 | str(tensor_dict['label'].numpy())) 282 | 283 | def _create_examples(self, lines, set_type): 284 | """Creates examples for the training and dev sets.""" 285 | examples = [] 286 | for (i, line) in enumerate(lines): 287 | guid = "%s-%s" % (set_type, i) 288 | # check if line contains 2 verses or not 289 | if len(line) == 3: 290 | text_a = self.process_verse(line[0]) 291 | text_b = self.process_verse(line[1]) 292 | label = line[2] 293 | elif len(line) == 2: 294 | text_a = self.process_verse(line[0]) 295 | text_b = None 296 | label = line[1] 297 | examples.append(InputExample(guid=guid, text_a=text_a, 298 | text_b=text_b, label=label)) 299 | 300 | return examples 301 | 302 | def get_train_examples(self, data_dir): 303 | """See base class.""" 304 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 305 | "train.tsv"))) 306 | return self._create_examples( 307 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 308 | "train") 309 | 310 | def get_dev_examples(self, data_dir): 311 | """See base class.""" 312 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 313 | "dev.tsv"))) 314 | return self._create_examples( 315 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 316 | "dev") 317 | 318 | def get_test_examples(self, data_dir): 319 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 320 | "test.tsv"))) 321 | return self._create_examples( 322 | self._read_tsv(os.path.join(data_dir, "test.tsv")), 323 | "test") 324 | 325 | def get_labels(self): 326 | return POETRY_LABELS 327 | 328 | def process_verse(self, verse): 329 | """ 330 | processes a verse by removing diacritics 331 | 332 | Args: 333 | input verse 334 | Returns: 335 | processed verse 336 | """ 337 | 338 | verse = dediac.dediac_ar(verse) 339 | return verse 340 | 341 | class ArabicDIDProcessor_MADAR_26(DataProcessor): 342 | """Processor for Arabic Dialect ID Classification 343 | on MADAR Corpus 26""" 344 | 345 | def get_example_from_tensor_dict(self, tensor_dict): 346 | """See base class.""" 347 | return InputExample(tensor_dict['idx'].numpy(), 348 | tensor_dict['text'].numpy().decode('utf-8'), 349 | str(tensor_dict['label'].numpy())) 350 | 351 | def _create_examples(self, lines, set_type): 352 | """Creates examples for the training and dev sets.""" 353 | examples = [] 354 | for (i, line) in enumerate(lines): 355 | guid = "%s-%s" % (set_type, i) 356 | text_a = self.process_text(line[0]) 357 | label = line[1] 358 | examples.append(InputExample(guid=guid, text_a=text_a,label=label)) 359 | 360 | return examples 361 | 362 | def get_train_examples(self, data_dir): 363 | """See base class.""" 364 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 365 | "MADAR-Corpus-26-train.tsv"))) 366 | return self._create_examples( 367 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-26-train.tsv")), 368 | "train") 369 | 370 | def get_dev_examples(self, data_dir): 371 | """See base class.""" 372 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 373 | "MADAR-Corpus-26-dev.tsv"))) 374 | return self._create_examples( 375 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-26-dev.tsv")), 376 | "dev") 377 | 378 | def get_test_examples(self, data_dir): 379 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 380 | "MADAR-Corpus-26-test.tsv"))) 381 | return self._create_examples( 382 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-26-test.tsv")), 383 | "test") 384 | 385 | def get_labels(self): 386 | return MADAR_26_LABELS 387 | 388 | def process_text(self, text): 389 | """ 390 | processes the input text by removing diacritics 391 | 392 | Args: 393 | input text 394 | Returns: 395 | processed text 396 | """ 397 | 398 | text = dediac.dediac_ar(text) 399 | return text 400 | 401 | class ArabicDIDProcessor_MADAR_6(DataProcessor): 402 | """Processor for Arabic Dialect ID Classification on 403 | MADAR Corpus 6""" 404 | 405 | def get_example_from_tensor_dict(self, tensor_dict): 406 | """See base class.""" 407 | return InputExample(tensor_dict['idx'].numpy(), 408 | tensor_dict['text'].numpy().decode('utf-8'), 409 | str(tensor_dict['label'].numpy())) 410 | 411 | def _create_examples(self, lines, set_type): 412 | """Creates examples for the training and dev sets.""" 413 | examples = [] 414 | for (i, line) in enumerate(lines): 415 | guid = "%s-%s" % (set_type, i) 416 | text_a = self.process_text(line[0]) 417 | label = line[1] 418 | examples.append(InputExample(guid=guid, text_a=text_a,label=label)) 419 | 420 | return examples 421 | 422 | def get_train_examples(self, data_dir): 423 | """See base class.""" 424 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 425 | "MADAR-Corpus-6-train.tsv"))) 426 | return self._create_examples( 427 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-6-train.tsv")), 428 | "train") 429 | 430 | def get_dev_examples(self, data_dir): 431 | """See base class.""" 432 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 433 | "MADAR-Corpus-6-dev.tsv"))) 434 | return self._create_examples( 435 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-6-dev.tsv")), 436 | "dev") 437 | 438 | def get_test_examples(self, data_dir): 439 | """See base class.""" 440 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 441 | "MADAR-Corpus-6-test.tsv"))) 442 | return self._create_examples( 443 | self._read_tsv(os.path.join(data_dir, "MADAR-Corpus-6-test.tsv")), 444 | "test") 445 | 446 | def get_labels(self): 447 | return MADAR_6_LABELS 448 | 449 | def process_text(self, text): 450 | """ 451 | processes the input text by removing diacritics 452 | 453 | Args: 454 | input text 455 | Returns: 456 | processed text 457 | """ 458 | 459 | text = dediac.dediac_ar(text) 460 | return text 461 | 462 | class ArabicDIDProcessor_MADAR_Twitter(DataProcessor): 463 | """Processor for Arabic Dialect ID Classification on 464 | MADAR Shared Task 2""" 465 | 466 | def get_example_from_tensor_dict(self, tensor_dict): 467 | """See base class.""" 468 | return InputExample(tensor_dict['idx'].numpy(), 469 | tensor_dict['user_id'].numpy().decode('utf-8'), 470 | tensor_dict['tweet'].numpy().decode('utf-8'), 471 | str(tensor_dict['label'].numpy())) 472 | 473 | def _create_examples(self, lines, set_type): 474 | """Creates examples for the training and dev sets.""" 475 | examples = [] 476 | for (i, line) in enumerate(lines[1:]): 477 | guid = "%s-%s" % (set_type, i) 478 | user_id = line[0] 479 | text_a = self.process_tweet(line[1]) 480 | label = line[2] 481 | examples.append(InputExample(guid=guid, text_a=text_a, 482 | label=label)) 483 | return examples 484 | 485 | def get_train_examples(self, data_dir): 486 | """See base class.""" 487 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 488 | "grouped_tweets.train.tsv"))) 489 | return self._create_examples( 490 | self._read_tsv(os.path.join(data_dir, "grouped_tweets.train.tsv")), 491 | "train") 492 | 493 | def get_dev_examples(self, data_dir): 494 | """See base class.""" 495 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 496 | "grouped_tweets.dev.tsv"))) 497 | return self._create_examples( 498 | self._read_tsv(os.path.join(data_dir, "grouped_tweets.dev.tsv")), 499 | "dev") 500 | 501 | def get_test_examples(self, data_dir): 502 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 503 | "grouped_tweets.test.tsv"))) 504 | return self._create_examples( 505 | self._read_tsv(os.path.join(data_dir, "grouped_tweets.test.tsv")), 506 | "test") 507 | 508 | def get_labels(self): 509 | return MADAR_TWITTER_LABELS 510 | 511 | def process_tweet(self, tweet): 512 | """ 513 | processes a tweet by normalizing letters, removing urls, 514 | and removing diacritics 515 | 516 | Args: 517 | input tweet 518 | Returns: 519 | processed tweet 520 | """ 521 | 522 | # URL regex 523 | URL_REG = re.compile(r'[https|http|@|RT]([^\s]+)') 524 | # space regex 525 | SPACE_REG = re.compile(r'[\s]+') 526 | 527 | tweet = normalize.normalize_unicode(tweet) 528 | tweet = dediac.dediac_ar(tweet) 529 | tweet = URL_REG.sub(' ', tweet) 530 | tweet = SPACE_REG.sub(' ', tweet) 531 | return tweet.strip() 532 | 533 | class ArabicDIDProcessor_NADI_COUNTRY(DataProcessor): 534 | """Processor for Arabic Dialect ID Classification on 535 | NADI Shared Task 1""" 536 | 537 | def get_example_from_tensor_dict(self, tensor_dict): 538 | """See base class.""" 539 | return InputExample(tensor_dict['idx'].numpy(), 540 | tensor_dict['tweet'].numpy().decode('utf-8'), 541 | str(tensor_dict['label'].numpy())) 542 | 543 | def _create_examples(self, lines, set_type): 544 | """Creates examples for the training and dev sets.""" 545 | examples = [] 546 | for (i, line) in enumerate(lines[1:]): 547 | guid = "%s-%s" % (set_type, i) 548 | text_a = self.process_tweet(line[1]) 549 | if set_type == 'test': 550 | # WE DON'T HAVE GOLD LABELS FOR NADI TEST 551 | label = None 552 | else: 553 | label = line[2] 554 | examples.append(InputExample(guid=guid, text_a=text_a, 555 | label=label)) 556 | return examples 557 | 558 | def get_train_examples(self, data_dir): 559 | """See base class.""" 560 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 561 | "train_labeled.tsv"))) 562 | return self._create_examples( 563 | self._read_tsv(os.path.join(data_dir, "train_labeled.tsv")), 564 | "train") 565 | 566 | def get_dev_examples(self, data_dir): 567 | """See base class.""" 568 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 569 | "dev_labeled.tsv"))) 570 | return self._create_examples( 571 | self._read_tsv(os.path.join(data_dir, "dev_labeled.tsv")), 572 | "dev") 573 | 574 | def get_test_examples(self, data_dir): 575 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, 576 | "test_unlabeled.tsv"))) 577 | return self._create_examples( 578 | self._read_tsv(os.path.join(data_dir, "test_unlabeled.tsv")), 579 | "test") 580 | 581 | def get_labels(self): 582 | return NADI_COUNTRY_LABELS 583 | 584 | def process_tweet(self, tweet): 585 | """ 586 | processes a tweet by normalizing letters, removing urls, 587 | and removing diacritics 588 | 589 | Args: 590 | input tweet 591 | Returns: 592 | processed tweet 593 | """ 594 | 595 | # URL regex 596 | URL_REG = re.compile(r'[https|http|@|RT]([^\s]+)') 597 | # space regex 598 | SPACE_REG = re.compile(r'[\s]+') 599 | 600 | tweet = normalize.normalize_unicode(tweet) 601 | tweet = dediac.dediac_ar(tweet) 602 | tweet = URL_REG.sub(' ', tweet) 603 | tweet = SPACE_REG.sub(' ', tweet) 604 | return tweet 605 | 606 | processors = { 607 | "arabic_sentiment": ArabicSentimentProcessor, 608 | "arabic_poetry": ArabicPoetryProcessor, 609 | "arabic_did_madar_26": ArabicDIDProcessor_MADAR_26, 610 | "arabic_did_madar_6": ArabicDIDProcessor_MADAR_6, 611 | "arabic_did_madar_twitter": ArabicDIDProcessor_MADAR_Twitter, 612 | "arabic_did_nadi_country": ArabicDIDProcessor_NADI_COUNTRY 613 | } 614 | 615 | output_modes = { 616 | "arabic_sentiment": "classification", 617 | "arabic_poetry": "classification", 618 | "arabic_did_madar_26": "classification", 619 | "arabic_did_madar_6": "classification", 620 | "arabic_did_madar_twitter": "classification", 621 | "arabic_did_nadi_country": "classification" 622 | 623 | } 624 | -------------------------------------------------------------------------------- /text-classification/run_text_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License 4 | # 5 | # Copyright 2018-2021 New York University Abu Dhabi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Fine-tuning code for text classification tasks 27 | Heavily adapted from: https://github.com/huggingface/transformers/blob/ 28 | v2.5.1/examples/run_glue.py 29 | """ 30 | 31 | import argparse 32 | import glob 33 | import json 34 | import logging 35 | import os 36 | import random 37 | 38 | import numpy as np 39 | import torch 40 | from torch.utils.data import ( 41 | DataLoader, 42 | RandomSampler, 43 | SequentialSampler, 44 | TensorDataset 45 | ) 46 | from torch.utils.data.distributed import DistributedSampler 47 | from tqdm import tqdm, trange 48 | 49 | from transformers import ( 50 | WEIGHTS_NAME, 51 | AdamW, 52 | AutoConfig, 53 | AutoModelForSequenceClassification, 54 | AutoTokenizer, 55 | get_linear_schedule_with_warmup 56 | ) 57 | 58 | from utils.metrics import compute_metrics, write_predictions 59 | from utils.data_utils import output_modes 60 | from utils.data_utils import processors 61 | from utils.data_utils import convert_examples_to_features 62 | 63 | 64 | logger = logging.getLogger(__name__) 65 | 66 | 67 | def set_seed(args): 68 | random.seed(args.seed) 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | if args.n_gpu > 0: 72 | torch.cuda.manual_seed_all(args.seed) 73 | 74 | 75 | def train(args, train_dataset, model, tokenizer): 76 | """ Train the model """ 77 | 78 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 79 | train_sampler = (RandomSampler(train_dataset) if args.local_rank == -1 80 | else DistributedSampler(train_dataset)) 81 | train_dataloader = DataLoader(train_dataset, 82 | sampler=train_sampler, 83 | batch_size=args.train_batch_size) 84 | 85 | if args.max_steps > 0: 86 | t_total = args.max_steps 87 | args.num_train_epochs =(args.max_steps // 88 | (len(train_dataloader) // 89 | args.gradient_accumulation_steps) + 1) 90 | else: 91 | t_total = (len(train_dataloader) // 92 | args.gradient_accumulation_steps * 93 | args.num_train_epochs) 94 | 95 | # Prepare optimizer and schedule (linear warmup and decay) 96 | no_decay = ["bias", "LayerNorm.weight"] 97 | optimizer_grouped_parameters = [ 98 | { 99 | "params": [p for n, p in model.named_parameters() 100 | if not any(nd in n for nd in no_decay)], 101 | "weight_decay": args.weight_decay, 102 | }, 103 | { 104 | "params": [p for n, p in model.named_parameters() 105 | if any(nd in n for nd in no_decay)], 106 | "weight_decay": 0.0 107 | }, 108 | ] 109 | 110 | optimizer = AdamW(optimizer_grouped_parameters, 111 | lr=args.learning_rate, 112 | eps=args.adam_epsilon) 113 | scheduler = get_linear_schedule_with_warmup( 114 | optimizer, num_warmup_steps=args.warmup_steps, 115 | num_training_steps=t_total) 116 | 117 | # Check if saved optimizer or scheduler states exist 118 | if (os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) 119 | and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt") 120 | )): 121 | # Load in optimizer and scheduler states 122 | optimizer.load_state_dict(torch.load( 123 | os.path.join(args.model_name_or_path, 124 | "optimizer.pt"))) 125 | scheduler.load_state_dict(torch.load( 126 | os.path.join(args.model_name_or_path, 127 | "scheduler.pt"))) 128 | 129 | if args.fp16: 130 | try: 131 | from apex import amp 132 | except ImportError: 133 | raise ImportError("Please install apex from " 134 | "https://www.github.com/nvidia/apex to use " 135 | "fp16 training.") 136 | model, optimizer = amp.initialize(model, optimizer, 137 | opt_level=args.fp16_opt_level) 138 | 139 | # multi-gpu training (should be after apex fp16 initialization) 140 | if args.n_gpu > 1: 141 | model = torch.nn.DataParallel(model) 142 | 143 | # Distributed training (should be after apex fp16 initialization) 144 | if args.local_rank != -1: 145 | model = torch.nn.parallel.DistributedDataParallel( 146 | model, device_ids=[args.local_rank], 147 | output_device=args.local_rank, find_unused_parameters=True, 148 | ) 149 | 150 | # Train! 151 | logger.info("***** Running training *****") 152 | logger.info(" Num examples = %d", len(train_dataset)) 153 | logger.info(" Num Epochs = %d", args.num_train_epochs) 154 | logger.info(" Instantaneous batch size per GPU = %d", 155 | args.per_gpu_train_batch_size) 156 | logger.info( 157 | " Total train batch size " 158 | "(w. parallel, distributed & accumulation) = %d", 159 | args.train_batch_size 160 | * args.gradient_accumulation_steps 161 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 162 | ) 163 | logger.info(" Gradient Accumulation steps = %d", 164 | args.gradient_accumulation_steps) 165 | logger.info(" Total optimization steps = %d", t_total) 166 | 167 | global_step = 0 168 | epochs_trained = 0 169 | steps_trained_in_current_epoch = 0 170 | # Check if continuing training from a checkpoint 171 | if os.path.exists(args.model_name_or_path): 172 | # set global_step to global_step of last saved checkpoint 173 | # from model path 174 | try: 175 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 176 | except ValueError: 177 | global_step = 0 178 | epochs_trained = (global_step // 179 | (len(train_dataloader) // 180 | args.gradient_accumulation_steps)) 181 | steps_trained_in_current_epoch = (global_step % 182 | (len(train_dataloader) // 183 | args.gradient_accumulation_steps)) 184 | 185 | logger.info(" Continuing training from checkpoint, " 186 | "will skip to saved global_step") 187 | logger.info(" Continuing training from epoch %d", epochs_trained) 188 | logger.info(" Continuing training from global step %d", global_step) 189 | logger.info(" Will skip the first %d steps in the first epoch", 190 | steps_trained_in_current_epoch) 191 | 192 | tr_loss, logging_loss = 0.0, 0.0 193 | model.zero_grad() 194 | train_iterator = trange(epochs_trained, 195 | int(args.num_train_epochs), desc="Epoch", 196 | disable=args.local_rank not in [-1, 0],) 197 | set_seed(args) # Added here for reproductibility 198 | for _ in train_iterator: 199 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", 200 | disable=args.local_rank not in [-1, 0]) 201 | for step, batch in enumerate(epoch_iterator): 202 | 203 | # Skip past any already trained steps if resuming training 204 | if steps_trained_in_current_epoch > 0: 205 | steps_trained_in_current_epoch -= 1 206 | continue 207 | 208 | model.train() 209 | batch = tuple(t.to(args.device) for t in batch) 210 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 211 | "labels": batch[3]} 212 | # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 213 | if args.model_type != "distilbert": 214 | inputs["token_type_ids"] = ( 215 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] 216 | else None) 217 | outputs = model(**inputs) 218 | # model outputs are always tuple in transformers (see doc) 219 | loss = outputs[0] 220 | 221 | if args.n_gpu > 1: 222 | # mean() to average on multi-gpu parallel training 223 | loss = loss.mean() 224 | if args.gradient_accumulation_steps > 1: 225 | loss = loss / args.gradient_accumulation_steps 226 | 227 | if args.fp16: 228 | with amp.scale_loss(loss, optimizer) as scaled_loss: 229 | scaled_loss.backward() 230 | else: 231 | loss.backward() 232 | 233 | tr_loss += loss.item() 234 | if (step + 1) % args.gradient_accumulation_steps == 0: 235 | if args.fp16: 236 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 237 | args.max_grad_norm) 238 | else: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), 240 | args.max_grad_norm) 241 | 242 | optimizer.step() 243 | scheduler.step() # Update learning rate schedule 244 | model.zero_grad() 245 | global_step += 1 246 | 247 | if (args.local_rank in [-1, 0] and args.logging_steps > 0 and 248 | global_step % args.logging_steps == 0): 249 | logs = {} 250 | # Only evaluate when single GPU otherwise metrics may not 251 | # average well 252 | if ( 253 | args.local_rank == -1 and args.evaluate_during_training 254 | ): 255 | results = evaluate(args, model, tokenizer) 256 | for key, value in results.items(): 257 | eval_key = "eval_{}".format(key) 258 | logs[eval_key] = value 259 | 260 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps 261 | learning_rate_scalar = scheduler.get_lr()[0] 262 | logs["learning_rate"] = learning_rate_scalar 263 | logs["loss"] = loss_scalar 264 | logging_loss = tr_loss 265 | 266 | print(json.dumps({**logs, **{"step": global_step}})) 267 | 268 | if (args.local_rank in [-1, 0] and args.save_steps > 0 and 269 | global_step % args.save_steps == 0): 270 | # Save model checkpoint 271 | output_dir = os.path.join(args.output_dir, 272 | "checkpoint-{}".format(global_step)) 273 | if not os.path.exists(output_dir): 274 | os.makedirs(output_dir) 275 | model_to_save = ( 276 | model.module if hasattr(model, "module") else model 277 | ) # Take care of distributed/parallel training 278 | model_to_save.save_pretrained(output_dir) 279 | tokenizer.save_pretrained(output_dir) 280 | 281 | torch.save(args, os.path.join(output_dir, 282 | "training_args.bin")) 283 | logger.info("Saving model checkpoint to %s", output_dir) 284 | 285 | torch.save(optimizer.state_dict(), 286 | os.path.join(output_dir, "optimizer.pt")) 287 | torch.save(scheduler.state_dict(), 288 | os.path.join(output_dir, "scheduler.pt")) 289 | logger.info("Saving optimizer and scheduler states to %s", 290 | output_dir) 291 | 292 | if args.max_steps > 0 and global_step > args.max_steps: 293 | epoch_iterator.close() 294 | break 295 | if args.max_steps > 0 and global_step > args.max_steps: 296 | train_iterator.close() 297 | break 298 | 299 | 300 | return global_step, tr_loss / global_step 301 | 302 | 303 | def evaluate(args, model, tokenizer, mode="", prefix=""): 304 | # Loop to handle MNLI double evaluation (matched, mis-matched) 305 | eval_task_names = (("mnli", "mnli-mm") if args.task_name == "mnli" 306 | else (args.task_name,)) 307 | eval_outputs_dirs = ((args.output_dir, args.output_dir + "-MM") 308 | if args.task_name == "mnli" else (args.output_dir,)) 309 | 310 | results = {} 311 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 312 | eval_dataset = load_and_cache_examples(args, eval_task, 313 | tokenizer, mode) 314 | 315 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 316 | os.makedirs(eval_output_dir) 317 | 318 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 319 | # Note that DistributedSampler samples randomly 320 | eval_sampler = SequentialSampler(eval_dataset) 321 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, 322 | batch_size=args.eval_batch_size) 323 | 324 | # multi-gpu eval 325 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 326 | model = torch.nn.DataParallel(model) 327 | 328 | 329 | # Eval! 330 | logger.info("***** Running evaluation {} *****".format(prefix)) 331 | logger.info(" Num examples = %d", len(eval_dataset)) 332 | logger.info(" Batch size = %d", args.eval_batch_size) 333 | eval_loss = 0.0 334 | nb_eval_steps = 0 335 | preds = None 336 | out_label_ids = None 337 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 338 | model.eval() 339 | batch = tuple(t.to(args.device) for t in batch) 340 | with torch.no_grad(): 341 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 342 | "labels": batch[3]} 343 | # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use 344 | # segment_ids 345 | if args.model_type != "distilbert": 346 | inputs["token_type_ids"] = ( 347 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] 348 | else None 349 | ) 350 | outputs = model(**inputs) 351 | tmp_eval_loss, logits = outputs[:2] 352 | 353 | eval_loss += tmp_eval_loss.mean().item() 354 | nb_eval_steps += 1 355 | if preds is None: 356 | preds = logits.detach().cpu().numpy() 357 | out_label_ids = inputs["labels"].detach().cpu().numpy() 358 | else: 359 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 360 | out_label_ids = np.append( 361 | out_label_ids, 362 | inputs["labels"].detach().cpu().numpy(), 363 | axis=0) 364 | 365 | eval_loss = eval_loss / nb_eval_steps 366 | if args.output_mode == "classification": 367 | preds = np.argmax(preds, axis=1) 368 | elif args.output_mode == "regression": 369 | preds = np.squeeze(preds) 370 | 371 | if args.write_preds: 372 | output_path_file = os.path.join(eval_output_dir, prefix, 373 | "predictions.txt") 374 | logger.info("***** Writing Predictions to " 375 | "{} *****".format(output_path_file)) 376 | write_predictions(output_path_file, eval_task, preds) 377 | 378 | result = compute_metrics(eval_task, preds, out_label_ids) 379 | results.update(result) 380 | 381 | output_eval_file = os.path.join(eval_output_dir, prefix, 382 | "eval_results.txt") 383 | with open(output_eval_file, "w") as writer: 384 | logger.info("***** Eval results {} *****".format(prefix)) 385 | for key in sorted(result.keys()): 386 | logger.info(" %s = %s", key, str(result[key])) 387 | writer.write("%s = %s\n" % (key, str(result[key]))) 388 | 389 | return results 390 | 391 | 392 | def load_and_cache_examples(args, task, tokenizer, mode=""): 393 | if args.local_rank not in [-1, 0] and not evaluate: 394 | # Make sure only the first process in distributed training process 395 | # the dataset, and the others will use the cache 396 | torch.distributed.barrier() 397 | 398 | processor = processors[task]() 399 | output_mode = output_modes[task] 400 | # Load data features from cache or dataset file 401 | cached_features_file = os.path.join( 402 | args.data_dir, 403 | "cached_{}_{}_{}_{}".format( 404 | mode, 405 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 406 | str(args.max_seq_length), 407 | str(task), 408 | ), 409 | ) 410 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 411 | logger.info("Loading features from cached file %s", 412 | cached_features_file) 413 | features = torch.load(cached_features_file) 414 | else: 415 | logger.info("Creating features from dataset file at %s", args.data_dir) 416 | label_list = processor.get_labels() 417 | if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]: 418 | # HACK(label indices are swapped in RoBERTa pretrained model) 419 | label_list[1], label_list[2] = label_list[2], label_list[1] 420 | if mode == "train": 421 | examples = processor.get_train_examples(args.data_dir) 422 | elif mode == "dev": 423 | examples = processor.get_dev_examples(args.data_dir) 424 | elif mode == "test": 425 | examples = processor.get_test_examples(args.data_dir) 426 | 427 | features = convert_examples_to_features( 428 | examples, 429 | tokenizer, 430 | label_list=label_list, 431 | max_length=args.max_seq_length, 432 | output_mode=output_mode, 433 | pad_on_left=bool(args.model_type in ["xlnet"]), 434 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 435 | pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, 436 | ) 437 | if args.local_rank in [-1, 0]: 438 | logger.info("Saving features into cached file %s", 439 | cached_features_file) 440 | torch.save(features, cached_features_file) 441 | 442 | if args.local_rank == 0 and not evaluate: 443 | # Make sure only the first process in distributed 444 | # training process the dataset, and the others will use the cache 445 | torch.distributed.barrier() 446 | 447 | # Convert to Tensors and build dataset 448 | all_input_ids = torch.tensor([f.input_ids for f in features], 449 | dtype=torch.long) 450 | all_attention_mask = torch.tensor([f.attention_mask for f in features], 451 | dtype=torch.long) 452 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], 453 | dtype=torch.long) 454 | if output_mode == "classification": 455 | all_labels = torch.tensor([f.label for f in features], 456 | dtype=torch.long) 457 | elif output_mode == "regression": 458 | all_labels = torch.tensor([f.label for f in features], 459 | dtype=torch.float) 460 | 461 | dataset = TensorDataset(all_input_ids, all_attention_mask, 462 | all_token_type_ids, all_labels) 463 | return dataset 464 | 465 | 466 | def main(): 467 | parser = argparse.ArgumentParser() 468 | 469 | # Required parameters 470 | parser.add_argument( 471 | "--data_dir", 472 | default=None, 473 | type=str, 474 | required=True, 475 | help="The input data dir. Should contain the .tsv files " 476 | "(or other data files) for the task.", 477 | ) 478 | parser.add_argument( 479 | "--model_type", 480 | default=None, 481 | type=str, 482 | required=True, 483 | help="Model type selected in the list: ", 484 | ) 485 | parser.add_argument( 486 | "--model_name_or_path", 487 | default=None, 488 | type=str, 489 | required=True, 490 | help="Path to pretrained model or model identifier " 491 | "from huggingface.co/models", 492 | ) 493 | parser.add_argument( 494 | "--task_name", 495 | default=None, 496 | type=str, 497 | required=True, 498 | help=("The name of the task to train selected in the list: " 499 | + ", ".join(processors.keys())), 500 | ) 501 | parser.add_argument( 502 | "--output_dir", 503 | default=None, 504 | type=str, 505 | required=True, 506 | help="The output directory where the model predictions and " 507 | "checkpoints will be written.", 508 | ) 509 | 510 | # Other parameters 511 | parser.add_argument( 512 | "--config_name", 513 | default="", 514 | type=str, 515 | help="Pretrained config name or path if not the same as model_name", 516 | ) 517 | parser.add_argument( 518 | "--tokenizer_name", 519 | default="", 520 | type=str, 521 | help="Pretrained tokenizer name or path if not the same as model_name", 522 | ) 523 | parser.add_argument( 524 | "--cache_dir", 525 | default="", 526 | type=str, 527 | help="Where do you want to store the pre-trained models " 528 | "downloaded from s3", 529 | ) 530 | parser.add_argument( 531 | "--max_seq_length", 532 | default=128, 533 | type=int, 534 | help="The maximum total input sequence length after tokenization. " 535 | "Sequences longer than this will be truncated, sequences shorter " 536 | "will be padded.", 537 | ) 538 | parser.add_argument( 539 | "--do_train", 540 | action="store_true", 541 | help="Whether to run training." 542 | ) 543 | parser.add_argument( 544 | "--do_eval", 545 | action="store_true", 546 | help="Whether to run eval on the dev set." 547 | ) 548 | parser.add_argument( 549 | "--do_pred", 550 | action="store_true", 551 | help="Whether to run eval on the test set." 552 | ) 553 | parser.add_argument( 554 | "--evaluate_during_training", 555 | action="store_true", 556 | help="Run evaluation during training at each logging step.", 557 | ) 558 | 559 | parser.add_argument( 560 | "--per_gpu_train_batch_size", 561 | default=8, 562 | type=int, 563 | help="Batch size per GPU/CPU for training.", 564 | ) 565 | parser.add_argument( 566 | "--per_gpu_eval_batch_size", 567 | default=8, type=int, 568 | help="Batch size per GPU/CPU for evaluation.", 569 | ) 570 | parser.add_argument( 571 | "--gradient_accumulation_steps", 572 | type=int, 573 | default=1, 574 | help="Number of updates steps to accumulate " 575 | "before performing a backward/update pass.", 576 | ) 577 | parser.add_argument( 578 | "--learning_rate", 579 | default=5e-5, 580 | type=float, 581 | help="The initial learning rate for Adam." 582 | ) 583 | parser.add_argument( 584 | "--weight_decay", 585 | default=0.0, 586 | type=float, 587 | help="Weight decay if we apply some." 588 | ) 589 | parser.add_argument( 590 | "--adam_epsilon", 591 | default=1e-8, 592 | type=float, 593 | help="Epsilon for Adam optimizer." 594 | ) 595 | parser.add_argument( 596 | "--max_grad_norm", 597 | default=1.0, 598 | type=float, 599 | help="Max gradient norm." 600 | ) 601 | parser.add_argument( 602 | "--num_train_epochs", 603 | default=3.0, 604 | type=float, 605 | help="Total number of training epochs to perform.", 606 | ) 607 | parser.add_argument( 608 | "--max_steps", 609 | default=-1, 610 | type=int, 611 | help="If > 0: set total number of training steps to perform. " 612 | "Override num_train_epochs.", 613 | ) 614 | parser.add_argument( 615 | "--warmup_steps", 616 | default=0, 617 | type=int, 618 | help="Linear warmup over warmup_steps." 619 | ) 620 | parser.add_argument( 621 | "--logging_steps", 622 | type=int, 623 | default=500, 624 | help="Log every X updates steps." 625 | ) 626 | parser.add_argument( 627 | "--save_steps", 628 | type=int, 629 | default=500, 630 | help="Save checkpoint every X updates steps." 631 | ) 632 | parser.add_argument( 633 | "--eval_all_checkpoints", 634 | action="store_true", 635 | help="Evaluate all checkpoints starting with the same prefix as " 636 | "model_name ending and ending with step number", 637 | ) 638 | parser.add_argument( 639 | "--write_preds", 640 | action="store_true", 641 | help="Write predictions to a file" 642 | ) 643 | parser.add_argument( 644 | "--no_cuda", 645 | action="store_true", 646 | help="Avoid using CUDA when available" 647 | ) 648 | parser.add_argument( 649 | "--overwrite_output_dir", 650 | action="store_true", 651 | help="Overwrite the content of the output directory", 652 | ) 653 | parser.add_argument( 654 | "--overwrite_cache", 655 | action="store_true", 656 | help="Overwrite the cached training and evaluation sets", 657 | ) 658 | parser.add_argument( 659 | "--seed", 660 | type=int, 661 | default=42, 662 | help="random seed for initialization") 663 | 664 | parser.add_argument( 665 | "--fp16", 666 | action="store_true", 667 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) " 668 | "instead of 32-bit", 669 | ) 670 | parser.add_argument( 671 | "--fp16_opt_level", 672 | type=str, 673 | default="O1", 674 | help="For fp16: Apex AMP optimization level selected in " 675 | "['O0', 'O1', 'O2', and 'O3']. See details at " 676 | "https://nvidia.github.io/apex/amp.html", 677 | ) 678 | parser.add_argument( 679 | "--local_rank", 680 | type=int, 681 | default=-1, 682 | help="For distributed training: local_rank" 683 | ) 684 | parser.add_argument("--server_ip", 685 | type=str, 686 | default="", 687 | help="For distant debugging." 688 | ) 689 | parser.add_argument("--server_port", 690 | type=str, 691 | default="", 692 | help="For distant debugging." 693 | ) 694 | args = parser.parse_args() 695 | 696 | if ( 697 | os.path.exists(args.output_dir) 698 | and os.listdir(args.output_dir) 699 | and args.do_train 700 | and not args.overwrite_output_dir 701 | ): 702 | raise ValueError( 703 | "Output directory ({}) already exists and is not empty. " 704 | "Use --overwrite_output_dir to overcome.".format( 705 | args.output_dir 706 | ) 707 | ) 708 | 709 | # Setup distant debugging if needed 710 | if args.server_ip and args.server_port: 711 | # Distant debugging 712 | import ptvsd 713 | 714 | print("Waiting for debugger attach") 715 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), 716 | redirect_output=True) 717 | ptvsd.wait_for_attach() 718 | 719 | # Setup CUDA, GPU & distributed training 720 | if args.local_rank == -1 or args.no_cuda: 721 | device = torch.device("cuda" if torch.cuda.is_available() 722 | and not args.no_cuda else "cpu") 723 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 724 | else: 725 | # Initializes the distributed backend which 726 | # will take care of sychronizing nodes/GPUs 727 | torch.cuda.set_device(args.local_rank) 728 | device = torch.device("cuda", args.local_rank) 729 | torch.distributed.init_process_group(backend="nccl") 730 | args.n_gpu = 1 731 | args.device = device 732 | 733 | # Setup logging 734 | logging.basicConfig( 735 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 736 | datefmt="%m/%d/%Y %H:%M:%S", 737 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 738 | ) 739 | logger.warning( 740 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, " 741 | "16-bits training: %s", 742 | args.local_rank, 743 | device, 744 | args.n_gpu, 745 | bool(args.local_rank != -1), 746 | args.fp16, 747 | ) 748 | 749 | # Set seed 750 | set_seed(args) 751 | 752 | # Prepare task 753 | args.task_name = args.task_name.lower() 754 | if args.task_name not in processors: 755 | raise ValueError("Task not found: %s" % (args.task_name)) 756 | 757 | processor = processors[args.task_name]() 758 | args.output_mode = output_modes[args.task_name] 759 | label_list = processor.get_labels() 760 | num_labels = len(label_list) 761 | 762 | # Load pretrained model and tokenizer 763 | if args.local_rank not in [-1, 0]: 764 | # Make sure only the first process in distributed training will 765 | # download model & vocab 766 | torch.distributed.barrier() 767 | 768 | args.model_type = args.model_type.lower() 769 | 770 | config = AutoConfig.from_pretrained( 771 | args.config_name if args.config_name else args.model_name_or_path, 772 | num_labels=num_labels, 773 | label2id={label: i for i, label in enumerate(label_list)}, 774 | id2label={str(i): label for i, label in enumerate(label_list)}, 775 | finetuning_task=args.task_name, 776 | cache_dir=args.cache_dir if args.cache_dir else None, 777 | ) 778 | tokenizer = AutoTokenizer.from_pretrained( 779 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 780 | cache_dir=args.cache_dir if args.cache_dir else None 781 | ) 782 | model = AutoModelForSequenceClassification.from_pretrained( 783 | args.model_name_or_path, 784 | from_tf=bool(".ckpt" in args.model_name_or_path), 785 | config=config, 786 | cache_dir=args.cache_dir if args.cache_dir else None, 787 | ) 788 | 789 | 790 | if args.local_rank == 0: 791 | # Make sure only the first process in distributed 792 | # training will download model & vocab 793 | torch.distributed.barrier() 794 | 795 | model.to(args.device) 796 | 797 | logger.info("Training/evaluation parameters %s", args) 798 | 799 | # Training 800 | if args.do_train: 801 | train_dataset = load_and_cache_examples(args, args.task_name, 802 | tokenizer, mode="train") 803 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 804 | logger.info(" global_step = %s, average loss = %s", 805 | global_step, tr_loss) 806 | 807 | # Saving best-practices: if you use defaults names for the model, 808 | # you can reload it using from_pretrained() 809 | if (args.do_train 810 | and (args.local_rank == -1 or torch.distributed.get_rank() == 0)): 811 | # Create output directory if needed 812 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 813 | os.makedirs(args.output_dir) 814 | 815 | logger.info("Saving model checkpoint to %s", args.output_dir) 816 | # Save a trained model, configuration and tokenizer 817 | # using `save_pretrained()`. They can then be reloaded using 818 | # `from_pretrained()` 819 | model_to_save = ( 820 | model.module if hasattr(model, "module") else model 821 | ) # Take care of distributed/parallel training 822 | model_to_save.save_pretrained(args.output_dir) 823 | tokenizer.save_pretrained(args.output_dir) 824 | 825 | # Good practice: save your training arguments 826 | # together with the trained model 827 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 828 | 829 | # Load a trained model and vocabulary that you have fine-tuned 830 | model = AutoModelForSequenceClassification.from_pretrained(args.output_dir) 831 | tokenizer = AutoTokenizer.from_pretrained(args.output_dir) 832 | model.to(args.device) 833 | 834 | # Evaluation 835 | results = {} 836 | best_f1_eval = 0 837 | best_model_checkpoint = args.output_dir 838 | if args.do_eval and args.local_rank in [-1, 0]: 839 | tokenizer = AutoTokenizer.from_pretrained( 840 | args.output_dir) 841 | checkpoints = [args.output_dir] 842 | if args.eval_all_checkpoints: 843 | checkpoints = list( 844 | os.path.dirname(c) 845 | for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, 846 | recursive=True)) 847 | ) 848 | # Reduce logging 849 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) 850 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 851 | for checkpoint in checkpoints: 852 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 853 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 854 | 855 | model = AutoModelForSequenceClassification.from_pretrained(checkpoint) 856 | model.to(args.device) 857 | result = evaluate(args, model, tokenizer, mode="dev", prefix=prefix) 858 | # getting the best model checkpoint 859 | if result['f1'] > best_f1_eval: 860 | best_f1_eval = result['f1'] 861 | best_model_checkpoint = checkpoint 862 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 863 | results.update(result) 864 | # renaming the best model checkpoint folder 865 | # os.rename(best_model_checkpoint, best_model_checkpoint + "-best") 866 | 867 | if args.do_pred and args.local_rank in [-1, 0]: 868 | tokenizer = AutoTokenizer.from_pretrained( 869 | args.output_dir 870 | ) 871 | model = AutoModelForSequenceClassification.from_pretrained( 872 | args.output_dir 873 | ) 874 | model.to(args.device) 875 | result = evaluate(args, model, tokenizer, mode="test") 876 | result = dict((k, v) for k, v in result.items()) 877 | results.update(result) 878 | 879 | return results 880 | 881 | 882 | if __name__ == "__main__": 883 | main() 884 | --------------------------------------------------------------------------------