├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── STORE_RUN_FILE ├── Train_bert │ ├── node1gpu1 │ │ └── main.sh │ └── node2gpu4 │ │ ├── node2gpu4_main.sh │ │ ├── node2gpu4_sub_1.sh │ │ └── sub.sh ├── Train_bert_fine_tuning_ner │ ├── bert_from_reddit_to_aida_ner │ │ ├── run_bert_fine_tuning_ner.o1724525 │ │ └── run_bert_fine_tuning_ner.sh │ └── bert_from_transformers_to_aida_ner │ │ └── run_bert_fine_tuning_ner.sh └── Train_mnist │ ├── node1gpu4 │ ├── eval.sh │ └── main.sh │ ├── node2gpu8 │ ├── eval.sh │ ├── main.sh │ ├── run.sh │ └── sub1.sh │ ├── node2gpu8_het │ ├── eval.sh │ ├── main.sh │ ├── run.sh │ └── sub1.sh │ ├── node4gpu16 │ ├── eval.sh │ ├── main.sh │ ├── run.sh │ ├── sub1.sh │ ├── sub2.sh │ └── sub3.sh │ ├── node4gpu16_het │ ├── eval.sh │ ├── main.sh │ ├── run.sh │ ├── sub1.sh │ ├── sub2.sh │ └── sub3.sh │ └── node8gpu32_het │ ├── eval.sh │ ├── main.sh │ ├── run.sh │ ├── sub1.sh │ ├── sub2.sh │ ├── sub3.sh │ ├── sub4.sh │ ├── sub5.sh │ ├── sub6.sh │ └── sub7.sh ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── dataset.rst │ ├── distribute.rst │ ├── examples.rst │ ├── extending.rst │ ├── index.rst │ ├── installing.rst │ ├── lr_scheduler.rst │ ├── meters.rst │ ├── model.rst │ ├── optimizer.rst │ ├── parameter.rst │ ├── progress_bar.rst │ ├── support.rst │ └── task.rst ├── hetseq ├── __init__.py ├── bert_modeling.py ├── checkpoint_utils.py ├── controller.py ├── data │ ├── BERT_DATA.py │ ├── BERT_DATA_test.py │ ├── __init__.py │ ├── bert_el_dataset.py │ ├── bert_ner_dataset.py │ ├── data_utils.py │ ├── data_utils_fast.pyx │ ├── h5pyDataset.py │ ├── iterators.py │ └── mnist_dataset.py ├── data_collator │ ├── __init__.py │ └── data_collator.py ├── distributed_utils.py ├── eval_bert_fine_tuning_ner.py ├── eval_mnist.py ├── file_utils.py ├── lr_scheduler.py ├── meters.py ├── model │ ├── __init__.py │ └── bert_for_EL_classification.py ├── optim.py ├── options.py ├── progress_bar.py ├── tasks │ ├── __init__.py │ ├── bert_for_el_classification_task.py │ ├── bert_for_token_classification_task.py │ └── tasks.py ├── train.py ├── transformers_tasks.py └── utils.py ├── requirements.txt ├── setup.py └── test ├── TRANFORMERS_test_eval_bert_fine_tuning.py ├── TRANFORMERS_test_eval_bert_fine_tuning.sh ├── WLGPU_test_conll_2003_dataset.sh ├── YD_test_mention_detection_conll.py ├── test_conll_2003_NER ├── test_conll_2003_dataset.py └── test_conll_2003_dataset.sh ├── test_entity_vocab ├── test_CONLL_2003_EL.py └── test_diff_dict.py ├── test_eval_bert_fine_tuning.py ├── test_eval_bert_fine_tuning.sh └── test_eval_bert_fine_tuning_aida.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | CRC_RUN_FILE/ 3 | preprocessing/ 4 | Nvidia_BERT/ 5 | hetseq.egg-info 6 | build/ 7 | support/ 8 | __pycache__/ 9 | *.so 10 | *.cpp 11 | hetseq/data/BERT_DATA.py 12 | hetseq/data/BERT_DATA_test.py 13 | test/script 14 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Build documentation with MkDocs 13 | #mkdocs: 14 | # configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF 17 | formats: 18 | - pdf 19 | 20 | # Optionally set the version of Python and requirements required to build your docs 21 | python: 22 | version: 3.7 23 | install: 24 | - requirements: requirements.txt 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yifan Ding and Weninger Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HetSeq: Distributed GPU Training on Heterogeneous Infrastructure 2 | This is our coding implementation for the paper: 3 | 4 | >Yifan Ding, Nicholas Botzer, Tim Weninger. 5 | HetSeq: Distributed GPU Training on Heterogeneous Infrastructure, Proc. of Association for the Advancement of Artificial Intelligence (AAAI) Innovative Application of Artificial Intelligence, February 2021. 6 | 7 | Author: Yifan Ding (yding4@nd.edu) 8 | 9 | **arxiv paper available**: https://arxiv.org/abs/2009.14783 10 | 11 | **Documentation available**: https://hetseq.readthedocs.io 12 | 13 | **Medium towards data science Post:** [Training BERT at a University](https://towardsdatascience.com/training-bert-at-a-university-eedcf940c754) 14 | 15 | Documentation includes [Distributed Setting](https://hetseq.readthedocs.io/en/master/distribute.html), [Scripts to Run HetSeq](https://hetseq.readthedocs.io/en/master/examples.html), [Extending HetSeq](https://hetseq.readthedocs.io/en/master/extending.html), [Parameter Explanation](https://hetseq.readthedocs.io/en/master/parameter.html) and [Code Reference](https://hetseq.readthedocs.io/en/master/task.html). 16 | 17 | ## Overview 18 | HetSeq is a distributed neural network platiform designed to run on Heterogeneous Infrastructure with common scientific shared file system. 19 | It can be run directly on command line with SSH or task queen submission system without privilege or any extra packages. It takes care of the data index randomization and assignment to different GPUs in the multi-node and multi-GPU setting. Users can easily extend HetSeq to many other models with minimum effort. 20 | 21 | HetSeq requires installation of [PyTorch](https://github.com/pytorch/pytorch) with GPU support and [NCCL](https://developer.nvidia.com/nccl). 22 | 23 | ## Installation 24 | 1) create and activate conda virtual environment with Python 3.7.4 (recommended) 25 | ``` 26 | $ conda create --name hetseq 27 | $ conda activate hetseq 28 | $ conda install python=3.7.4 29 | ``` 30 | 31 | 2) Git clone directory and install nessasory package 32 | ``` 33 | $ git clone https://github.com/yifding/hetseq.git 34 | $ cd /path/to/hetseq 35 | $ pip install -r requirements.txt 36 | $ pip install --editable . 37 | ``` 38 | 39 | 3) **To Run BERT:** Download data files including training corpus, model configuration, and BPE dictionary. Test corpus from [here](https://drive.google.com/file/d/1ZPJVAiV7PsewChi7xKACrjuniJ2N9Sry/view?usp=sharing), full data from [this link](https://drive.google.com/file/d/1Vq_UO-T9345uYs8a7zloukGfhDXSDd2A/view?usp=sharing). Download test_DATA.zip for test or DATA.zip for full run, unzip it and place the ```preprocessing/``` directory inside the package directory. Available corpus under ```preprocessing/```, 40 | 41 | * phase one of BERT training corpus : ```preprocessing/hdf5_lower_case_1_seq_len_128.../wikicorpus_en/``` 42 | * phase two of BERT training corpus : ```preprocessing/hdf5_lower_case_1_seq_len_512.../wikicorpus_en/``` 43 | * sample test for phase one : ```preprocessing/test128/``` 44 | * sample test for phase two : ```preprocessing/test512/``` 45 | * see [NVIDIA-pytorch-BERT](https://arxiv.org/abs/1810.04805), [google_original_BERT](https://github.com/google-research/bert) and [BERT paper](https://arxiv.org/abs/1810.04805) for more information. 46 | * current provided is generated from [NVIDIA-pytorch-BERT](https://arxiv.org/abs/1810.04805) with wikipedia data (book data is not available) 47 | 48 | 4) Running HetSeq script is available at https://hetseq.readthedocs.io/en/master/examples.html, 49 | 50 | ### Distributed Configuration 51 | HetSeq can be executed on single GPU on a single node, multiple GPUs on a single node, or multiple GPUs across multiple nodes. Main logic is defined at [train.py](https://github.com/yifding/hetseq/blob/master/train.py#L213). 52 | 53 | * **--distributed-init-method**: defines an initialization. e.g.: "tcp://10.32.82.207:11111" (tcp for multiple nodes) or "file:///hetseq/communicate.txt" (shared file for multiple nodes). 54 | * **--distributed-world-size**: total number of GPUs used in the training. 55 | * **--distributed-gpus**: the number of GPUs on the current node. 56 | * **--distributed-rank**: represents the rank/index of the first GPU used on current node. 57 | 58 | 59 | 60 | 61 | ## Performance table 62 | Running BERT on nodes with 4 GPUs each. 63 | | nodes | GPUs | epochs | batch size | steps | avg. time per step | training time | training loss | expansion | speedup | 64 | |:-----:|:----:|:------:|:----------:|:-------:|:------------------:|:-------------:|:-------------:|:---------:|:-------:| 65 | | 1 | 4 | 5 | 128 | 267,139 | 2.60s | 7.19d | 0.026 | 1 | 1 | 66 | | 2 | 8 | 5 | 256 | 133,570 | 2.69s | 4.19d | 0.028 | 0.86 | 1.72 | 67 | | 4 | 16 | 5 | 512 | 66,785 | 2.794 | 2.23d | 0.031 | 0.81 | 3.22 | 68 | | 8 | 32 | 5 | 1024 | 33,393 | 3.126 | 1.21d | 0.055 | 0.74 | 5.94 | 69 | 70 | ## Notice and tips 71 | > loading BERT data takes a while. 72 | 73 | ## Known issues 74 | - [ ] currently not supporting continue training 75 | - [ ] mnist datasets download does not support multiple GPUs 76 | 77 | ## future patch 78 | - [ ] bert processing pipeline not included 79 | - [ ] interface of datasets/transformers not included 80 | - [ ] hetseq not supporting download from pip 81 | - [ ] evaluation separate/combined not included 82 | - [ ] fp16 support 83 | 84 | ## License 85 | this repository is MIT-licensed. It is created based on [fairseq](https://github.com/pytorch/fairseq), 86 | [NVIDIA-BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT), and 87 | [pytorch](https://github.com/pytorch/pytorch) 88 | 89 | Please send us e-mail or leave comments on github if have any questions. 90 | 91 | 92 | Copyright (c) 2020 Yifan Ding and [Weninger Lab](https://www3.nd.edu/~tweninge/) 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert/node1gpu1/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-001 # specify the queue 6 | #$-l gpu_card=1 7 | #$-N bert_bsz_512-num_sentece_32_gpu_8-update_4-phase_1_main 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | 14 | python3 ${DIST}/hetseq/train.py \ 15 | --task bert --data ${DIST}/preprocessing/test_128/ \ 16 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 17 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 18 | --max-sentences 32 --fast-stat-sync --max-update 900000 --update-freq 4 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 21 | --weight-decay 0.01 --distributed-world-size 1 \ 22 | --device-id 1 --save-dir sing_gpu_test \ 23 | 2>&1 | tee sing_gpu_test.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert/node2gpu4/node2gpu4_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-004 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N bert_main 8 | 9 | #conda environment is on ~/Transformer 10 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 11 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 12 | 13 | DIST=/scratch365/yding4/hetseq 14 | 15 | python3 ${DIST}/hetseq/train.py \ 16 | --task bert --data ${DIST}/preprocessing/test_128/ \ 17 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 18 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 19 | --max-sentences 32 --fast-stat-sync --max-update 450000 --update-freq 4 \ 20 | --valid-subset test --num-workers 4 \ 21 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 22 | --weight-decay 0.01 \ 23 | --save-dir bsz_512-num_sentece_32_gpu_8-update_4-phase_1 \ 24 | --distributed-init-method tcp://10.32.82.211:12345 \ 25 | --distributed-world-size 8 \ 26 | --distributed-gpus 4 \ 27 | --distributed-rank 0 \ 28 | 2>&1 | tee bert_bsz_512-num_sentece_32_gpu_8-update_4-phase_1_main.log 29 | # bsz_4-gpu_4-update_16-phase_1-file_cent.log 30 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert/node2gpu4/node2gpu4_sub_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-008 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N bert_sub1 8 | 9 | #conda environment is on ~/Transformer 10 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 11 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 12 | 13 | DIST=/scratch365/yding4/hetseq 14 | 15 | python3 ${DIST}/hetseq/train.py \ 16 | --task bert --data ${DIST}/preprocessing/test_128/ \ 17 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 18 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 19 | --max-sentences 32 --fast-stat-sync --max-update 450000 --update-freq 4 \ 20 | --valid-subset test --num-workers 4 \ 21 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 22 | --weight-decay 0.01 \ 23 | --save-dir bsz_512-num_sentece_32_gpu_8-update_4-phase_1 \ 24 | --distributed-init-method tcp://10.32.82.211:12345 \ 25 | --distributed-world-size 8 \ 26 | --distributed-gpus 4 \ 27 | --distributed-rank 4 \ 28 | 2>&1 | tee bert_bsz_512-num_sentece_32_gpu_8-update_4-phase_1_sub.log 29 | # bsz_4-gpu_4-update_16-phase_1-file_cent.log 30 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert/node2gpu4/sub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub node2gpu4_main.sh 4 | qsub node2gpu4_sub_1.sh 5 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_reddit_to_aida_ner/run_bert_fine_tuning_ner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-007 # specify the queue 6 | #$-l gpu_card=1 7 | #$-N run_bert_fine_tuning_ner 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/BI_local_conll2003.py 13 | DATA_DIR=/scratch365/yding4/EL_resource/data/CONLL2003/ 14 | TRAIN_FILE=train.txt 15 | VALIDATION_FILE=valid.txt 16 | TEST_FILE=test.txt 17 | HETSEQ_STATE_DICT=/scratch365/yding4/Reddit_EL/RUN_FILE/Reddit_1st_test/node2gpu4/Reddit_512-num_sentece_6_gpu_8-update_4-phase_1/checkpoint3.pt 18 | 19 | DIST=/scratch365/yding4/hetseq 20 | 21 | python3 ${DIST}/hetseq/train.py \ 22 | --task BertForTokenClassification --optimizer adam --lr-scheduler PolynomialDecayScheduler \ 23 | --fast-stat-sync --max-update 5000 --update-freq 1 \ 24 | --valid-subset test --num-workers 4 \ 25 | --warmup-updates 0 --total-num-update 50000 --lr 0.0001 \ 26 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 27 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 28 | --hetseq_state_dict ${HETSEQ_STATE_DICT} \ 29 | --train_file ${DATA_DIR}${TRAIN_FILE} \ 30 | --validation_file ${DATA_DIR}${VALIDATION_FILE} \ 31 | --test_file ${DATA_DIR}${TEST_FILE} \ 32 | --extension_file ${EXTENSION_FILE} \ 33 | --max-sentences 32 \ 34 | --num-workers 8 \ 35 | --load_state_dict_strict "False" \ 36 | --find-unused-parameters \ 37 | --save-dir bert_fine_tuning_ner 38 | # --distributed-world-size 1 --device-id 1 \ 39 | 40 | 41 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_transformers_to_aida_ner/run_bert_fine_tuning_ner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-007 # specify the queue 6 | #$-l gpu_card=1 7 | #$-N run_bert_finu_tuning_ner 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/BI_local_conll2003.py 13 | DATA_DIR=/scratch365/yding4/EL_resource/data/CONLL2003/ 14 | TRAIN_FILE=train.txt 15 | VALIDATION_FILE=valid.txt 16 | TEST_FILE=test.txt 17 | TRANSFORMERS_STATE_DICT=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_transformers_to_aida_ner/pytorch_model.bin 18 | 19 | DIST=/scratch365/yding4/hetseq 20 | 21 | python3 ${DIST}/hetseq/train.py \ 22 | --task BertForTokenClassification --optimizer adam --lr-scheduler PolynomialDecayScheduler \ 23 | --fast-stat-sync --max-update 5000 --update-freq 1 \ 24 | --valid-subset test --num-workers 4 \ 25 | --warmup-updates 0 --total-num-update 5000 --lr 0.0001 \ 26 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 27 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 28 | --transformers_state_dict ${TRANSFORMERS_STATE_DICT} \ 29 | --train_file ${DATA_DIR}${TRAIN_FILE} \ 30 | --validation_file ${DATA_DIR}${VALIDATION_FILE} \ 31 | --test_file ${DATA_DIR}${TEST_FILE} \ 32 | --extension_file ${EXTENSION_FILE} \ 33 | --max-sentences 32 \ 34 | --num-workers 8 \ 35 | --load_state_dict_strict "True" \ 36 | --save-dir bert_fine_tuning_ner \ 37 | # --find-unused-parameters \ 38 | # --distributed-world-size 1 --device-id 1 \ 39 | 40 | 41 | -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node1gpu4/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N eval 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node1gpu4/node1gpu4/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node1gpu4/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node1gpu4_mnist_main 8 | 9 | #conda environment is on ~/Transformer 10 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 11 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 12 | 13 | DIST=/scratch365/yding4/hetseq 14 | 15 | python3 ${DIST}/hetseq/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --save-dir node1gpu4 2>&1 | tee node1gpu4.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_main1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node2gpu8/node2gpu8/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-004 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_main1 8 | 9 | #conda environment is on ~/Transformer 10 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 11 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 12 | 13 | DIST=/scratch365/yding4/hetseq 14 | AD=tcp://10.32.82.211:11111 15 | 16 | python3 ${DIST}/hetseq/train.py \ 17 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 18 | --data /scratch365/yding4/mnist/ --clip-norm 100 \ 19 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 20 | --valid-subset test --num-workers 4 \ 21 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 22 | --distributed-init-method ${AD} --distributed-world-size 8 \ 23 | --distributed-gpus 4 --distributed-rank 0 --save-dir node2gpu8 2>&1 | tee node2gpu8.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub main.sh 4 | qsub sub1.sh -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8/sub1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-008 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_sub1 8 | 9 | #conda environment is on ~/Transformer 10 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 11 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 12 | 13 | DIST=/scratch365/yding4/hetseq 14 | AD=tcp://10.32.82.211:11111 15 | 16 | python3 ${DIST}/hetseq/train.py \ 17 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 18 | --data /scratch365/yding4/mnist/ --clip-norm 100 \ 19 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 20 | --valid-subset test --num-workers 4 \ 21 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 22 | --distributed-init-method ${AD} --distributed-world-size 8 \ 23 | --distributed-gpus 4 --distributed-rank 4 --save-dir node2gpu8 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8_het/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_main1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node2gpu8_het/node2gpu8/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8_het/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_het_mnist_main 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 8 \ 22 | --distributed-gpus 4 --distributed-rank 0 --save-dir node2gpu8 2>&1 | tee node2gpu8.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8_het/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub main.sh 4 | qsub sub1.sh -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node2gpu8_het/sub1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_het_mnist_sub1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 8 \ 22 | --distributed-gpus 4 --distributed-rank 4 --save-dir node2gpu8 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_main1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node4gpu16/node4gpu16/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_main 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 0 --save-dir node4gpu16 | tee node4gpu16.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub main.sh 4 | qsub sub1.sh 5 | qsub sub2.sh 6 | qsub sub3.sh -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/sub1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 4 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/sub2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-004 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub2 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 8 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16/sub3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-004 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub3 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 12 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node2gpu8_mnist_main1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node4gpu16_het/node4gpu16/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_main 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 0 --save-dir node4gpu16 2>&1 | tee node4gpu16.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub main.sh 4 | qsub sub1.sh 5 | qsub sub2.sh 6 | qsub sub3.sh -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/sub1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 4 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/sub2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-005 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub2 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 8 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node4gpu16_het/sub3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-006 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node4gpu16_mnist_sub3 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 16 \ 22 | --distributed-gpus 4 --distributed-rank 12 --save-dir node4gpu16 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N eval 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | code=/scratch365/yding4/hetseq/eval_mnist.py 13 | ckpt=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_mnist/node8gpu32_het/node8gpu32/checkpoint_last.pt 14 | mnist=/scratch365/yding4/mnist/MNIST/processed 15 | 16 | python3 ${code} --model_ckpt ${ckpt} --mnist_dir ${mnist} 2>&1 |tee eval.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-002 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_main 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 0 --save-dir node8gpu32 2>&1 | tee node8gpu32.log -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | qsub main.sh 4 | qsub sub1.sh 5 | qsub sub2.sh 6 | qsub sub3.sh 7 | qsub sub4.sh 8 | qsub sub5.sh 9 | qsub sub6.sh 10 | qsub sub7.sh -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub1 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 4 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-004 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub2 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 8 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-005 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub3 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 12 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-005 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub4 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 16 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-006 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub5 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 20 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-007 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub6 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 24 --save-dir node8gpu32 -------------------------------------------------------------------------------- /STORE_RUN_FILE/Train_mnist/node8gpu32_het/sub7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-p100-008 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N node8gpu32_mnist_sub7 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/Transformer/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | AD=tcp://10.32.82.207:11111 14 | 15 | python3 ${DIST}/train.py \ 16 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 17 | --data /scratch365/yding4/mnist/MNIST/processed/ --clip-norm 100 \ 18 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 19 | --valid-subset test --num-workers 4 \ 20 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 21 | --distributed-init-method ${AD} --distributed-world-size 32 \ 22 | --distributed-gpus 4 --distributed-rank 28 --save-dir node8gpu32 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath('..')) 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'HetSeq' 23 | copyright = '2020, Yifan Ding' 24 | author = 'Yifan Ding' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = '0.0.1' 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | "sphinx.ext.viewcode", 37 | "sphinx.ext.autodoc", 38 | "sphinx_rtd_theme", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = [] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = 'sphinx_rtd_theme' 56 | master_doc = 'index' 57 | 58 | # Add any paths that contain custom static files (such as style sheets) here, 59 | # relative to this directory. They are copied after the builtin static files, 60 | # so a file named "default.css" will overwrite the builtin "default.css". 61 | html_static_path = ['_static'] 62 | 63 | autodoc_mock_imports = ["numpy", "torch", "h5py", "torchvision", "PIL", "boto3", "botocore", "tqdm"] 64 | -------------------------------------------------------------------------------- /docs/source/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ------- 3 | 4 | .. autoclass:: hetseq.data.h5pyDataset.BertH5pyData 5 | 6 | .. autoclass:: hetseq.data.h5pyDataset.ConBertH5pyData 7 | 8 | .. autoclass:: hetseq.data.mnist_dataset.MNISTDataset -------------------------------------------------------------------------------- /docs/source/distribute.rst: -------------------------------------------------------------------------------- 1 | .. _distribute: 2 | 3 | ******************* 4 | Distributed Setting 5 | ******************* 6 | 7 | 8 | HetSeq can be executed on single GPU on a single node, multiple GPUs on a single node, or multiple GPUs across multiple nodes. Main logic is defined at `train.py `__. 9 | 10 | 11 | Control Parameters 12 | ------------------ 13 | ``--distributed-init-method``: defines an initialization. 14 | 15 | * ``tcp://10.32.82.207:11111`` (IP address:port. TCP example for multiple nodes) or 16 | * ``file:///hetseq/communicate.txt`` (shared file example for multiple nodes). 17 | 18 | ``--distributed-world-size``: total number of GPUs used in the training. 19 | 20 | ``--distributed-gpus``: the number of GPUs on the current node. 21 | 22 | ``--distributed-rank``: represents the rank/index of the first GPU used on current node. 23 | 24 | 25 | Different Distributed Settings 26 | ------------------------------ 27 | 28 | 1.Single GPU: 29 | 30 | .. code:: console 31 | 32 | $ --distributed-world-size 1 --device-id 1 33 | 34 | 2. Four GPUs on a single node: 35 | 36 | .. code:: console 37 | 38 | $ --distributed-world-size 4 39 | 40 | 3. Four nodes with four GPUs each (16 GPUs in total) ``10.00.123.456`` is the IP address of first node and ``11111`` is the port number: 41 | 42 | * 1st node 43 | 44 | .. code:: console 45 | 46 | $ --distributed-init-method tcp://10.00.123.456:11111 --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 0 47 | 48 | * 2nd node 49 | 50 | .. code:: console 51 | 52 | $ --distributed-init-method tcp://10.00.123.456:11111 --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 4 53 | 54 | * 3rd node 55 | 56 | .. code:: console 57 | 58 | $ --distributed-init-method tcp://10.00.123.456:11111 --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 8 59 | 60 | * 4th node 61 | 62 | .. code:: console 63 | 64 | $ --distributed-init-method tcp://10.00.123.456:11111 --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 12 65 | 66 | 67 | 68 | 69 | 70 | 71 | Main Logic 72 | ---------- 73 | 74 | .. code:: python 75 | 76 | if args.distributed_init_method is not None: 77 | assert args.distributed_gpus <= torch.cuda.device_count() 78 | 79 | if args.distributed_gpus > 1 and not args.distributed_no_spawn: 80 | start_rank = args.distributed_rank 81 | args.distributed_rank = None # assign automatically 82 | torch.multiprocessing.spawn( 83 | fn=distributed_main, 84 | args=(args, start_rank), 85 | nprocs=args.distributed_gpus, 86 | ) 87 | else: 88 | distributed_main(args.device_id, args) 89 | 90 | elif args.distributed_world_size > 1: 91 | assert args.distributed_world_size <= torch.cuda.device_count() 92 | port = random.randint(10000, 20000) 93 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) 94 | args.distributed_rank = None # set based on device id 95 | torch.multiprocessing.spawn( 96 | fn=distributed_main, 97 | args=(args, ), 98 | nprocs=args.distributed_world_size, 99 | ) 100 | else: 101 | main(args) 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | ************** 2 | Running Script 3 | ************** 4 | 5 | 6 | BERT Task 7 | --------- 8 | 9 | 10 | 1. Single GPU 11 | 12 | .. code-block:: console 13 | 14 | $ DIST=~/hetseq 15 | $ python3 ${DIST}/hetseq/train.py \ 16 | --task bert --data ${DIST}/preprocessing/test_128/ \ 17 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 18 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 19 | --max-sentences 32 --fast-stat-sync --max-update 900000 --update-freq 4 \ 20 | --valid-subset test --num-workers 4 \ 21 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 22 | --weight-decay 0.01 --distributed-world-size 1 \ 23 | --device-id 0 --save-dir bert_single_gpu 24 | 25 | 2. Multiple GPU on a single Node, examples are all four GPUs on one Node. 26 | 27 | .. code:: console 28 | 29 | $ DIST=~/hetseq 30 | $ python3 ${DIST}/hetseq/train.py \ 31 | --task bert --data ${DIST}/preprocessing/test_128/ \ 32 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 33 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 34 | --max-sentences 32 --fast-stat-sync --max-update 900000 --update-freq 4 \ 35 | --valid-subset test --num-workers 4 \ 36 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 37 | --weight-decay 0.01 \ 38 | --save-dir bert_node1gpu4 39 | 40 | 3. Multiple GPUs on multiple nodes, examples are two nodes with four GPUs each. 41 | 42 | 43 | * on the main node 44 | 45 | .. code:: console 46 | 47 | $ DIST=~/hetseq 48 | $ python3 ${DIST}/hetseq/train.py \ 49 | --task bert --data ${DIST}/preprocessing/test_128/ \ 50 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 51 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 52 | --max-sentences 32 --fast-stat-sync --max-update 900000 --update-freq 4 \ 53 | --valid-subset test --num-workers 4 \ 54 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 55 | --weight-decay 0.01 --save-dir bert_node2gpu4 \ 56 | --distributed-init-method tcp://10.00.123.456:11111 \ 57 | --distributed-world-size 8 --distributed-gpus 4 --distributed-rank 0 58 | 59 | * on the other second node 60 | 61 | .. code:: console 62 | 63 | $ DIST=~/hetseq 64 | $ python3 ${DIST}/hetseq/train.py \ 65 | --task bert --data ${DIST}/preprocessing/test_128/ \ 66 | --dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 67 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 68 | --max-sentences 32 --fast-stat-sync --max-update 900000 --update-freq 4 \ 69 | --valid-subset test --num-workers 4 \ 70 | --warmup-updates 10000 --total-num-update 1000000 --lr 0.0001 \ 71 | --weight-decay 0.01 \ 72 | --distributed-init-method tcp://10.00.123.456:11111 \ 73 | --distributed-world-size 8 --distributed-gpus 4 --distributed-rank 4 74 | 75 | 76 | 77 | MNIST Task 78 | ---------- 79 | 80 | 1. Single GPU 81 | 82 | .. code:: console 83 | 84 | $ DIST=~/hetseq 85 | $ python3 ${DIST}/hetseq/train.py \ 86 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 87 | --data ${DIST} --clip-norm 100 \ 88 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 89 | --valid-subset test --num-workers 4 \ 90 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 91 | --distributed-world-size 1 --device-id 0 --save-dir mnist_single_node 92 | 93 | 94 | 2. Multiple GPU on a single Node, examples are all four GPUs on one Node. 95 | 96 | .. code:: console 97 | 98 | $ DIST=~/hetseq 99 | $ python3 ${DIST}/hetseq/train.py \ 100 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 101 | --data ${DIST} --clip-norm 100 \ 102 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 103 | --valid-subset test --num-workers 4 \ 104 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 105 | --save-dir mnist_node1gpu4 106 | 107 | 3. Multiple GPUs on multiple nodes, examples are two nodes with four GPUs each. 108 | 109 | 110 | * on the main node 111 | 112 | .. code:: console 113 | 114 | $ DIST=~/hetseq 115 | $ python3 ${DIST}/hetseq/train.py \ 116 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 117 | --data ${DIST} --clip-norm 100 \ 118 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 119 | --valid-subset test --num-workers 4 \ 120 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 121 | --save-dir mnist_node2gpu4 \ 122 | --distributed-init-method tcp://10.00.123.456:11111 \ 123 | --distributed-world-size 8 --distributed-gpus 4 --distributed-rank 0 124 | 125 | * on the other second node 126 | 127 | .. code:: console 128 | 129 | $ DIST=~/hetseq 130 | $ python3 ${DIST}/hetseq/train.py \ 131 | --task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \ 132 | --data ${DIST} --clip-norm 100 \ 133 | --max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \ 134 | --valid-subset test --num-workers 4 \ 135 | --warmup-updates 0 --total-num-update 50000 --lr 1.01 \ 136 | --distributed-init-method tcp://10.00.123.456:11111 \ 137 | --distributed-world-size 8 --distributed-gpus 4 --distributed-rank 4 138 | 139 | 140 | Evaluate MNIST Task 141 | ------------------- 142 | 143 | .. code:: console 144 | 145 | $ DIST=~/hetseq 146 | $ python3 ${DIST}/hetseq/eval_mnist.py --model_ckpt /path/to/check/point --mnist_dir ${DIST} 147 | -------------------------------------------------------------------------------- /docs/source/extending.rst: -------------------------------------------------------------------------------- 1 | ******** 2 | Overview 3 | ******** 4 | 5 | To extend HetSeq to another model, one needs to define a new :doc:`Task ` with corresponding :doc:`Model `, :doc:`Dataset `, :doc:`Optimizer ` and :doc:`Learning Rate Scheduler `. A ``MNIST`` example is given with all the extended classes. Pre-defined ``optimizers``, ``Learning Rate Scheduler``, ``datasets`` and ``models`` can be reused in other applications. 6 | 7 | Task 8 | ---- 9 | For each individual application, task is the basic unit. Defined by ``class Task`` in ``Task.py``. ``datasets`` is stored and load in a dictionary manner. Define a child class of ``Task`` to define a new task, necessary function is to define ``Model`` (in ``def build_model``), ``Dataset`` (in ``def load_dataset``). 10 | 11 | .. code:: python 12 | 13 | class Task(object): 14 | def __init__(self, args): 15 | self.args = args 16 | self.datasets = {} 17 | self.dataset_to_epoch_iter = {} 18 | 19 | def build_model(self, args): 20 | raise NotImplementedError 21 | 22 | def load_dataset(self, split, **kwargs): 23 | """Load a given dataset split. 24 | Args: 25 | split (str): name of the split (e.g., train, valid, test) 26 | """ 27 | raise NotImplementedError 28 | 29 | Dataset 30 | ------- 31 | 32 | Dataset should be defined as a child class of ``torch.utils.data.Dataset`` to be compatible with ``torch.utils.data.dataloader``. Need to define 33 | * ``__getitem__`` (get item), 34 | * ``__len__`` (total length of the datset), 35 | * ``ordered_indices`` (index used to split and assignment to different GPUs, 36 | * ``np.arange(len(self))``), 37 | * ``num_tokens`` (total tokens in a instance, ``1`` for image model), 38 | * ``collater`` (collater function to combined the output of ``__getitem__``, typically use ``torch.utils.data.dataloader.default_collate``) 39 | * ``set_epoch`` (``pass``) function in the class. 40 | 41 | See following ``MNIST example`` for more information. 42 | 43 | 44 | .. note:: 45 | 46 | In our implementation, each process/GPU has its own dataset and dataloader. When dataset is small (like MNIST example), the dataset can be put into ``__init__`` function. However, if the dataset is large (like BERT example or ImageNet), the dataset can not be loaded into memory at once, then the loading process should be defined inside ``__getitem__`` function. 47 | 48 | 49 | Model 50 | ----- 51 | 52 | Model should be defined as a child class of ``torch.nn.Module``. By default, the model should output a loss function. This is compatible with the ``def train_step(self, sample, model, optimizer, ignore_grad=False)`` function inside ``class Task``. One can change the logic but need to fit the ``train_step``. 53 | 54 | 55 | Optimizer 56 | --------- 57 | 58 | Optimizer in distributed data parallel (DDP) has to consider manipulate gradients and learning rates. In our implementation, optimizer (``class _Optimizer(object):``) is defined as a higher level class than ``torch.optim.Optimizer`` to include other parameters to be recorded. For example, the ``Adam`` optimizer provided in HetSeq, has initial learning rate:``lr``, beta1 and beta2: ``betas``, epsilon ``eps`` to avoid normalize by ``0`` and weight decay ``weight_decay``. 59 | 60 | .. code:: python 61 | 62 | class _Adam(_Optimizer): 63 | def __init__(self, args, params): 64 | super().__init__(args) 65 | 66 | self._optimizer = Adam(params, **self.optimizer_config) 67 | 68 | @property 69 | def optimizer_config(self): 70 | """ 71 | Return a kwarg dictionary that will be used to override optimizer 72 | args stored in checkpoints. This allows us to load a checkpoint and 73 | resume training using a different set of optimizer args, e.g., with a 74 | different learning rate. 75 | """ 76 | return { 77 | 'lr': self.args.lr[0], 78 | 'betas': eval(self.args.adam_betas), 79 | 'eps': self.args.adam_eps, 80 | 'weight_decay': self.args.weight_decay, 81 | } 82 | 83 | 84 | 85 | 86 | Learning Rate Scheduler 87 | ----------------------- 88 | 89 | In HetSeq, common ``PolynomialDecayScheduler`` is provided and compatible to ``BERT model`` and ``MNIST model``. 90 | Other learning rate scheduler can be easily extended by providing ``step_update`` and ``step`` function. 91 | 92 | 93 | ************* 94 | MNIST example 95 | ************* 96 | 97 | MNIST example is adapted from `PyTorch mnist example `__. It is convolutional neural network model for image classification. We adapt the original datasets, model and data loader to be compatible to HetSeq. 98 | 99 | 100 | Task 101 | ---- 102 | 103 | .. code:: python 104 | 105 | class MNISTTask(Task): 106 | def __init__(self, args): 107 | super(MNISTTask, self).__init__(args) 108 | 109 | @classmethod 110 | def setup_task(cls, args, **kwargs): 111 | """Setup the task (e.g., load dictionaries). 112 | Args: 113 | args (argparse.Namespace): parsed command-line arguments 114 | """ 115 | return cls(args) 116 | 117 | def build_model(self, args): 118 | model = MNISTNet() 119 | return model 120 | 121 | def load_dataset(self, split, **kwargs): 122 | """Load a given dataset split. 123 | Args: 124 | split (str): name of the split (e.g., train, valid, test) 125 | """ 126 | path = self.args.data 127 | 128 | if not os.path.exists(path): 129 | raise FileNotFoundError( 130 | "Dataset not found: ({})".format(path) 131 | ) 132 | 133 | if os.path.isdir(path): 134 | if os.path.exists(os.path.join(path, 'MNIST/processed/')): 135 | path = os.path.join(path, 'MNIST/processed/') 136 | elif os.path.basename(os.path.normpath(path)) != 'processed': 137 | datasets.MNIST(path, train=True, download=True) 138 | path = os.path.join(path, 'MNIST/processed/') 139 | 140 | files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path] 141 | files = sorted([f for f in files if split in f]) 142 | 143 | assert len(files) == 1, "no suitable file in split ***{}***".format(split) 144 | 145 | dataset = MNISTDataset(files[0]) 146 | 147 | print('| loaded {} sentences from: {}'.format(len(dataset), path), flush=True) 148 | 149 | self.datasets[split] = dataset 150 | print('| loading finished') 151 | 152 | 153 | Dataset 154 | ------- 155 | .. code:: python 156 | 157 | class MNISTDataset(torch.utils.data.Dataset): 158 | def __init__(self, path): 159 | self.data = None 160 | self.path = path 161 | self.read_data(self.path) 162 | self.transform = transforms.Compose([ 163 | transforms.ToTensor(), 164 | transforms.Normalize((0.1307,), (0.3081,)) 165 | ]) 166 | 167 | 168 | def read_data(self, path): 169 | self.data = torch.load(path) 170 | self._len = len(self.data[0]) 171 | self.image = self.data[0] 172 | self.label = self.data[1] 173 | 174 | 175 | def __getitem__(self, index): 176 | img, target = self.image[index], int(self.label[index]) 177 | img = Image.fromarray(img.numpy(), mode='L') 178 | img = self.transform(img) 179 | return img, target 180 | 181 | def __len__(self): 182 | return self._len 183 | 184 | def ordered_indices(self): 185 | """Return an ordered list of indices. Batches will be constructed based 186 | on this order.""" 187 | return np.arange(len(self)) 188 | 189 | def num_tokens(self, index: int): 190 | return 1 191 | 192 | def collater(self, samples): 193 | if len(samples) == 0: 194 | return None 195 | else: 196 | return default_collate(samples) 197 | 198 | def set_epoch(self, epoch): 199 | pass 200 | 201 | Model 202 | ----- 203 | 204 | .. code:: python 205 | 206 | class MNISTNet(nn.Module): 207 | def __init__(self): 208 | super(MNISTNet, self).__init__() 209 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 210 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 211 | self.dropout1 = nn.Dropout2d(0.25) 212 | self.dropout2 = nn.Dropout2d(0.5) 213 | self.fc1 = nn.Linear(9216, 128) 214 | self.fc2 = nn.Linear(128, 10) 215 | 216 | def forward(self, x, target, eval=False): 217 | x = self.conv1(x) 218 | x = F.relu(x) 219 | x = self.conv2(x) 220 | x = F.relu(x) 221 | x = F.max_pool2d(x, 2) 222 | x = self.dropout1(x) 223 | x = torch.flatten(x, 1) 224 | x = self.fc1(x) 225 | x = F.relu(x) 226 | x = self.dropout2(x) 227 | x = self.fc2(x) 228 | output = F.log_softmax(x, dim=1) 229 | loss = F.nll_loss(output, target) 230 | return loss 231 | 232 | Running Script 233 | -------------- 234 | 235 | See :doc:`running script ` for details. 236 | 237 | 238 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. HetSeq documentation master file, created by 2 | sphinx-quickstart on Mon Aug 17 16:33:04 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/yifding/hetseq 7 | 8 | HetSeq's documentation! 9 | ================================== 10 | 11 | HetSeq is a distributed neural network platiform designed to run on Heterogeneous Infrastructure with common scientific shared file system. It can be run directly on command line with SSH or task queen submission system without privilege or any extra packages. It takes care of the data index randomization and assignment to different GPUs in the multi-node and multi-GPU setting. Users can easily extend HetSeq to many other models with minimum effort. 12 | 13 | .. note:: 14 | 15 | HetSeq requires installation of `PyTorch `__ with GPU support and `NCCL `__. 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 3 20 | :caption: Getting Started: 21 | 22 | installing 23 | distribute 24 | 25 | .. toctree:: 26 | :maxdepth: 3 27 | :caption: Running HetSeq: 28 | 29 | parameter 30 | examples 31 | 32 | .. toctree:: 33 | :maxdepth: 3 34 | :caption: Extending HetSeq: 35 | 36 | extending 37 | 38 | .. toctree:: 39 | :maxdepth: 3 40 | :caption: Reference: 41 | 42 | task 43 | dataset 44 | model 45 | optimizer 46 | lr_scheduler 47 | meters 48 | progress_bar 49 | 50 | .. toctree:: 51 | :maxdepth: 3 52 | :caption: Support: 53 | 54 | support 55 | 56 | 57 | 58 | 59 | 60 | Indices and tables 61 | ================== 62 | 63 | * :ref:`genindex` 64 | * :ref:`search` 65 | -------------------------------------------------------------------------------- /docs/source/installing.rst: -------------------------------------------------------------------------------- 1 | ************ 2 | Installation 3 | ************ 4 | 5 | Create Conda Virtual Environment 6 | -------------------------------- 7 | Recommend to create and activate conda virtual environment with Python 3.7.4 8 | 9 | .. code:: console 10 | 11 | $ conda create --name hetseq 12 | $ conda activate hetseq 13 | $ conda install python=3.7.4 14 | 15 | Git Clone and Install Packages 16 | -------------------------------------------- 17 | 18 | .. code:: console 19 | 20 | $ git clone https://github.com/yifding/hetseq.git 21 | $ cd /path/to/hetseq 22 | $ pip install -r requirements.txt 23 | $ pip install --editable . 24 | 25 | Download BERT Processed File 26 | ---------------------------- 27 | 28 | (Required to run ``bert`` model) 29 | Download data files including training corpus, model configuration, and BPE dictionary. Test corpus from `here `__, full data from `this link `__. Download ``test_DATA.zip`` for test or ``DATA.zip`` for full run, unzip it and place the ``preprocessing/`` directory inside the package directory. 30 | 31 | Download MNIST Dataset (not required) 32 | ----------------------------------- 33 | 34 | (For ``mnist`` model) 35 | Download MNIST dataset from ``torchvision``, see example `here `__. 36 | 37 | .. code:: python 38 | 39 | from torchvision import datasets 40 | dataset1 = datasets.MNIST('../data', train=True, download=True) 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /docs/source/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | Learning Rate Scheduler 2 | ----------------------- 3 | 4 | .. autoclass:: hetseq.lr_scheduler.PolynomialDecayScheduler -------------------------------------------------------------------------------- /docs/source/meters.rst: -------------------------------------------------------------------------------- 1 | Meters 2 | ------ 3 | .. autoclass:: hetseq.meters.AverageMeter 4 | 5 | .. autoclass:: hetseq.meters.TimeMeter 6 | 7 | .. autoclass:: hetseq.meters.StopwatchMeter 8 | 9 | -------------------------------------------------------------------------------- /docs/source/model.rst: -------------------------------------------------------------------------------- 1 | Model 2 | ----- 3 | 4 | .. autoclass:: hetseq.tasks.MNISTNet 5 | 6 | .. autoclass:: hetseq.bert_modeling.BertConfig 7 | 8 | .. autoclass:: hetseq.bert_modeling.BertForPreTraining -------------------------------------------------------------------------------- /docs/source/optimizer.rst: -------------------------------------------------------------------------------- 1 | Optimizer 2 | --------- 3 | 4 | .. autoclass:: hetseq.optim.Adam 5 | 6 | .. autoclass:: hetseq.optim.Adadelta 7 | -------------------------------------------------------------------------------- /docs/source/parameter.rst: -------------------------------------------------------------------------------- 1 | ********** 2 | Parameters 3 | ********** 4 | 5 | Overview 6 | -------- 7 | 8 | To run HetSeq, almost all the parameters are passed through commond line processed by ``argparse``. Thoses parameters can be grouped into several clusters: 9 | 10 | * ``Extendable Components Parameters``: including ``task``, ``optimizer``, and ``lr_scheduler``; 11 | * ``Distributed Parameters``: to set up distributed training enviroments; 12 | * ``Training Prameters``: other important parameters to control stop criteria, logging information, checkpoints and etc. 13 | 14 | Here we are going to explain most parameters in details. 15 | 16 | Extendable Components Parameters 17 | -------------------------------- 18 | 19 | * **Task**: ``--task``: Application name, its corresponding class defines major parts of the application. Currently support ``bert`` and ``mnist``, can be extended to other models. 20 | 21 | 22 | * ``--task bert``: 23 | Extra parameters for ``bert`` task: 24 | * ``--data``: Dataset directory or file to be loaded in the corresponding task. 25 | * ``--config_file``: Configuration file of ``BERT`` model, example can be found `here `__ 26 | * ``--dict``: PATH of BPE dictionary for BERT model. Typically it has ~30,000 tokens. 27 | * ``--max_pred_length``: max number of tokens in a sentence, ``512`` by default. 28 | * ``--num_file``: number of input files for training, used with ``--data`` to debug. ``0`` by default to use all the data files. 29 | 30 | 31 | * ``--task mnist``: 32 | Extra parameters for ``mnist`` task: 33 | * ``--data``: Dataset directory or file to be loaded in the corresponding task, compatible with ``torchvision.datasets.MNIST(path, train=True, download=True)``. 34 | 35 | 36 | * **Optimizer**: ``--optimizer``: Optimizer defined in HetSeq is based on ``torch.optim.Optimizer`` with extra gradient and learning rate manipulation function. Currently support ``adam`` and ``adadelta`` which can be extended to many other optimizers. 37 | 38 | * ``--optimizer adam``: 39 | Extra parameters for ``adam`` optimizer: `Fixed Weight Decay Regularization in Adam `__. 40 | * ``--adam-betas``: betas to control momentum and velocity. Default='(0.9, 0.999)'. 41 | * ``--adam-eps``: epsilon for avoiding deviding by 0. Default=1e-8. 42 | * ``--weight-decay``: weight decay. ``0`` by default. 43 | 44 | 45 | * ``--optimizer adadelta``: 46 | Extra parameters for ``adadelta`` optimizer: 47 | * ``--adadelta_rho``: Default=0.9. 48 | * ``--adadelta_eps``: epsilon for avoiding deviding by 0. Default=1e-6. 49 | * ``--dadelta_weight_decay``: ``0`` by default. 50 | 51 | 52 | * **Lr_scheduler**: ``--lr_scheduler``: Learning rate scheduler defined in HetSeq customized to consider stop criteria ``end-learning-rate``, ``total-num-update`` and ``warmup-updates``. Currently support ``Polynomial Decay Scheduler``. 53 | 54 | * ``--optimizer PolynomialDecayScheduler``: 55 | Extra parameters for ``PolynomialDecayScheduler``: 56 | * ``--force-anneal``: force annealing at specified epoch, by default not existed. 57 | * ``--power``: decay power. ``1.0`` by default. 58 | * ``--warmup-updates``: warmup the learning rate linearly for the first N updates, ``0`` by default. 59 | * ``--total-num-update``: total number of update steps until learning rate decay to ``--end-learning-rate``, ``10000`` by default. 60 | * ``--end-learning-rate``: learning rate when traing stops. ``0`` by default. 61 | 62 | 63 | Distributed Parameters 64 | ---------------------- 65 | Distrbuted parameters play a key role in HetSeq to set up the distrbuted training environments, it defines the number of nodes, number of GPUs, communication methods and etc. 66 | 67 | * ``--fast-stat-sync``: Enable fast sync of stats between nodes, this hardcodes to sync only some default stats from logging_output. 68 | * ``--device-id``: index of single GPU used in the training. ``0`` by default. 69 | 70 | ``torch.nn.parallel.distributed import DistributedDataParallel`` related parameters, see `document `__ for more informaiton. Our implementation consider input and put tensors on the same device. 71 | 72 | * ``--bucket-cap-mb``: ``25`` by default 73 | * ``--find-unused-parameters``: ``False`` by default 74 | 75 | ``torch.distributed.init_process_group`` related parameters, control the main environment of distributed training. See :ref:`distribute` or `document `__ for more information. 76 | * ``--ddp-backend``: distributed data parallel backend, currently only support ``c10d`` with ``NCCL`` to communicate between GPUs. Default: 'c10d'. 77 | * ``--distributed-init-method``: initial methods to communicate between GPUs. Default ``None``. 78 | * ``--distributed-world-size``: total number of GPUs/processes in the distributed seeting. Defalut: ``max(1, torch.cuda.device_count())``. 79 | * ``--distributed-rank``: rank of the current GPU, ``0`` by default. 80 | 81 | 82 | 83 | Training Parameters 84 | ------------------- 85 | 86 | * ``--max-epoch``: maximum epoches allowd in the training. ``0`` by default. 87 | 88 | * ``--max-update``: maximum number of updates allowd in the training. ``0`` by default. 89 | 90 | * ``--required-batch-size-multiple``: check the batch size is the multiple times of the given number. ``1`` by default. 91 | 92 | * ``--update-freq``: update parameters every ``N_i`` batches, when in epoch i. ``1`` by default. 93 | 94 | * ``--max-tokens``: maximum number of tokens of a batch, not assigned. 95 | 96 | * ``--max-sentences``: maximum number of sentences/images/instances of a batch (batch size), not assigned. 97 | 98 | .. note:: 99 | 100 | ``--max-tokens`` or ``--max-sentences`` must be assigned in the prameter settings. 101 | 102 | * ``--train-subset``: string to store training subset, ``train`` by default. 103 | 104 | * ``--num-workers``: number of threads used in the data loading process. 105 | 106 | * ``--save-interval-updates``: save a checkpoint (and validate) every N updates, ``0`` by default. 107 | 108 | * ``--seed``: onlu seed in the training process to control all the possible random steps (e.g. in ``torch``, ``numpy`` and ``random``). ``19940802`` by default. 109 | 110 | * ``--log-interval``: log progress every N batches (when progress bar is disabled), ``1`` by default. 111 | 112 | * ``--log-format``: log format to use, choices=['none', 'simple'], ``simple`` by default. 113 | -------------------------------------------------------------------------------- /docs/source/progress_bar.rst: -------------------------------------------------------------------------------- 1 | Progress_bar 2 | ------------ 3 | 4 | .. autoclass:: hetseq.progress_bar.progress_bar 5 | 6 | .. autoclass:: hetseq.progress_bar.noop_progress_bar 7 | 8 | .. autoclass:: hetseq.progress_bar.simple_progress_bar 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/source/support.rst: -------------------------------------------------------------------------------- 1 | Copyright 2 | --------- 3 | HetSeq: Distributed GPU Training on Heterogeneous Infrastructure. 4 | 5 | Yifan Ding, Nicholas Botzer, Tim Weninger (2020). 6 | 7 | `Project Web `__ and 8 | `Project GitHub `__ 9 | 10 | `@Weninger Lab `__, Department of Computer Science & Engineering, University of Notre Dame. 11 | 12 | Contact 13 | ------- 14 | Welcome to contact us if any question. 15 | 16 | `Yifan Ding `__: yding4@nd.edu 17 | 18 | Tim Weninger: tweninge@nd.edu 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/source/task.rst: -------------------------------------------------------------------------------- 1 | Task 2 | ---- 3 | 4 | .. autoclass:: hetseq.tasks.Task 5 | 6 | .. autoclass:: hetseq.tasks.MNISTNet 7 | 8 | .. autoclass:: hetseq.tasks.LanguageModelingTask -------------------------------------------------------------------------------- /hetseq/__init__.py: -------------------------------------------------------------------------------- 1 | import hetseq.data 2 | import hetseq.tasks 3 | -------------------------------------------------------------------------------- /hetseq/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Union 3 | import collections 4 | import logging 5 | import os 6 | import re 7 | import traceback 8 | import shutil 9 | 10 | import torch 11 | from torch.serialization import default_restore_location 12 | 13 | 14 | def save_checkpoint(args, controller, epoch_itr, val_loss): 15 | import distributed_utils, meters 16 | 17 | prev_best = getattr(save_checkpoint, 'best', val_loss) 18 | if val_loss is not None: 19 | best_function = max if args.maximize_best_checkpoint_metric else min 20 | save_checkpoint.best = best_function(val_loss, prev_best) 21 | 22 | if args.no_save or not distributed_utils.is_master(args): 23 | return 24 | 25 | def is_better(a, b): 26 | return a >= b if args.maximize_best_checkpoint_metric else a <= b 27 | 28 | write_timer = meters.StopwatchMeter() 29 | write_timer.start() 30 | 31 | epoch = epoch_itr.epoch 32 | end_of_epoch = epoch_itr.end_of_epoch() 33 | updates = controller.get_num_updates() 34 | 35 | checkpoint_conds = collections.OrderedDict() 36 | checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( 37 | end_of_epoch and not args.no_epoch_checkpoints and 38 | epoch % args.save_interval == 0 39 | ) 40 | checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( 41 | not end_of_epoch and args.save_interval_updates > 0 and 42 | updates % args.save_interval_updates == 0 43 | ) 44 | checkpoint_conds['checkpoint_best.pt'] = ( 45 | val_loss is not None and 46 | (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best)) 47 | ) 48 | checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints 49 | 50 | extra_state = { 51 | 'train_iterator': epoch_itr.state_dict(), 52 | 'val_loss': val_loss, 53 | } 54 | if hasattr(save_checkpoint, 'best'): 55 | extra_state.update({'best': save_checkpoint.best}) 56 | 57 | checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] 58 | if len(checkpoints) > 0: 59 | controller.save_checkpoint(checkpoints[0], extra_state) 60 | for cp in checkpoints[1:]: 61 | shutil.copyfile(checkpoints[0], cp) 62 | 63 | write_timer.stop() 64 | print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( 65 | checkpoints[0], epoch, updates, write_timer.sum)) 66 | 67 | if not end_of_epoch and args.keep_interval_updates > 0: 68 | # remove old checkpoints; checkpoints are sorted in descending order 69 | checkpoints = checkpoint_paths( 70 | args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', 71 | ) 72 | for old_chk in checkpoints[args.keep_interval_updates:]: 73 | if os.path.lexists(old_chk): 74 | os.remove(old_chk) 75 | 76 | if args.keep_last_epochs > 0: 77 | # remove old epoch checkpoints; checkpoints are sorted in descending order 78 | checkpoints = checkpoint_paths( 79 | args.save_dir, pattern=r'checkpoint(\d+)\.pt', 80 | ) 81 | for old_chk in checkpoints[args.keep_last_epochs:]: 82 | if os.path.lexists(old_chk): 83 | os.remove(old_chk) 84 | 85 | 86 | def load_checkpoint(args, controller): 87 | """Load a checkpoint and restore the training iterator.""" 88 | # only one worker should attempt to create the required dir 89 | if args.distributed_rank == 0: 90 | os.makedirs(args.save_dir, exist_ok=True) 91 | 92 | if args.restore_file == 'checkpoint_last.pt' or args.restore_file =='checkpoint_best.pt': 93 | checkpoint_path = os.path.join(args.save_dir, args.restore_file) 94 | else: 95 | checkpoint_path = args.restore_file 96 | 97 | extra_state = controller.load_checkpoint( 98 | checkpoint_path, 99 | args.reset_optimizer, 100 | args.reset_lr_scheduler, 101 | eval(args.optimizer_overrides), 102 | reset_meters=args.reset_meters, 103 | ) 104 | 105 | 106 | if( 107 | extra_state is not None 108 | and 'best' in extra_state 109 | and not args.reset_optimizer 110 | and not args.reset_meters 111 | ): 112 | save_checkpoint.best = extra_state['best'] 113 | 114 | 115 | if extra_state is not None and not args.reset_dataloader: 116 | # restore iterator from checkpoint 117 | itr_state = extra_state['train_iterator'] 118 | epoch_itr = controller.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True) 119 | epoch_itr.load_state_dict(itr_state) 120 | else: 121 | epoch_itr = controller.get_train_iterator(epoch=0, load_dataset=True) 122 | 123 | controller.lr_step(epoch_itr.epoch) 124 | 125 | return extra_state, epoch_itr 126 | 127 | 128 | def load_checkpoint_to_cpu(path, arg_overrides=None): 129 | """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" 130 | state = torch.load( 131 | path, map_location=lambda s, l: default_restore_location(s, 'cpu'), 132 | ) 133 | args = state['args'] 134 | if arg_overrides is not None: 135 | for arg_name, arg_val in arg_overrides.items(): 136 | setattr(args, arg_name, arg_val) 137 | # state = _upgrade_state_dict(state) 138 | return state 139 | 140 | 141 | 142 | 143 | def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): 144 | """Retrieves all checkpoints found in `path` directory. 145 | Checkpoints are identified by matching filename to the specified pattern. If 146 | the pattern contains groups, the result will be sorted by the first group in 147 | descending order. 148 | """ 149 | pt_regexp = re.compile(pattern) 150 | files = os.listdir(path) 151 | 152 | entries = [] 153 | for i, f in enumerate(files): 154 | m = pt_regexp.fullmatch(f) 155 | if m is not None: 156 | idx = int(m.group(1)) if len(m.groups()) > 0 else i 157 | entries.append((idx, m.group(0))) 158 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] 159 | 160 | 161 | def torch_persistent_save(*args, **kwargs): 162 | for i in range(3): 163 | try: 164 | return torch.save(*args, **kwargs) 165 | except Exception: 166 | if i == 2: 167 | logging.error(traceback.format_exc()) 168 | 169 | 170 | def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): 171 | if isinstance(state_dict, dict): 172 | cpu_dict = OrderedDict() 173 | for k, v in state_dict.items(): 174 | cpu_dict[k] = convert_state_dict_type(v) 175 | return cpu_dict 176 | elif isinstance(state_dict, list): 177 | return [convert_state_dict_type(v) for v in state_dict] 178 | elif torch.is_tensor(state_dict): 179 | return state_dict.type(ttype) 180 | else: 181 | return state_dict 182 | 183 | 184 | def save_state( 185 | #criterion = None 186 | filename, args, model_state_dict, criterion, optimizer, lr_scheduler, 187 | num_updates, optim_history=None, extra_state=None, 188 | ): 189 | if optim_history is None: 190 | optim_history = [] 191 | if extra_state is None: 192 | extra_state = {} 193 | state_dict = { 194 | 'args': args, 195 | 'model': model_state_dict if model_state_dict else {}, 196 | 'optimizer_history': optim_history + [ 197 | { 198 | #'criterion_name': criterion.__class__.__name__, 199 | 'optimizer_name': optimizer.__class__.__name__, 200 | 'lr_scheduler_state': lr_scheduler.state_dict(), 201 | 'num_updates': num_updates, 202 | } 203 | ], 204 | 'extra_state': {}, 205 | } 206 | if not args.no_save_optimizer_state: 207 | state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) 208 | torch_persistent_save(state_dict, filename) 209 | 210 | 211 | def verify_checkpoint_directory(save_dir: str) -> None: 212 | if not os.path.exists(save_dir): 213 | os.makedirs(save_dir, exist_ok=True) 214 | temp_file_path = os.path.join(save_dir, 'dummy') 215 | try: 216 | with open(temp_file_path, 'w'): 217 | pass 218 | except OSError as e: 219 | print('| Unable to access checkpoint save directory: {}'.format(save_dir)) 220 | raise e 221 | else: 222 | os.remove(temp_file_path) -------------------------------------------------------------------------------- /hetseq/data/BERT_DATA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | 4 | import h5py 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import torch.utils.data 9 | 10 | 11 | # # method is too large(memory 7% took 20GB) and ~25min to load data. 12 | class CombineBertData(torch.utils.data.Dataset): 13 | def __init__(self, files, max_pred_length=512, 14 | keys=('input_ids', 'input_mask', 'segment_ids', 15 | 'masked_lm_positions', 'masked_lm_ids','next_sentence_labels')): 16 | # can potentially using multiple processing/thread to speed up reading 17 | # if input is too large, considering other strategies to load data 18 | self.max_pred_length = max_pred_length 19 | self.keys = keys 20 | self.inputs = collections.OrderedDict() 21 | 22 | for key in self.keys: 23 | self.inputs[key] = [] 24 | 25 | for input_file in tqdm(files): 26 | f = h5py.File(input_file, "r") 27 | for i, key in enumerate(keys): 28 | if i < 5: 29 | self.inputs[key].append(f[key][:]) 30 | else: 31 | self.inputs[key].append(np.asarray(f[key][:])) 32 | f.close() 33 | 34 | for key in self.inputs: 35 | self.inputs[key] = np.concatenate(self.inputs[key]) 36 | 37 | def __len__(self): 38 | return len(self.inputs[self.keys[0]]) 39 | 40 | def __getitem__(self, index): 41 | return [self.inputs[key][index] for key in self.keys] 42 | # ['input_ids', 'input_mask', 'segment_ids','masked_lm_positions', 43 | # 'masked_lm_ids','next_sentence_labels'] 44 | 45 | 46 | class BertTask(object): 47 | def __init__(self, args): 48 | self.args = self.load() 49 | self.dict = self.load_vocab(args.dict) 50 | self.seed = args.seed 51 | self.datasets = {} 52 | 53 | @staticmethod 54 | def load_vocab(vocab_file): 55 | """Loads a vocabulary file into a dictionary.""" 56 | vocab = collections.OrderedDict() 57 | index = 0 58 | with open(vocab_file, "r", encoding="utf-8") as reader: 59 | while True: 60 | token = reader.readline() 61 | if not token: 62 | break 63 | token = token.strip() 64 | vocab[token] = index 65 | index += 1 66 | return vocab 67 | 68 | def load_dataset(self, split='train'): 69 | """combine multiple files into one single dataset 70 | Args: 71 | split (str): name must included in the file(e.g., train, valid, test) 72 | 73 | """ 74 | path = self.args.data 75 | if not os.path.exists(path): 76 | raise FileNotFoundError( 77 | "Dataset not found: ({})".format(path) 78 | ) 79 | 80 | files = os.listdir(path) if os.path.isdir(path) else [path] 81 | files = [f for f in files if split in f] 82 | assert len(files) > 0 83 | 84 | self.datasets[split] = CombineBertData(files) 85 | 86 | 87 | """ 88 | dataset = data_utils.load_indexed_dataset( 89 | split_path, self.dictionary, self.args.dataset_impl, combine=combine 90 | ) 91 | if dataset is None: 92 | raise FileNotFoundError( 93 | "Dataset not found: {} ({})".format(split, split_path) 94 | ) 95 | 96 | dataset = TokenBlockDataset( 97 | dataset, 98 | dataset.sizes, 99 | self.args.tokens_per_sample, 100 | pad=self.dictionary.pad(), 101 | eos=self.dictionary.eos(), 102 | break_mode=self.args.sample_break_mode, 103 | include_targets=True, 104 | ) 105 | 106 | add_eos_for_other_targets = ( 107 | self.args.sample_break_mode is not None 108 | and self.args.sample_break_mode != "none" 109 | ) 110 | 111 | self.datasets[split] = MonolingualDataset( 112 | dataset, 113 | dataset.sizes, 114 | self.dictionary, 115 | self.output_dictionary, 116 | add_eos_for_other_targets=add_eos_for_other_targets, 117 | shuffle=True, 118 | targets=self.targets, 119 | add_bos_token=self.args.add_bos_token, 120 | ) 121 | """ -------------------------------------------------------------------------------- /hetseq/data/BERT_DATA_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from BERT_DATA import CombineBertData 3 | 4 | split = "train" 5 | files = "/scratch365/yding4/bert_project/bert_prep_working_dir/" \ 6 | "hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/wikicorpus_en" 7 | 8 | 9 | 10 | path = files 11 | if not os.path.exists(path): 12 | raise FileNotFoundError( 13 | "Dataset not found: ({})".format(path) 14 | ) 15 | 16 | files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path] 17 | print(files) 18 | files = [f for f in files if split in f] 19 | print(files) 20 | assert len(files) > 0 21 | 22 | self.datasets[split] = CombineBertData(files) -------------------------------------------------------------------------------- /hetseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .h5pyDataset import BertH5pyData, ConBertH5pyData 2 | from .mnist_dataset import MNISTDataset 3 | from .bert_ner_dataset import BertNerDataset 4 | from .bert_el_dataset import BertELDataset 5 | 6 | 7 | __all__ = [ 8 | 'BertH5pyData', 9 | 'ConBertH5pyData', 10 | 'MNISTDataset', 11 | 'BertNerDataset', 12 | ] 13 | -------------------------------------------------------------------------------- /hetseq/data/bert_el_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class BertELDataset(torch.utils.data.Dataset): 8 | def __init__(self, dataset, args): 9 | self.args = args 10 | self.dataset = dataset 11 | 12 | @lru_cache(maxsize=8) 13 | def __getitem__(self, index): 14 | return self.dataset.__getitem__(index) 15 | 16 | def __len__(self): 17 | return len(self.dataset) 18 | 19 | def ordered_indices(self): 20 | """Return an ordered list of indices. Batches will be constructed based 21 | on this order.""" 22 | return np.arange(len(self.dataset)) 23 | 24 | def num_tokens(self, index: int): 25 | return len(self.dataset[index]['labels']) 26 | 27 | def collater(self, samples): 28 | if len(samples) == 0: 29 | return None 30 | else: 31 | return self.args.data_collator(samples) 32 | 33 | def set_epoch(self, epoch): 34 | pass 35 | 36 | 37 | -------------------------------------------------------------------------------- /hetseq/data/bert_ner_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class BertNerDataset(torch.utils.data.Dataset): 8 | def __init__(self, dataset, args): 9 | self.args = args 10 | self.dataset = dataset 11 | 12 | @lru_cache(maxsize=8) 13 | def __getitem__(self, index): 14 | return self.dataset.__getitem__(index) 15 | 16 | def __len__(self): 17 | return len(self.dataset) 18 | 19 | def ordered_indices(self): 20 | """Return an ordered list of indices. Batches will be constructed based 21 | on this order.""" 22 | return np.arange(len(self.dataset)) 23 | 24 | def num_tokens(self, index: int): 25 | return len(self.dataset[index]['labels']) 26 | 27 | def collater(self, samples): 28 | if len(samples) == 0: 29 | return None 30 | else: 31 | return self.args.data_collator(samples) 32 | 33 | def set_epoch(self, epoch): 34 | pass 35 | 36 | 37 | -------------------------------------------------------------------------------- /hetseq/data/data_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from collections.abc import Iterable 3 | except ImportError: 4 | from collections import Iterable 5 | import contextlib 6 | import itertools 7 | import os 8 | import sys 9 | import types 10 | 11 | import numpy as np 12 | 13 | # # get current state, apply seed, return back to the stored state 14 | @contextlib.contextmanager 15 | def numpy_seed(seed, *addl_seeds): 16 | """Context manager which seeds the NumPy PRNG with the specified seed and 17 | restores the state afterward""" 18 | if seed is None: 19 | yield 20 | return 21 | if len(addl_seeds) > 0: 22 | seed = int(hash((seed, *addl_seeds)) % 1e6) 23 | state = np.random.get_state() 24 | np.random.seed(seed) 25 | try: 26 | yield 27 | finally: 28 | np.random.set_state(state) 29 | 30 | 31 | def batch_by_size( 32 | indices, num_tokens_fn, max_tokens=None, max_sentences=None, 33 | required_batch_size_multiple=1, 34 | ): 35 | """ 36 | Yield mini-batches of indices bucketed by size. Batches may contain 37 | sequences of different lengths. 38 | Args: 39 | indices (List[int]): ordered list of dataset indices 40 | num_tokens_fn (callable): function that returns the number of tokens at 41 | a given index 42 | max_tokens (int, optional): max number of tokens in each batch 43 | (default: None). 44 | max_sentences (int, optional): max number of sentences in each 45 | batch (default: None). 46 | required_batch_size_multiple (int, optional): require batch size to 47 | be a multiple of N (default: 1). 48 | """ 49 | try: 50 | from data.data_utils_fast import batch_by_size_fast # #dependency issue 51 | except ImportError: 52 | raise ImportError( 53 | 'Please build Cython components with: `pip install --editable .` ' 54 | 'or `python setup.py build_ext --inplace`' 55 | ) 56 | 57 | max_tokens = max_tokens if max_tokens is not None else sys.maxsize 58 | max_sentences = max_sentences if max_sentences is not None else sys.maxsize 59 | bsz_mult = required_batch_size_multiple 60 | 61 | return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) 62 | -------------------------------------------------------------------------------- /hetseq/data/data_utils_fast.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | DTYPE = np.int64 7 | ctypedef np.int64_t DTYPE_t 8 | 9 | 10 | cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): 11 | if len(batch) == 0: 12 | return 0 13 | if len(batch) == max_sentences: 14 | return 1 15 | if num_tokens > max_tokens: 16 | return 1 17 | return 0 18 | 19 | 20 | @cython.cdivision(True) 21 | cpdef list batch_by_size_fast( 22 | np.ndarray[DTYPE_t, ndim=1] indices, 23 | num_tokens_fn, 24 | long max_tokens, 25 | long max_sentences, 26 | int bsz_mult, 27 | ): 28 | cdef long sample_len = 0 29 | cdef list sample_lens = [] 30 | cdef list batch = [] 31 | cdef list batches = [] 32 | cdef long mod_len 33 | cdef long i 34 | cdef long idx 35 | cdef long num_tokens 36 | cdef DTYPE_t[:] indices_view = indices 37 | 38 | for i in range(len(indices_view)): 39 | idx = indices_view[i] 40 | num_tokens = num_tokens_fn(idx) 41 | sample_lens.append(num_tokens) 42 | sample_len = max(sample_len, num_tokens) 43 | 44 | assert sample_len <= max_tokens, ( 45 | "sentence at index {} of size {} exceeds max_tokens " 46 | "limit of {}!".format(idx, sample_len, max_tokens) 47 | ) 48 | num_tokens = (len(batch) + 1) * sample_len 49 | 50 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 51 | mod_len = max( 52 | bsz_mult * (len(batch) // bsz_mult), 53 | len(batch) % bsz_mult, 54 | ) 55 | batches.append(batch[:mod_len]) 56 | batch = batch[mod_len:] 57 | sample_lens = sample_lens[mod_len:] 58 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 59 | batch.append(idx) 60 | if len(batch) > 0: 61 | batches.append(batch) 62 | return batches -------------------------------------------------------------------------------- /hetseq/data/h5pyDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bisect 3 | from functools import lru_cache 4 | 5 | import h5py 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from torch.utils.data.dataloader import default_collate 11 | 12 | 13 | class BertH5pyData(torch.utils.data.Dataset): # # don't know whether support multiprocess loading? 14 | def __init__(self, path, max_pred_length=512): 15 | super(BertH5pyData, self).__init__() 16 | self.keys = ('input_ids', 'input_mask', 'segment_ids', 17 | 'masked_lm_positions', 'masked_lm_ids', 'next_sentence_labels') 18 | self.max_pred_length = max_pred_length 19 | self.data_file = None 20 | self.path = path 21 | self.read_data(path) 22 | 23 | def read_data(self, path): 24 | with h5py.File(path, "r", libver='latest', swmr=True) as data_file: 25 | self._len = len(data_file[self.keys[0]]) 26 | 27 | def check_index(self, i): 28 | if i < 0 or i >= self._len: 29 | raise IndexError('index out of range') 30 | 31 | @lru_cache(maxsize=8) 32 | def __getitem__(self, index): 33 | with h5py.File(self.path, "r", libver='latest', swmr=True) as data_file: 34 | self.check_index(index) 35 | 36 | inputs = [data_file[key][index] for key in self.keys] 37 | 38 | [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ 39 | torch.from_numpy(input.astype(np.int64)) if indice < 5 else torch.from_numpy( 40 | np.asarray(input.astype(np.int64))) for indice, input in enumerate(inputs)] 41 | 42 | masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1 43 | index = self.max_pred_length 44 | # store number of masked tokens in index 45 | padded_mask_indices = (masked_lm_positions == 0).nonzero() 46 | if len(padded_mask_indices) != 0: 47 | index = padded_mask_indices[0].item() 48 | masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] 49 | 50 | return [input_ids, segment_ids, input_mask, 51 | masked_lm_labels, next_sentence_labels] 52 | 53 | def __del__(self): 54 | if self.data_file: 55 | self.data_file.flush() 56 | self.data_file.close() #encounter bug, don't know how to fix it 57 | 58 | def __len__(self): 59 | return self._len 60 | # debug 61 | # return 11 62 | 63 | def size(self, idx: int): 64 | """ 65 | Return an example's size as a float or tuple. 66 | """ 67 | return self.max_pred_length # in our BERT preparation, the length is always 512 68 | 69 | def set_epoch(self, epoch): 70 | pass 71 | 72 | class ConBertH5pyData(torch.utils.data.Dataset): 73 | @staticmethod 74 | def cumsum(sequence, sample_ratios): 75 | r, s = [], 0 76 | for e, ratio in zip(sequence, sample_ratios): 77 | curr_len = int(ratio * len(e)) 78 | r.append(curr_len + s) 79 | s += curr_len 80 | return r 81 | 82 | def __init__(self, datasets, sample_ratios=1): 83 | super(ConBertH5pyData, self).__init__() 84 | assert len(datasets) > 0, "datasets should not be an empty iterable" 85 | self.datasets = list(datasets) 86 | if isinstance(sample_ratios, int): 87 | sample_ratios = [sample_ratios] * len(self.datasets) 88 | self.sample_ratios = sample_ratios 89 | self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) 90 | self.real_sizes = [len(d) for d in self.datasets] 91 | 92 | def __len__(self): 93 | return self.cumulative_sizes[-1] 94 | 95 | def __getitem__(self, idx): 96 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 97 | return self.datasets[dataset_idx][sample_idx] 98 | 99 | def _get_dataset_and_sample_index(self, idx: int): 100 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 101 | if dataset_idx == 0: 102 | sample_idx = idx 103 | else: 104 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 105 | sample_idx = sample_idx % self.real_sizes[dataset_idx] 106 | return dataset_idx, sample_idx 107 | 108 | def collater(self, samples): 109 | # For now only supports datasets with same underlying collater implementations 110 | # print("samples", type(samples)) 111 | if len(samples) == 0: 112 | return None 113 | if hasattr(self.datasets[0], 'collater'): 114 | return self.datasets[0].collater(samples) 115 | else: 116 | return default_collate(samples) 117 | 118 | def ordered_indices(self): 119 | """Return an ordered list of indices. Batches will be constructed based 120 | on this order.""" 121 | return np.arange(len(self)) 122 | 123 | def num_tokens(self, index: int): 124 | return np.max(self.size(index)) 125 | 126 | def size(self, idx: int): 127 | """ 128 | Return an example's size as a float or tuple. 129 | """ 130 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 131 | return self.datasets[dataset_idx].size(sample_idx) 132 | 133 | def set_epoch(self, epoch): 134 | pass 135 | 136 | 137 | -------------------------------------------------------------------------------- /hetseq/data/iterators.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from hetseq.data import data_utils 9 | 10 | class CountingIterator(object): 11 | """Wrapper around an iterable that maintains the iteration count. 12 | Args: 13 | iterable (iterable): iterable to wrap 14 | Attributes: 15 | count (int): number of elements consumed from this iterator 16 | """ 17 | 18 | def __init__(self, iterable, start=0): 19 | self.iterable = iterable 20 | self.count = start 21 | self.itr = iter(self) 22 | self.len = start + len(iterable) 23 | 24 | def __len__(self): 25 | return self.len 26 | 27 | def __iter__(self): 28 | for x in self.iterable: 29 | self.count += 1 30 | yield x 31 | 32 | def __next__(self): 33 | return next(self.itr) 34 | 35 | def has_next(self): 36 | """Whether the iterator has been exhausted.""" 37 | return self.count < len(self) 38 | 39 | def skip(self, num_to_skip): 40 | """Fast-forward the iterator by skipping *num_to_skip* elements.""" 41 | next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) 42 | return self 43 | 44 | 45 | class EpochBatchIterating(object): 46 | def __len__(self) -> int: 47 | raise NotImplementedError 48 | 49 | def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): 50 | raise NotImplementedError 51 | 52 | def end_of_epoch(self) -> bool: 53 | """Returns whether the most recent epoch iterator has been exhausted""" 54 | raise NotImplementedError 55 | 56 | @property 57 | def iterations_in_epoch(self) -> int: 58 | raise NotImplementedError 59 | 60 | def state_dict(self): 61 | raise NotImplementedError 62 | 63 | def load_state_dict(self, state_dict): 64 | raise NotImplementedError 65 | 66 | 67 | class EpochBatchIterator(EpochBatchIterating): 68 | """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. 69 | Compared to :class:`torch.utils.data.DataLoader`, this iterator: 70 | - can be reused across multiple epochs with the :func:`next_epoch_itr` 71 | method (optionally shuffled between epochs) 72 | - can be serialized/deserialized with the :func:`state_dict` and 73 | :func:`load_state_dict` methods 74 | - supports sharding with the *num_shards* and *shard_id* arguments 75 | Args: 76 | dataset (~torch.utils.data.Dataset): dataset from which to load the data 77 | collate_fn (callable): merges a list of samples to form a mini-batch 78 | batch_sampler (~torch.utils.data.Sampler): an iterator over batches of 79 | indices 80 | seed (int, optional): seed for random number generator for 81 | reproducibility (default: 1). 82 | num_shards (int, optional): shard the data iterator into N 83 | shards (default: 1). 84 | shard_id (int, optional): which shard of the data iterator to 85 | return (default: 0). 86 | num_workers (int, optional): how many subprocesses to use for data 87 | loading. 0 means the data will be loaded in the main process 88 | (default: 0). 89 | epoch (int, optional): the epoch to start the iterator from 90 | (default: 0). 91 | """ 92 | 93 | def __init__( 94 | self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, 95 | num_workers=0, epoch=0, 96 | ): 97 | assert isinstance(dataset, torch.utils.data.Dataset) 98 | self.dataset = dataset 99 | self.collate_fn = collate_fn 100 | self.frozen_batches = tuple(batch_sampler) 101 | self.seed = seed 102 | self.num_shards = num_shards 103 | self.shard_id = shard_id 104 | self.num_workers = num_workers 105 | 106 | self.epoch = epoch 107 | self._cur_epoch_itr = None 108 | self._next_epoch_itr = None 109 | self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) 110 | 111 | def __len__(self): 112 | return len(self.frozen_batches) 113 | 114 | def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): 115 | """Return a new iterator over the dataset. 116 | Args: 117 | shuffle (bool, optional): shuffle batches before returning the 118 | iterator (default: True). 119 | fix_batches_to_gpus: ensure that batches are always 120 | allocated to the same shards across epochs. Requires 121 | that :attr:`dataset` supports prefetching (default: False). 122 | """ 123 | if self._next_epoch_itr is not None: 124 | self._cur_epoch_itr = self._next_epoch_itr 125 | self._next_epoch_itr = None 126 | else: 127 | self.epoch += 1 128 | self._cur_epoch_itr = self._get_iterator_for_epoch( 129 | self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, 130 | ) 131 | self.dataset.set_epoch(self.epoch) 132 | return self._cur_epoch_itr 133 | 134 | def end_of_epoch(self) -> bool: 135 | """Returns whether the most recent epoch iterator has been exhausted""" 136 | return not self._cur_epoch_itr.has_next() 137 | 138 | @property 139 | def iterations_in_epoch(self): 140 | """The number of consumed batches in the current epoch.""" 141 | if self._cur_epoch_itr is not None: 142 | return self._cur_epoch_itr.count 143 | elif self._next_epoch_itr is not None: 144 | return self._next_epoch_itr.count 145 | return 0 146 | 147 | def state_dict(self): 148 | """Returns a dictionary containing a whole state of the iterator.""" 149 | return { 150 | 'epoch': self.epoch, 151 | 'iterations_in_epoch': self.iterations_in_epoch, 152 | } 153 | 154 | def load_state_dict(self, state_dict): 155 | """Copies the state of the iterator from the given *state_dict*.""" 156 | self.epoch = state_dict['epoch'] 157 | itr_pos = state_dict.get('iterations_in_epoch', 0) 158 | if itr_pos > 0: 159 | # fast-forward epoch iterator 160 | self._next_epoch_itr = self._get_iterator_for_epoch( 161 | self.epoch, 162 | shuffle=state_dict.get('shuffle', True), 163 | offset=itr_pos, 164 | ) 165 | 166 | def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0): 167 | 168 | def shuffle_batches(batches, seed): 169 | # set seed based on the seed and epoch number so that we get 170 | # reproducible results when resuming from checkpoints 171 | with data_utils.numpy_seed(seed): 172 | np.random.shuffle(batches) 173 | return batches 174 | 175 | if self._supports_prefetch: 176 | batches = self.frozen_batches 177 | 178 | if shuffle and not fix_batches_to_gpus: 179 | batches = shuffle_batches(list(batches), self.seed + epoch) 180 | 181 | batches = list(ShardedIterator( 182 | batches, self.num_shards, self.shard_id, fill_value=[] 183 | )) 184 | self.dataset.prefetch([i for s in batches for i in s]) 185 | 186 | if shuffle and fix_batches_to_gpus: 187 | batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) 188 | else: 189 | if shuffle: 190 | batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) 191 | else: 192 | batches = self.frozen_batches 193 | batches = list(ShardedIterator( 194 | batches, self.num_shards, self.shard_id, fill_value=[] 195 | )) 196 | 197 | if offset > 0 and offset >= len(batches): 198 | return None 199 | 200 | if self.num_workers > 0: 201 | os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning' 202 | 203 | return CountingIterator( 204 | torch.utils.data.DataLoader( 205 | self.dataset, 206 | collate_fn=self.collate_fn, 207 | batch_sampler=batches[offset:], 208 | num_workers=self.num_workers, 209 | ), 210 | start=offset, 211 | ) 212 | 213 | 214 | class GroupedIterator(object): 215 | """Wrapper around an iterable that returns groups (chunks) of items. 216 | Args: 217 | iterable (iterable): iterable to wrap 218 | chunk_size (int): size of each chunk 219 | """ 220 | 221 | def __init__(self, iterable, chunk_size): 222 | self._len = int(math.ceil(len(iterable) / float(chunk_size))) 223 | self.offset = int(math.ceil(getattr(iterable, 'count', 0) / float(chunk_size))) 224 | self.itr = iterable 225 | self.chunk_size = chunk_size 226 | 227 | def __len__(self): 228 | return self._len 229 | 230 | def __iter__(self): 231 | return self 232 | 233 | def __next__(self): 234 | chunk = [] 235 | try: 236 | for _ in range(self.chunk_size): 237 | chunk.append(next(self.itr)) 238 | except StopIteration as e: 239 | if len(chunk) == 0: 240 | raise e 241 | return chunk 242 | 243 | 244 | class ShardedIterator(object): 245 | """A sharded wrapper around an iterable, padded to length. 246 | Args: 247 | iterable (iterable): iterable to wrap 248 | num_shards (int): number of shards to split the iterable into 249 | shard_id (int): which shard to iterator over 250 | fill_value (Any, optional): padding value when the iterable doesn't 251 | evenly divide *num_shards* (default: None). 252 | """ 253 | 254 | def __init__(self, iterable, num_shards, shard_id, fill_value=None): 255 | if shard_id < 0 or shard_id >= num_shards: 256 | raise ValueError('shard_id must be between 0 and num_shards') 257 | 258 | self._sharded_len = len(iterable) // num_shards 259 | if len(iterable) % num_shards > 0: 260 | self._sharded_len += 1 261 | 262 | self.itr = itertools.zip_longest( 263 | range(self._sharded_len), 264 | itertools.islice(iterable, shard_id, len(iterable), num_shards), 265 | fillvalue=fill_value, 266 | ) 267 | 268 | def __len__(self): 269 | return self._sharded_len 270 | 271 | def __iter__(self): 272 | return self 273 | 274 | def __next__(self): 275 | return next(self.itr)[1] -------------------------------------------------------------------------------- /hetseq/data/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from torchvision import transforms 3 | from PIL import Image 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | from torch.utils.data.dataloader import default_collate 9 | 10 | 11 | class MNISTDataset(torch.utils.data.Dataset): 12 | def __init__(self, path): 13 | self.data = None 14 | self.path = path 15 | self.read_data(self.path) 16 | self.transform = transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.1307,), (0.3081,)) 19 | ]) 20 | 21 | """ 22 | **YD** original read_data 23 | def read_data(self, path): 24 | self.data = torch.load(path) 25 | self._len = len(self.data[0]) 26 | self.image = self.data[0].unsqueeze(1).float() 27 | self.label = self.data[1].long() 28 | """ 29 | 30 | def read_data(self, path): 31 | self.data = torch.load(path) 32 | self._len = len(self.data[0]) 33 | self.image = self.data[0] 34 | self.label = self.data[1] 35 | 36 | # **YD** 37 | # print(self.data[0].shape, self.data[1].shape) 38 | # raise ValueError('debugging for data shape') 39 | 40 | """ 41 | **YD** original __getitem__ 42 | @lru_cache(maxsize=8) 43 | def __getitem__(self, index): 44 | # print(self.image.shape, self.data[1].shape) 45 | return [self.image[index, :, :, :], self.label[index]] 46 | """ 47 | @lru_cache(maxsize=8) 48 | def __getitem__(self, index): 49 | img, target = self.image[index], int(self.label[index]) 50 | img = Image.fromarray(img.numpy(), mode='L') 51 | img = self.transform(img) 52 | return img, target 53 | # return [self.image[index, :, :, :], self.label[index]] 54 | 55 | def __len__(self): 56 | return self._len 57 | 58 | def ordered_indices(self): 59 | """Return an ordered list of indices. Batches will be constructed based 60 | on this order.""" 61 | return np.arange(len(self)) 62 | 63 | def num_tokens(self, index: int): 64 | return 1 65 | 66 | def collater(self, samples): 67 | # For now only supports datasets with same underlying collater implementations 68 | # print("samples", type(samples)) 69 | if len(samples) == 0: 70 | return None 71 | else: 72 | return default_collate(samples) 73 | 74 | def set_epoch(self, epoch): 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | path = '/scratch365/yding4/mnist/MNIST/processed/training.pt' 80 | dataset = MNISTDataset(path) 81 | data = torch.load(path) 82 | print(len(dataset)) 83 | print(data[0].shape, data[1].shape) -------------------------------------------------------------------------------- /hetseq/data_collator/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_collator import YD_DataCollatorForTokenClassification, YD_DataCollatorForELClassification 2 | 3 | __all__ = [ 4 | 'YD_DataCollatorForTokenClassification', 5 | 'YD_DataCollatorForELClassification', 6 | ] 7 | -------------------------------------------------------------------------------- /hetseq/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import socket 3 | import warnings 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from hetseq import utils 9 | 10 | 11 | def distributed_init(args): 12 | if args.distributed_world_size == 1: 13 | raise ValueError('Cannot initialize distributed with distributed_world_size=1') 14 | 15 | if torch.distributed.is_initialized(): 16 | warnings.warn('Distributed is already initialized, cannot initialize twice!') 17 | else: 18 | print('| distributed init (rank {}): {}'.format( 19 | args.distributed_rank, args.distributed_init_method), flush=True) 20 | dist.init_process_group( 21 | backend=args.distributed_backend, 22 | init_method=args.distributed_init_method, 23 | world_size=args.distributed_world_size, 24 | rank=args.distributed_rank, 25 | ) 26 | print('| initialized host {} as rank {}'.format( 27 | socket.gethostname(), args.distributed_rank), flush=True) 28 | 29 | # perform a dummy all-reduce to initialize the NCCL communicator 30 | if torch.cuda.is_available(): 31 | dist.all_reduce(torch.zeros(1).cuda()) 32 | else: 33 | dist.all_reduce(torch.zeros(1)) 34 | 35 | suppress_output(is_master(args)) 36 | 37 | print("previous rank: {}, using rank :{} for host {}".format( 38 | args.distributed_rank, torch.distributed.get_rank(), socket.gethostname())) 39 | args.distributed_rank = torch.distributed.get_rank() 40 | print('| actual rank {}'.format(args.distributed_rank)) 41 | return args.distributed_rank 42 | 43 | 44 | def is_master(args): 45 | return args.distributed_rank == 0 46 | 47 | 48 | def suppress_output(is_master): 49 | """Suppress printing on the current device. Force printing with `force=True`.""" 50 | import builtins as __builtin__ 51 | builtin_print = __builtin__.print 52 | 53 | def print(*args, **kwargs): 54 | force = kwargs.pop('force', False) 55 | if is_master or force: 56 | builtin_print(*args, **kwargs) 57 | 58 | __builtin__.print = print 59 | 60 | 61 | def get_rank(): 62 | return dist.get_rank() 63 | 64 | 65 | def get_world_size(): 66 | return dist.get_world_size() 67 | 68 | 69 | def get_default_group(): 70 | return dist.group.WORLD 71 | 72 | 73 | def all_reduce(tensor, group=None): 74 | if group is None: 75 | group = get_default_group() 76 | return dist.all_reduce(tensor, group=group) 77 | 78 | 79 | def all_gather_list(data, group=None, max_size=16384): 80 | """Gathers arbitrary data from all nodes into a list. 81 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 82 | data. Note that *data* must be picklable. 83 | Args: 84 | data (Any): data from the local worker to be gathered on other workers 85 | group (optional): group of the collective 86 | max_size (int, optional): maximum size of the data to be gathered 87 | across workers 88 | """ 89 | rank = get_rank() 90 | world_size = get_world_size() 91 | 92 | buffer_size = max_size * world_size 93 | if not hasattr(all_gather_list, '_buffer') or \ 94 | all_gather_list._buffer.numel() < buffer_size: 95 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 96 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 97 | buffer = all_gather_list._buffer 98 | buffer.zero_() 99 | cpu_buffer = all_gather_list._cpu_buffer 100 | 101 | enc = pickle.dumps(data) 102 | enc_size = len(enc) 103 | if enc_size + 2 > max_size: 104 | raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) 105 | assert max_size < 255*256 106 | 107 | cpu_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 108 | cpu_buffer[1] = enc_size % 255 109 | cpu_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc)) 110 | start = rank * max_size 111 | size = enc_size + 2 112 | buffer[start : start + size].copy_(cpu_buffer[:size]) 113 | 114 | all_reduce(buffer, group=group) 115 | 116 | try: 117 | result = [] 118 | for i in range(world_size): 119 | out_buffer = buffer[i * max_size : (i + 1) * max_size] 120 | size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) 121 | if size > 0: 122 | result.append(pickle.loads(bytes(out_buffer[2 : size + 2].tolist()))) 123 | return result 124 | except pickle.UnpicklingError: 125 | raise Exception( 126 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 127 | 'workers to enter the function together, so this error usually indicates ' 128 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 129 | 'sync if one of them runs out of memory, or if there are other conditions ' 130 | 'in your training script that can cause one worker to finish an epoch ' 131 | 'while other workers are still iterating over their portions of the data.' 132 | ) -------------------------------------------------------------------------------- /hetseq/eval_bert_fine_tuning_ner.py: -------------------------------------------------------------------------------- 1 | __DATE__ = '2020/12/26' 2 | __AUTHOR__ = 'Yifan_Ding' 3 | __E_MAIL = 'yding4@nd.edu' 4 | 5 | # 1. load dataset, either from CONLL2003 or other entity linking datasets. 6 | # 2. load model from a trained one, load model architecture and state_dict from a trained one. 7 | # 3. predict loss, generate predicted label 8 | # 4. compare predicted label with ground truth label to obtain evaluation results. 9 | 10 | if args.test_file is not None: 11 | data_files["test"] = args.test_file 12 | else: 13 | raise ValueError('Evaluation must specify test_file!') 14 | -------------------------------------------------------------------------------- /hetseq/eval_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import datasets, transforms 7 | 8 | 9 | class MNISTNet(nn.Module): 10 | def __init__(self): 11 | super(MNISTNet, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 13 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 14 | self.dropout1 = nn.Dropout2d(0.25) 15 | self.dropout2 = nn.Dropout2d(0.5) 16 | self.fc1 = nn.Linear(9216, 128) 17 | self.fc2 = nn.Linear(128, 10) 18 | 19 | def forward(self, x, target, eval=False): 20 | # print('shape', x.shape, target.shape) 21 | # print(target) 22 | 23 | x = self.conv1(x) 24 | x = F.relu(x) 25 | x = self.conv2(x) 26 | x = F.relu(x) 27 | x = F.max_pool2d(x, 2) 28 | x = self.dropout1(x) 29 | x = torch.flatten(x, 1) 30 | x = self.fc1(x) 31 | x = F.relu(x) 32 | x = self.dropout2(x) 33 | x = self.fc2(x) 34 | output = F.log_softmax(x, dim=1) 35 | loss = F.nll_loss(output, target) 36 | return loss if not eval else output, loss 37 | # return loss 38 | 39 | def main(args): 40 | transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ]) 44 | dataset2 = datasets.MNIST(args.mnist_dir, train=False, download=True, 45 | transform=transform) 46 | 47 | device = torch.device("cuda") 48 | checkpoint = torch.load(args.model_ckpt) 49 | 50 | model = MNISTNet() 51 | model.load_state_dict(checkpoint['model']) 52 | model.to(device) 53 | model.eval() 54 | 55 | kwargs = {'batch_size': 64} 56 | test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) 57 | 58 | 59 | test_loss = 0 60 | correct = 0 61 | with torch.no_grad(): 62 | for data, target in test_loader: 63 | data, target = data.to(device), target.to(device) 64 | outputs = model(data, target, eval=True) 65 | output = outputs[0] 66 | loss = outputs[1] 67 | test_loss += loss.item() 68 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 69 | correct += pred.eq(target.view_as(pred)).sum().item() 70 | 71 | test_loss /= len(test_loader.dataset) 72 | 73 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 74 | test_loss, correct, len(test_loader.dataset), 75 | 100. * correct / len(test_loader.dataset))) 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser( 80 | description='evaluate trained mnist model', 81 | allow_abbrev=False, 82 | ) 83 | 84 | parser.add_argument( 85 | '--model_ckpt', 86 | type=str, 87 | default='/scratch365/yding4/hetseq/CRC_RUN_FILE/new_test/node1gpu4/checkpoint_last.pt', 88 | help='Specify the input AIDA directory', 89 | ) 90 | 91 | parser.add_argument( 92 | '--mnist_dir', 93 | type=str, 94 | default='/scratch365/yding4/mnist/MNIST/processed', 95 | help='Specify the input AIDA directory', 96 | ) 97 | 98 | 99 | args = parser.parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /hetseq/file_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, unicode_literals) 2 | 3 | import json 4 | import logging 5 | import os 6 | import shutil 7 | import tempfile 8 | from functools import wraps 9 | from hashlib import sha256 10 | import sys 11 | from io import open 12 | 13 | import boto3 14 | import requests 15 | from botocore.exceptions import ClientError 16 | from tqdm import tqdm 17 | 18 | try: 19 | from urllib.parse import urlparse 20 | except ImportError: 21 | from urlparse import urlparse 22 | 23 | try: 24 | from pathlib import Path 25 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 26 | Path.home() / '.pytorch_pretrained_bert')) 27 | except AttributeError: 28 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 29 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 30 | 31 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | def url_to_filename(url, etag=None): 35 | """ 36 | Convert `url` into a hashed filename in a repeatable way. 37 | If `etag` is specified, append its hash to the url's, delimited 38 | by a period. 39 | """ 40 | url_bytes = url.encode('utf-8') 41 | url_hash = sha256(url_bytes) 42 | filename = url_hash.hexdigest() 43 | 44 | if etag: 45 | etag_bytes = etag.encode('utf-8') 46 | etag_hash = sha256(etag_bytes) 47 | filename += '.' + etag_hash.hexdigest() 48 | 49 | return filename 50 | 51 | 52 | def filename_to_url(filename, cache_dir=None): 53 | """ 54 | Return the url and etag (which may be ``None``) stored for `filename`. 55 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 56 | """ 57 | if cache_dir is None: 58 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 59 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 60 | cache_dir = str(cache_dir) 61 | 62 | cache_path = os.path.join(cache_dir, filename) 63 | if not os.path.exists(cache_path): 64 | raise EnvironmentError("file {} not found".format(cache_path)) 65 | 66 | meta_path = cache_path + '.json' 67 | if not os.path.exists(meta_path): 68 | raise EnvironmentError("file {} not found".format(meta_path)) 69 | 70 | with open(meta_path, encoding="utf-8") as meta_file: 71 | metadata = json.load(meta_file) 72 | url = metadata['url'] 73 | etag = metadata['etag'] 74 | 75 | return url, etag 76 | 77 | 78 | def cached_path(url_or_filename, cache_dir=None): 79 | """ 80 | Given something that might be a URL (or might be a local path), 81 | determine which. If it's a URL, download the file and cache it, and 82 | return the path to the cached file. If it's already a local path, 83 | make sure the file exists and then return the path. 84 | """ 85 | if cache_dir is None: 86 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 87 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 88 | url_or_filename = str(url_or_filename) 89 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 90 | cache_dir = str(cache_dir) 91 | 92 | parsed = urlparse(url_or_filename) 93 | 94 | if parsed.scheme in ('http', 'https', 's3'): 95 | # URL, so get it from the cache (downloading if necessary) 96 | return get_from_cache(url_or_filename, cache_dir) 97 | elif os.path.exists(url_or_filename): 98 | # File, and it exists. 99 | return url_or_filename 100 | elif parsed.scheme == '': 101 | # File, but it doesn't exist. 102 | raise EnvironmentError("file {} not found".format(url_or_filename)) 103 | else: 104 | # Something unknown 105 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 106 | 107 | 108 | def split_s3_path(url): 109 | """Split a full s3 path into the bucket name and path.""" 110 | parsed = urlparse(url) 111 | if not parsed.netloc or not parsed.path: 112 | raise ValueError("bad s3 path {}".format(url)) 113 | bucket_name = parsed.netloc 114 | s3_path = parsed.path 115 | # Remove '/' at beginning of path. 116 | if s3_path.startswith("/"): 117 | s3_path = s3_path[1:] 118 | return bucket_name, s3_path 119 | 120 | 121 | def s3_request(func): 122 | """ 123 | Wrapper function for s3 requests in order to create more helpful error 124 | messages. 125 | """ 126 | 127 | @wraps(func) 128 | def wrapper(url, *args, **kwargs): 129 | try: 130 | return func(url, *args, **kwargs) 131 | except ClientError as exc: 132 | if int(exc.response["Error"]["Code"]) == 404: 133 | raise EnvironmentError("file {} not found".format(url)) 134 | else: 135 | raise 136 | 137 | return wrapper 138 | 139 | 140 | @s3_request 141 | def s3_etag(url): 142 | """Check ETag on S3 object.""" 143 | s3_resource = boto3.resource("s3") 144 | bucket_name, s3_path = split_s3_path(url) 145 | s3_object = s3_resource.Object(bucket_name, s3_path) 146 | return s3_object.e_tag 147 | 148 | 149 | @s3_request 150 | def s3_get(url, temp_file): 151 | """Pull a file directly from S3.""" 152 | s3_resource = boto3.resource("s3") 153 | bucket_name, s3_path = split_s3_path(url) 154 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 155 | 156 | 157 | def http_get(url, temp_file): 158 | req = requests.get(url, stream=True) 159 | content_length = req.headers.get('Content-Length') 160 | total = int(content_length) if content_length is not None else None 161 | progress = tqdm(unit="B", total=total) 162 | for chunk in req.iter_content(chunk_size=1024): 163 | if chunk: # filter out keep-alive new chunks 164 | progress.update(len(chunk)) 165 | temp_file.write(chunk) 166 | progress.close() 167 | 168 | 169 | def get_from_cache(url, cache_dir=None): 170 | """ 171 | Given a URL, look for the corresponding dataset in the local cache. 172 | If it's not there, download it. Then return the path to the cached file. 173 | """ 174 | if cache_dir is None: 175 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 176 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 177 | cache_dir = str(cache_dir) 178 | 179 | if not os.path.exists(cache_dir): 180 | os.makedirs(cache_dir) 181 | 182 | # Get eTag to add to filename, if it exists. 183 | if url.startswith("s3://"): 184 | etag = s3_etag(url) 185 | else: 186 | response = requests.head(url, allow_redirects=True) 187 | if response.status_code != 200: 188 | raise IOError("HEAD request failed for url {} with status code {}" 189 | .format(url, response.status_code)) 190 | etag = response.headers.get("ETag") 191 | 192 | filename = url_to_filename(url, etag) 193 | 194 | # get cache path to put the file 195 | cache_path = os.path.join(cache_dir, filename) 196 | 197 | if not os.path.exists(cache_path): 198 | # Download to temporary file, then copy to cache dir once finished. 199 | # Otherwise you get corrupt cache entries if the download gets interrupted. 200 | with tempfile.NamedTemporaryFile() as temp_file: 201 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 202 | 203 | # GET file object 204 | if url.startswith("s3://"): 205 | s3_get(url, temp_file) 206 | else: 207 | http_get(url, temp_file) 208 | 209 | # we are copying the file before closing it, so flush to avoid truncation 210 | temp_file.flush() 211 | # shutil.copyfileobj() starts at the current position, so go to the start 212 | temp_file.seek(0) 213 | 214 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 215 | with open(cache_path, 'wb') as cache_file: 216 | shutil.copyfileobj(temp_file, cache_file) 217 | 218 | logger.info("creating metadata file for %s", cache_path) 219 | meta = {'url': url, 'etag': etag} 220 | meta_path = cache_path + '.json' 221 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 222 | json.dump(meta, meta_file) 223 | 224 | logger.info("removing temp file %s", temp_file.name) 225 | 226 | return cache_path 227 | 228 | 229 | def read_set_from_file(filename): 230 | ''' 231 | Extract a de-duped collection (set) of text from a file. 232 | Expected file format is one item per line. 233 | ''' 234 | collection = set() 235 | with open(filename, 'r', encoding='utf-8') as file_: 236 | for line in file_: 237 | collection.add(line.rstrip()) 238 | return collection 239 | 240 | 241 | def get_file_extension(path, dot=True, lower=True): 242 | ext = os.path.splitext(path)[1] 243 | ext = ext if dot else ext[1:] 244 | return ext.lower() if lower else ext 245 | -------------------------------------------------------------------------------- /hetseq/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from hetseq.optim import _Optimizer 4 | 5 | 6 | class _LRScheduler(object): 7 | def __init__(self, args, optimizer): 8 | super().__init__() 9 | # # TODO make optimizer, lr_scheduler compatiable 10 | if not isinstance(optimizer, _Optimizer): 11 | raise ValueError('optimizer must be an instance of _Optimizer') 12 | self.args = args 13 | self.optimizer = optimizer 14 | self.best = None 15 | 16 | ''' 17 | @staticmethod 18 | def add_args(parser): 19 | """Add arguments to the parser for this LR scheduler.""" 20 | pass 21 | ''' 22 | 23 | def state_dict(self): 24 | """Return the LR scheduler state dict.""" 25 | return {'best': self.best} 26 | 27 | def load_state_dict(self, state_dict): 28 | """Load an LR scheduler state dict.""" 29 | self.best = state_dict['best'] 30 | 31 | def step(self, epoch, val_loss=None): 32 | """Update the learning rate at the end of the given epoch.""" 33 | if val_loss is not None: 34 | if self.best is None: 35 | self.best = val_loss 36 | else: 37 | self.best = min(self.best, val_loss) 38 | 39 | def step_update(self, num_updates): 40 | """Update the learning rate after each update.""" 41 | return self.optimizer.get_lr() 42 | 43 | 44 | class PolynomialDecayScheduler(_LRScheduler): 45 | """Decay the LR on a fixed schedule.""" 46 | 47 | def __init__(self, args, optimizer): 48 | super().__init__(args, optimizer) 49 | 50 | # set defaults 51 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 52 | 53 | self.lr = args.lr[0] 54 | if args.warmup_updates > 0: 55 | self.warmup_factor = 1. / args.warmup_updates 56 | else: 57 | self.warmup_factor = 1 58 | self.end_learning_rate = args.end_learning_rate 59 | self.total_num_update = args.total_num_update 60 | self.power = args.power 61 | self.optimizer.set_lr(self.warmup_factor * self.lr) 62 | 63 | @staticmethod 64 | def add_args(parser): 65 | """Add arguments to the parser for this LR scheduler.""" 66 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 67 | help='force annealing at specified epoch') 68 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 69 | help='warmup the learning rate linearly for the first N updates') 70 | parser.add_argument('--end-learning-rate', default=0.0, type=float) 71 | parser.add_argument('--power', default=1.0, type=float) 72 | parser.add_argument('--total-num-update', default=1000000, type=int) 73 | 74 | def get_next_lr(self, epoch): 75 | lrs = self.args.lr 76 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 77 | # use fixed LR schedule 78 | next_lr = lrs[min(epoch, len(lrs) - 1)] 79 | else: 80 | # annneal based on lr_shrink 81 | next_lr = self.optimizer.get_lr() 82 | return next_lr 83 | 84 | def step(self, epoch, val_loss=None): 85 | """Update the learning rate at the end of the given epoch.""" 86 | super().step(epoch, val_loss) 87 | self.lr = self.get_next_lr(epoch) 88 | self.optimizer.set_lr(self.warmup_factor * self.lr) 89 | return self.optimizer.get_lr() 90 | 91 | def step_update(self, num_updates): 92 | """Update the learning rate after each update.""" 93 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 94 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 95 | lr = self.warmup_factor * self.lr 96 | elif num_updates >= self.total_num_update: 97 | lr = self.end_learning_rate 98 | else: 99 | warmup = self.args.warmup_updates 100 | lr_range = self.lr - self.end_learning_rate 101 | pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup) 102 | lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate 103 | self.optimizer.set_lr(lr) 104 | return self.optimizer.get_lr() 105 | 106 | -------------------------------------------------------------------------------- /hetseq/meters.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | class TimeMeter(object): 23 | """Computes the average occurrence of some event per second""" 24 | def __init__(self, init=0): 25 | self.reset(init) 26 | 27 | def reset(self, init=0): 28 | self.init = init 29 | self.start = time.time() 30 | self.n = 0 31 | 32 | def update(self, val=1): 33 | self.n += val 34 | 35 | @property 36 | def avg(self): 37 | return self.n / self.elapsed_time 38 | 39 | @property 40 | def elapsed_time(self): 41 | return self.init + (time.time() - self.start) 42 | 43 | 44 | class StopwatchMeter(object): 45 | """Computes the sum/avg duration of some event in seconds""" 46 | def __init__(self): 47 | self.reset() 48 | 49 | def start(self): 50 | self.start_time = time.time() 51 | 52 | def stop(self, n=1): 53 | if self.start_time is not None: 54 | delta = time.time() - self.start_time 55 | self.sum += delta 56 | self.n += n 57 | self.start_time = None 58 | 59 | def reset(self): 60 | self.sum = 0 61 | self.n = 0 62 | self.start_time = None 63 | 64 | @property 65 | def avg(self): 66 | return self.sum / self.n -------------------------------------------------------------------------------- /hetseq/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert_for_EL_classification import BertForELClassification 2 | 3 | __all__ = [ 4 | 'BertForELClassification', 5 | ] 6 | -------------------------------------------------------------------------------- /hetseq/model/bert_for_EL_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import ( 5 | CosineEmbeddingLoss, 6 | CrossEntropyLoss, 7 | ) 8 | 9 | from hetseq.bert_modeling import ( 10 | BertPreTrainedModel, 11 | BertModel, 12 | ) 13 | 14 | _OUT_DICT_ENTITY_ID = -1 15 | _IGNORE_CLASSIFICATION_LABEL = -100 16 | NER_LABEL_DICT = {'B': 0, 'I':1, 'O':2} 17 | 18 | # **YD** tag each token with its cor- rect mention indicator and link each mention with its correct entity ID. 19 | 20 | 21 | class BertForELClassification(BertPreTrainedModel): 22 | def __init__(self, config, args): 23 | super(BertForELClassification, self).__init__(config) 24 | self.config = config 25 | self.args = args 26 | self.num_labels = args.num_labels 27 | self.bert = BertModel(config) 28 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 29 | self.classifier = nn.Linear(config.hidden_size, self.num_labels) 30 | 31 | self.num_entity_labels = args.num_entity_labels 32 | self.dim_entity_emb = args.dim_entity_emb 33 | self.entity_classifier = nn.Linear(config.hidden_size, self.dim_entity_emb) 34 | 35 | self.apply(self.init_bert_weights) 36 | 37 | # **YD** TODO args.EntityEmbedding to be added. 38 | self.entity_emb = nn.Embedding.from_pretrained(args.EntityEmbedding, freeze=True) 39 | assert len(self.entity_emb.weight.shape) == 2 40 | assert self.entity_emb.weight.shape[0] == self.num_entity_labels 41 | assert self.entity_emb.weight.shape[1] == self.dim_entity_emb 42 | 43 | self.activate = torch.tanh 44 | 45 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, 46 | labels=None, entity_labels=None, checkpoint_activations=False): 47 | 48 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 49 | sequence_output = self.dropout(sequence_output) 50 | logits = self.classifier(sequence_output) 51 | 52 | # **YD** entity branch forward. 53 | entity_logits = self.entity_classifier(sequence_output) 54 | # **YD** may not require activation function 55 | entity_logits = self.activate(entity_logits) 56 | 57 | # entity_logits = F.normalize(entity_logits, 2, 2) 58 | # entity_logits = torch.matmul(entity_logits, self.entity_emb.weight.T) 59 | # entity_logits = torch.log(entity_logits) 60 | 61 | if labels is not None: 62 | loss_fct = CrossEntropyLoss() 63 | entity_loss_fct = CosineEmbeddingLoss() 64 | # Only keep active parts of the loss 65 | if attention_mask is not None: 66 | ''' 67 | active_loss = attention_mask.view(-1) == 1 68 | active_logits = logits.view(-1, self.num_labels)[active_loss] 69 | active_labels = labels.view(-1)[active_loss] 70 | loss = loss_fct(active_logits, active_labels) 71 | ''' 72 | active_loss = attention_mask.view(-1) == 1 73 | active_logits = logits.view(-1, self.num_labels) 74 | active_labels = torch.where( 75 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 76 | ) 77 | ner_loss = loss_fct(active_logits, active_labels) 78 | 79 | ''' 80 | entity_labels[entity_labels == _OUT_DICT_ENTITY_ID] = _IGNORE_CLASSIFICATION_LABEL 81 | assert entity_labels.requires_grad is False 82 | entity_active_logits = entity_logits.view(-1, self.num_entity_labels) 83 | entity_active_labels = torch.where( 84 | active_loss, entity_labels.view(-1), 85 | torch.tensor(entity_loss_fct.ignore_index).type_as(entity_labels) 86 | ) 87 | entity_loss = entity_loss_fct(entity_active_logits, entity_active_labels) 88 | ''' 89 | 90 | # entity_active_loss = (labels.view(-1) == NER_LABEL_DICT['B']) | active_loss 91 | entity_active_loss = (entity_labels.view(-1) > 0) 92 | entity_active_logits = entity_logits.view(-1, self.dim_entity_emb)[entity_active_loss] 93 | entity_active_labels = entity_labels.view(-1)[entity_active_loss] 94 | 95 | entity_loss = entity_loss_fct( 96 | entity_active_logits, 97 | self.entity_emb.weight[entity_active_labels], 98 | torch.tensor(1).type_as(entity_labels) 99 | ) 100 | 101 | print('ner_loss', ner_loss, 'entity_loss', entity_loss) 102 | if torch.isnan(entity_loss): 103 | loss = ner_loss 104 | else: 105 | loss = ner_loss + entity_loss 106 | assert not torch.isnan(loss) 107 | else: 108 | # loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 109 | raise ValueError("mask has to not None ") 110 | 111 | return loss 112 | else: 113 | return logits, entity_logits 114 | -------------------------------------------------------------------------------- /hetseq/progress_bar.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import json 3 | from numbers import Number 4 | import os 5 | import sys 6 | 7 | from hetseq import distributed_utils 8 | from hetseq.meters import AverageMeter, StopwatchMeter, TimeMeter 9 | 10 | g_tbmf_wrapper = None 11 | 12 | 13 | def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): 14 | if args.log_format is None: 15 | args.log_format = no_progress_bar if args.no_progress_bar else default 16 | 17 | if args.log_format == 'tqdm' and not sys.stderr.isatty(): 18 | args.log_format = 'simple' 19 | 20 | if args.log_format == 'json': 21 | bar = json_progress_bar(iterator, epoch, prefix, args.log_interval) 22 | elif args.log_format == 'none': 23 | bar = noop_progress_bar(iterator, epoch, prefix) 24 | elif args.log_format == 'simple': 25 | bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval) 26 | elif args.log_format == 'tqdm': 27 | bar = tqdm_progress_bar(iterator, epoch, prefix) 28 | else: 29 | raise ValueError('Unknown log format: {}'.format(args.log_format)) 30 | 31 | return bar 32 | 33 | 34 | def format_stat(stat): 35 | if isinstance(stat, Number): 36 | stat = '{:g}'.format(stat) 37 | elif isinstance(stat, AverageMeter): 38 | stat = '{:.3f}'.format(stat.avg) 39 | elif isinstance(stat, TimeMeter): 40 | stat = '{:g}'.format(round(stat.avg)) 41 | elif isinstance(stat, StopwatchMeter): 42 | #stat = '{:g}'.format(round(stat.sum)) 43 | stat = '{:.4f}'.format(stat.sum) 44 | return stat 45 | 46 | 47 | class progress_bar(object): 48 | """Abstract class for progress bars.""" 49 | def __init__(self, iterable, epoch=None, prefix=None): 50 | self.iterable = iterable 51 | self.offset = getattr(iterable, 'offset', 0) 52 | self.epoch = epoch 53 | self.prefix = '' 54 | if epoch is not None: 55 | self.prefix += '| epoch {:03d}'.format(epoch) 56 | if prefix is not None: 57 | self.prefix += ' | {}'.format(prefix) 58 | 59 | def __len__(self): 60 | return len(self.iterable) 61 | 62 | def __enter__(self): 63 | return self 64 | 65 | def __exit__(self, *exc): 66 | return False 67 | 68 | def __iter__(self): 69 | raise NotImplementedError 70 | 71 | def log(self, stats, tag='', step=None): 72 | """Log intermediate stats according to log_interval.""" 73 | raise NotImplementedError 74 | 75 | def print(self, stats, tag='', step=None): 76 | """Print end-of-epoch stats.""" 77 | raise NotImplementedError 78 | 79 | def _str_commas(self, stats): 80 | return ', '.join(key + '=' + stats[key].strip() 81 | for key in stats.keys()) 82 | 83 | def _str_pipes(self, stats): 84 | return ' | '.join(key + ' ' + stats[key].strip() 85 | for key in stats.keys()) 86 | 87 | def _format_stats(self, stats): 88 | postfix = OrderedDict(stats) 89 | # Preprocess stats according to datatype 90 | for key in postfix.keys(): 91 | postfix[key] = str(format_stat(postfix[key])) 92 | return postfix 93 | 94 | 95 | class noop_progress_bar(progress_bar): 96 | """No logging.""" 97 | 98 | def __init__(self, iterable, epoch=None, prefix=None): 99 | super().__init__(iterable, epoch, prefix) 100 | 101 | def __iter__(self): 102 | for obj in self.iterable: 103 | yield obj 104 | 105 | def log(self, stats, tag='', step=None): 106 | """Log intermediate stats according to log_interval.""" 107 | pass 108 | 109 | def print(self, stats, tag='', step=None): 110 | """Print end-of-epoch stats.""" 111 | pass 112 | 113 | 114 | class simple_progress_bar(progress_bar): 115 | """A minimal logger for non-TTY environments.""" 116 | 117 | def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): 118 | super().__init__(iterable, epoch, prefix) 119 | self.log_interval = log_interval 120 | self.stats = None 121 | 122 | def __iter__(self): 123 | size = len(self.iterable) 124 | for i, obj in enumerate(self.iterable, start=self.offset): 125 | yield obj 126 | if self.stats is not None and i > 0 and \ 127 | self.log_interval is not None and i % self.log_interval == 0: 128 | postfix = self._str_commas(self.stats) 129 | print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix), 130 | flush=True) 131 | 132 | def log(self, stats, tag='', step=None): 133 | """Log intermediate stats according to log_interval.""" 134 | self.stats = self._format_stats(stats) 135 | 136 | def print(self, stats, tag='', step=None): 137 | """Print end-of-epoch stats.""" 138 | postfix = self._str_pipes(self._format_stats(stats)) 139 | print('{} | {}'.format(self.prefix, postfix), flush=True) -------------------------------------------------------------------------------- /hetseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .tasks import Task, LanguageModelingTask, MNISTTask 2 | # from .bert_for_el_classification_task import BertForELClassificationTask 3 | # from .bert_for_token_classification_task import BertForTokenClassificationTask 4 | 5 | __all__ = [ 6 | 'Task', 7 | 'LanguageModelingTask', 8 | #'BertForTokenClassificationTask', 9 | 'MNISTTask', 10 | #'BertForELClassificationTask', 11 | ] 12 | -------------------------------------------------------------------------------- /hetseq/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import random 4 | import argparse 5 | import collections 6 | 7 | import torch 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | from hetseq import ( 13 | checkpoint_utils, 14 | distributed_utils, 15 | options, 16 | progress_bar, 17 | tasks, 18 | utils 19 | ) 20 | from hetseq.data import iterators 21 | from hetseq.controller import Controller 22 | from hetseq.meters import AverageMeter, StopwatchMeter 23 | 24 | 25 | def main(args, init_distributed=False): 26 | assert args.max_tokens is not None or args.max_sentences is not None, \ 27 | 'Must specify batch size either with --max-tokens or --max-sentences' 28 | 29 | if torch.cuda.is_available() and not args.cpu: 30 | torch.cuda.set_device(args.device_id) 31 | 32 | # set random seed 33 | np.random.seed(args.seed) 34 | torch.manual_seed(args.seed) 35 | 36 | if init_distributed: 37 | args.distributed_rank = distributed_utils.distributed_init(args) 38 | 39 | if distributed_utils.is_master(args): 40 | checkpoint_utils.verify_checkpoint_directory(args.save_dir) 41 | 42 | print(args, flush=True) 43 | 44 | # Setup task, e.g., translation, language modeling, etc. 45 | task = None 46 | if args.task == 'bert': 47 | task = tasks.LanguageModelingTask.setup_task(args) 48 | elif args.task == 'mnist': 49 | task = tasks.MNISTTask.setup_task(args) 50 | elif args.task == 'BertForTokenClassification': 51 | task = tasks.BertForTokenClassificationTask.setup_task(args) 52 | elif args.task == 'BertForELClassification': 53 | task = tasks.BertForELClassificationTask.setup_task(args) 54 | assert task != None 55 | 56 | # Load valid dataset (we load training data below, based on the latest checkpoint) 57 | for valid_sub_split in args.valid_subset.split(','): 58 | task.load_dataset(valid_sub_split, combine=False, epoch=0) 59 | 60 | # Build model 61 | model = task.build_model(args) 62 | 63 | print('| num. model params: {} (num. trained: {})'.format( 64 | sum(p.numel() for p in model.parameters()), 65 | sum(p.numel() for p in model.parameters() if p.requires_grad), 66 | )) 67 | 68 | # Build controller 69 | controller = Controller(args, task, model) 70 | print('| training on {} GPUs'.format(args.distributed_world_size)) 71 | print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( 72 | args.max_tokens, 73 | args.max_sentences, 74 | )) 75 | 76 | # Load the latest checkpoint if one is available and restore the 77 | # corresponding train iterator 78 | 79 | extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, controller) 80 | 81 | # Train until the learning rate gets too small 82 | max_epoch = args.max_epoch or math.inf 83 | max_update = args.max_update or math.inf 84 | 85 | lr = controller.get_lr() 86 | train_meter = StopwatchMeter() 87 | train_meter.start() 88 | 89 | while ( 90 | lr > args.min_lr 91 | and (epoch_itr.epoch < max_epoch 92 | or (epoch_itr.epoch == max_epoch 93 | and epoch_itr._next_epoch_itr is not None)) 94 | and controller.get_num_updates() < max_update 95 | ): 96 | # train for one epoch 97 | train(args, controller, task, epoch_itr) # #revise-task 6 98 | 99 | # debug 100 | valid_losses = [None] 101 | 102 | # only use first validation loss to update the learning rate 103 | lr = controller.lr_step(epoch_itr.epoch, valid_losses[0]) 104 | 105 | # save checkpoint 106 | if epoch_itr.epoch % args.save_interval == 0: 107 | checkpoint_utils.save_checkpoint(args, controller, epoch_itr, valid_losses[0]) 108 | 109 | reload_dataset = hasattr(args, 'data') and args.data is not None and ':' in getattr(args, 'data', '') 110 | # sharded data: get train iterator for next epoch 111 | epoch_itr = controller.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) 112 | 113 | train_meter.stop() 114 | print('| done training in {:.1f} seconds'.format(train_meter.sum)) 115 | 116 | 117 | def train(args, controller, task, epoch_itr): # #revise-task 7 118 | """Train the model for one epoch.""" 119 | # Update parameters every N batches, CORE scaling method 120 | update_freq = args.update_freq[epoch_itr.epoch - 1] \ 121 | if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] 122 | 123 | # Initialize data iterator 124 | itr = epoch_itr.next_epoch_itr( 125 | fix_batches_to_gpus=args.fix_batches_to_gpus, 126 | shuffle=(epoch_itr.epoch >= args.curriculum), 127 | ) 128 | 129 | itr = iterators.GroupedIterator(itr, update_freq) 130 | 131 | progress = progress_bar.build_progress_bar( 132 | args, itr, epoch_itr.epoch, no_progress_bar='simple', 133 | ) 134 | 135 | extra_meters = collections.defaultdict(lambda: AverageMeter()) 136 | valid_subsets = args.valid_subset.split(',') 137 | max_update = args.max_update or math.inf 138 | 139 | 140 | loop = enumerate(progress, start=epoch_itr.iterations_in_epoch) 141 | 142 | for i, samples in loop: 143 | log_output = controller.train_step(samples) 144 | if log_output is None: 145 | continue 146 | 147 | # log mid-epoch stats 148 | 149 | stats = get_training_stats(controller) 150 | for k, v in log_output.items(): 151 | if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: 152 | continue # these are already logged above 153 | if 'loss' in k or k == 'accuracy': 154 | extra_meters[k].update(v, log_output['sample_size']) 155 | else: 156 | extra_meters[k].update(v) 157 | stats[k] = extra_meters[k].avg 158 | progress.log(stats, tag='train', step=stats['num_updates']) 159 | 160 | # ignore the first mini-batch in words-per-second and updates-per-second calculation 161 | if i == 0: 162 | controller.get_meter('wps').reset() 163 | controller.get_meter('ups').reset() 164 | 165 | num_updates = controller.get_num_updates() 166 | 167 | if num_updates >= max_update: 168 | break 169 | 170 | 171 | def get_training_stats(controller): 172 | stats = collections.OrderedDict() 173 | stats['loss'] = controller.get_meter('train_loss') # #training loss 174 | if controller.get_meter('train_nll_loss').count > 0: 175 | nll_loss = controller.get_meter('train_nll_loss') 176 | stats['nll_loss'] = nll_loss 177 | else: 178 | nll_loss = controller.get_meter('train_loss') # #null loss? 179 | stats['ppl'] = utils.get_perplexity(nll_loss.avg) # #perplexity 2**null_loss.avg 180 | stats['wps'] = controller.get_meter('wps') # #words per second 181 | stats['ups'] = controller.get_meter('ups') # #updates per second 182 | stats['wpb'] = controller.get_meter('wpb') # #?words per batch? 183 | stats['bsz'] = controller.get_meter('bsz') # #batch size 184 | stats['num_updates'] = controller.get_num_updates() # #number of updates 185 | stats['lr'] = controller.get_lr() # #learning rate 186 | stats['gnorm'] = controller.get_meter('gnorm') # #?normalization 187 | stats['clip'] = controller.get_meter('clip') # #gradient clip 188 | stats['oom'] = controller.get_meter('oom') # #out of memory 189 | if controller.get_meter('loss_scale') is not None: 190 | stats['loss_scale'] = controller.get_meter('loss_scale') 191 | stats['wall'] = round(controller.get_meter('wall').elapsed_time) # #walk time 192 | stats['train_wall'] = controller.get_meter('train_wall') # #training walk time 193 | return stats 194 | 195 | 196 | def distributed_main(i, args, start_rank=0): 197 | args.device_id = i 198 | if args.distributed_rank is None: # torch.multiprocessing.spawn 199 | args.distributed_rank = start_rank + i 200 | main(args, init_distributed=True) 201 | 202 | 203 | def cli_main(): 204 | task_parser = argparse.ArgumentParser(allow_abbrev=False) 205 | task_parser.add_argument('--task', type=str, 206 | default='bert', choices=['bert', 'mnist', 'BertForELClassification', 'BertForTokenClassification']) 207 | task_parser.add_argument('--optimizer', type=str, 208 | default='adam', choices=['adam', 'adadelta']) 209 | task_parser.add_argument('--lr-scheduler', type=str, 210 | default='PolynomialDecayScheduler', choices=['PolynomialDecayScheduler']) 211 | 212 | pre_args, s = task_parser.parse_known_args() 213 | 214 | parser = options.get_training_parser(task=pre_args.task, 215 | optimizer=pre_args.optimizer, 216 | lr_scheduler=pre_args.lr_scheduler, 217 | ) 218 | args = options.parse_args_and_arch(parser, s) 219 | 220 | if args.distributed_init_method is not None: 221 | assert args.distributed_gpus <= torch.cuda.device_count() 222 | 223 | if args.distributed_gpus > 1 and not args.distributed_no_spawn: # #by default run this logic 224 | start_rank = args.distributed_rank 225 | args.distributed_rank = None # assign automatically 226 | torch.multiprocessing.spawn( 227 | fn=distributed_main, 228 | args=(args, start_rank), 229 | nprocs=args.distributed_gpus, 230 | ) 231 | else: 232 | distributed_main(args.device_id, args) 233 | elif args.distributed_world_size > 1: 234 | # fallback for single node with multiple GPUs 235 | assert args.distributed_world_size <= torch.cuda.device_count() 236 | port = random.randint(10000, 20000) 237 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) 238 | args.distributed_rank = None # set based on device id 239 | torch.multiprocessing.spawn( 240 | fn=distributed_main, 241 | args=(args, ), 242 | nprocs=args.distributed_world_size, 243 | ) 244 | else: 245 | # single GPU training 246 | main(args) 247 | 248 | 249 | if __name__ == "__main__": 250 | cli_main() -------------------------------------------------------------------------------- /hetseq/utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import math 3 | import os 4 | import sys 5 | from typing import Callable, List 6 | import warnings 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def apply_to_sample(f, sample): 13 | if len(sample) == 0: 14 | return {} 15 | 16 | def _apply(x): 17 | if torch.is_tensor(x): 18 | return f(x) 19 | elif isinstance(x, dict): 20 | return { 21 | key: _apply(value) 22 | for key, value in x.items() 23 | } 24 | elif isinstance(x, list): 25 | return [_apply(x) for x in x] 26 | else: 27 | return x 28 | 29 | return _apply(sample) 30 | 31 | 32 | def move_to_cuda(sample): 33 | 34 | def _move_to_cuda(tensor): 35 | return tensor.cuda() 36 | 37 | return apply_to_sample(_move_to_cuda, sample) 38 | 39 | 40 | def get_incremental_state(module, incremental_state, key): 41 | """Helper for getting incremental state for an nn.Module.""" 42 | full_key = _get_full_incremental_state_key(module, key) 43 | if incremental_state is None or full_key not in incremental_state: 44 | return None 45 | return incremental_state[full_key] 46 | 47 | 48 | def set_incremental_state(module, incremental_state, key, value): 49 | """Helper for setting incremental state for an nn.Module.""" 50 | if incremental_state is not None: 51 | full_key = _get_full_incremental_state_key(module, key) 52 | incremental_state[full_key] = value 53 | 54 | 55 | def load_align_dict(replace_unk): 56 | if replace_unk is None: 57 | align_dict = None 58 | elif isinstance(replace_unk, str) and len(replace_unk) > 0: 59 | # Load alignment dictionary for unknown word replacement if it was passed as an argument. 60 | align_dict = {} 61 | with open(replace_unk, 'r') as f: 62 | for line in f: 63 | cols = line.split() 64 | align_dict[cols[0]] = cols[1] 65 | else: 66 | # No alignment dictionary provided but we still want to perform unknown word replacement by copying the 67 | # original source word. 68 | align_dict = {} 69 | return align_dict 70 | 71 | 72 | def make_positions(tensor, padding_idx, onnx_trace=False): 73 | """Replace non-padding symbols with their position numbers. 74 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 75 | """ 76 | # The series of casts and type-conversions here are carefully 77 | # balanced to both work with ONNX export and XLA. In particular XLA 78 | # prefers ints, cumsum defaults to output longs, and ONNX doesn't know 79 | # how to handle the dtype kwarg in cumsum. 80 | mask = tensor.ne(padding_idx).int() 81 | return ( 82 | torch.cumsum(mask, dim=1).type_as(mask) * mask 83 | ).long() + padding_idx 84 | 85 | 86 | def item(tensor): 87 | if hasattr(tensor, 'item'): 88 | return tensor.item() 89 | if hasattr(tensor, '__getitem__'): 90 | return tensor[0] 91 | return tensor 92 | 93 | 94 | def fill_with_neg_inf(t): 95 | """FP16-compatible function that fills a tensor with -inf.""" 96 | return t.float().fill_(float('-inf')).type_as(t) 97 | 98 | 99 | def resolve_max_positions(*args): 100 | """Resolve max position constraints from multiple sources.""" 101 | 102 | def map_value_update(d1, d2): 103 | updated_value = copy.deepcopy(d1) 104 | for key in d2: 105 | if key not in updated_value: 106 | updated_value[key] = d2[key] 107 | else: 108 | updated_value[key] = min(d1[key], d2[key]) 109 | return updated_value 110 | 111 | def nullsafe_min(l): 112 | minim = None 113 | for item in l: 114 | if minim is None: 115 | minim = item 116 | elif item is not None and item < minim: 117 | minim = item 118 | return minim 119 | 120 | max_positions = None 121 | for arg in args: 122 | if max_positions is None: 123 | max_positions = arg 124 | elif arg is not None: 125 | if isinstance(arg, float) or isinstance(arg, int): 126 | max_positions = min(max_positions, arg) 127 | elif isinstance(arg, dict): 128 | max_positions = map_value_update(max_positions, arg) 129 | else: 130 | max_positions = tuple( 131 | map(nullsafe_min, zip(max_positions, arg)) 132 | ) 133 | 134 | return max_positions 135 | 136 | 137 | def import_user_module(args): 138 | module_path = getattr(args, 'user_dir', None) 139 | if module_path is not None: 140 | module_path = os.path.abspath(args.user_dir) 141 | if not os.path.exists(module_path): 142 | fairseq_rel_path = os.path.join(os.path.dirname(__file__), '..', args.user_dir) 143 | if os.path.exists(fairseq_rel_path): 144 | module_path = fairseq_rel_path 145 | module_parent, module_name = os.path.split(module_path) 146 | 147 | if module_name not in sys.modules: 148 | sys.path.insert(0, module_parent) 149 | importlib.import_module(module_name) 150 | sys.path.pop(0) 151 | 152 | 153 | def softmax(x, dim, onnx_trace=False): 154 | if onnx_trace: 155 | return F.softmax(x.float(), dim=dim) 156 | else: 157 | return F.softmax(x, dim=dim, dtype=torch.float32) 158 | 159 | 160 | def log_softmax(x, dim, onnx_trace=False): 161 | if onnx_trace: 162 | return F.log_softmax(x.float(), dim=dim) 163 | else: 164 | return F.log_softmax(x, dim=dim, dtype=torch.float32) 165 | 166 | 167 | def get_perplexity(loss): 168 | try: 169 | return float('{:.2f}'.format(math.pow(2, loss))) 170 | except OverflowError: 171 | return float('inf') 172 | 173 | 174 | def deprecation_warning(message, stacklevel=3): 175 | # don't use DeprecationWarning, since it's ignored by default 176 | warnings.warn(message, stacklevel=stacklevel) 177 | 178 | 179 | def get_activation_fn(activation: str) -> Callable: 180 | """ Returns the activation function corresponding to `activation` """ 181 | if activation == 'relu': 182 | return F.relu 183 | elif activation == 'gelu': 184 | return gelu 185 | elif activation == 'gelu_fast': 186 | deprecation_warning('--activation-fn=gelu_fast has been renamed to gelu_accurate') 187 | return gelu_accurate 188 | elif activation == 'gelu_accurate': 189 | return gelu_accurate 190 | elif activation == 'tanh': 191 | return torch.tanh 192 | elif activation == 'linear': 193 | return lambda x: x 194 | else: 195 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 196 | 197 | 198 | def get_available_activation_fns() -> List: 199 | return [ 200 | 'relu', 201 | 'gelu', 202 | 'gelu_fast', # deprecated 203 | 'gelu_accurate', 204 | 'tanh', 205 | 'linear', 206 | ] 207 | 208 | 209 | def has_parameters(module): 210 | try: 211 | next(module.parameters()) 212 | return True 213 | except StopIteration: 214 | return False -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython==0.29.13 2 | numpy==1.17.0 3 | h5py==2.10.0 4 | torch==1.6.0 5 | tqdm==4.36.1 6 | chardet==3.0.4 7 | idna==2.8 8 | python-dateutil==2.8.0 9 | sphinx-rtd-theme==0.5.0 10 | sphinx==3.2.1 11 | boto3==1.9.244 12 | torchvision==0.7.0 13 | datasets==1.1.3 14 | transformers==4.1.1 15 | seqeval==1.2.2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages, Extension 3 | 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | if sys.platform == 'darwin': 8 | extra_compile_args = ['-stdlib=libc++', '-O3'] 9 | else: 10 | extra_compile_args = ['-std=c++11', '-O3'] 11 | 12 | 13 | class NumpyExtension(Extension): 14 | """Source: https://stackoverflow.com/a/54128391""" 15 | 16 | def __init__(self, *args, **kwargs): 17 | self.__include_dirs = [] 18 | super().__init__(*args, **kwargs) 19 | 20 | @property 21 | def include_dirs(self): 22 | import numpy 23 | return self.__include_dirs + [numpy.get_include()] 24 | 25 | @include_dirs.setter 26 | def include_dirs(self, dirs): 27 | self.__include_dirs = dirs 28 | 29 | 30 | extensions = [ 31 | NumpyExtension( 32 | 'hetseq.data.data_utils_fast', 33 | sources=['hetseq/data/data_utils_fast.pyx'], 34 | language='c++', 35 | extra_compile_args=extra_compile_args, 36 | ), 37 | ] 38 | 39 | 40 | setup( 41 | name='hetseq', 42 | version='1.0.1', 43 | author='Yifan Ding', 44 | author_email='dyf0125@gmail.com', 45 | url="https://github.com/yifding/hetseq", 46 | project_urls={ 47 | "Documentation": "https://hetseq.readthedocs.io", 48 | "Medium Post": "https://towardsdatascience.com/training-bert-at-a-university-eedcf940c754", 49 | }, 50 | description='Distributed GPU Training on Heterogeneous Infrastructureg', 51 | long_description=long_description, 52 | long_description_content_type='text/markdown', 53 | setup_requires=[ 54 | 'cython', 55 | 'numpy', 56 | 'setuptools>=18.0', 57 | ], 58 | install_requires=[ 59 | 'cython', 60 | 'numpy', 61 | 'h5py==2.10.0', 62 | 'torch==1.6.0', 63 | 'tqdm==4.36.1', 64 | 'chardet==3.0.4', 65 | 'idna==2.8', 66 | 'python-dateutil==2.8.0', 67 | 'sphinx-rtd-theme==0.5.0', 68 | 'sphinx==3.2.1', 69 | 'boto3==1.9.244', 70 | 'torchvision==0.7.0', 71 | 'datasets==1.1.3', 72 | 'transformers==4.1.1', 73 | 'seqeval==1.2.2', 74 | ], 75 | packages=find_packages(exclude=[]), 76 | ext_modules=extensions, 77 | zip_safe=False, 78 | ) 79 | -------------------------------------------------------------------------------- /test/TRANFORMERS_test_eval_bert_fine_tuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from datasets import load_dataset, ClassLabel 5 | 6 | from transformers import ( 7 | BertConfig, 8 | BertForPreTraining, 9 | BertForTokenClassification, 10 | ) 11 | 12 | from hetseq import utils 13 | from hetseq.data_collator import YD_DataCollatorForTokenClassification 14 | 15 | from transformers import ( 16 | BertTokenizerFast, 17 | DataCollatorForTokenClassification, 18 | ) 19 | 20 | import torch 21 | from torch.utils.data.dataloader import DataLoader 22 | 23 | from tqdm import tqdm 24 | from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score 25 | 26 | 27 | __DATE__ = '2020/1/3' 28 | __AUTHOR__ = 'Yifan_Ding' 29 | __E_MAIL = 'dyf0125@gmail.edu' 30 | 31 | _NER_COLUMNS = ['input_ids', 'labels', 'token_type_ids', 'attention_mask'] 32 | 33 | # 1. load dataset, either from CONLL2003 or other entity linking datasets. 34 | # 2. load model from a trained one, load model architecture and state_dict from a trained one. 35 | # 3. predict loss, generate predicted label 36 | # 4. compare predicted label with ground truth label to obtain evaluation results. 37 | 38 | 39 | def main(args): 40 | dataset = prepare_dataset(args) 41 | tokenizer = prepare_tokenizer(args) 42 | data_collator = YD_DataCollatorForTokenClassification(tokenizer) 43 | # data_collator = DataCollatorForTokenClassification(tokenizer) 44 | 45 | if 'test' in dataset: 46 | column_names = dataset["test"].column_names 47 | features = dataset["test"].features 48 | else: 49 | raise ValueError('Evaluation must specify test_file!') 50 | 51 | text_column_name = "tokens" if "tokens" in column_names else column_names[0] 52 | label_column_name = ('ner_tags' if 'ner_tags' in column_names else column_names[1]) 53 | 54 | if isinstance(features[label_column_name].feature, ClassLabel): 55 | label_list = features[label_column_name].feature.names 56 | # No need to convert the labels since they are already ints. 57 | label_to_id = {i: i for i in range(len(label_list))} 58 | else: 59 | if 'test' in label_column_name: 60 | label_list = get_label_list(datasets["test"][label_column_name]) 61 | else: 62 | raise ValueError('Evaluation must specify test_file!') 63 | 64 | label_to_id = {l: i for i, l in enumerate(label_list)} 65 | args.num_labels = len(label_list) 66 | 67 | # Tokenize all texts and align the labels with them. 68 | def tokenize_and_align_labels(examples, label_all_tokens=False): 69 | tokenized_inputs = tokenizer( 70 | examples[text_column_name], 71 | padding=False, 72 | truncation=True, 73 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 74 | is_split_into_words=True, 75 | return_offsets_mapping=True, 76 | ) 77 | #print('tokenized_inputs', tokenized_inputs) 78 | 79 | offset_mappings = tokenized_inputs.pop("offset_mapping") 80 | labels = [] 81 | for label, offset_mapping in zip(examples[label_column_name], offset_mappings): 82 | label_index = 0 83 | current_label = -100 84 | label_ids = [] 85 | for offset in offset_mapping: 86 | # We set the label for the first token of each word. Special characters will have an offset of (0, 0) 87 | # so the test ignores them. 88 | if offset[0] == 0 and offset[1] != 0: 89 | current_label = label_to_id[label[label_index]] 90 | label_index += 1 91 | label_ids.append(current_label) 92 | # For special tokens, we set the label to -100 so it's automatically ignored in the loss function. 93 | elif offset[0] == 0 and offset[1] == 0: 94 | label_ids.append(-100) 95 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 96 | # the label_all_tokens flag. 97 | else: 98 | label_ids.append(current_label if label_all_tokens else -100) 99 | 100 | labels.append(label_ids) 101 | tokenized_inputs["labels"] = labels 102 | return tokenized_inputs 103 | 104 | tokenized_datasets = dataset.map( 105 | tokenize_and_align_labels, 106 | # batched=True, **YD** test non-batched results 107 | batched=True, 108 | num_proc=8, 109 | load_from_cache_file=False, 110 | ) 111 | 112 | test_dataset = tokenized_datasets['test'] 113 | # **YD** core code to keep only usefule parameters for model 114 | test_dataset.set_format(type=test_dataset.format["type"], columns=_NER_COLUMNS) 115 | 116 | # **YD** dataloader 117 | data_loader = DataLoader( 118 | test_dataset, 119 | batch_size=args.batch_size, 120 | sampler=None, 121 | collate_fn=data_collator, 122 | drop_last=False, 123 | # num_workers=8, 124 | num_workers=0, 125 | ) 126 | 127 | model = prepare_model(args) 128 | model.cuda() 129 | 130 | true_predictions = [] 131 | true_labels = [] 132 | 133 | for index, input in tqdm(enumerate(data_loader)): 134 | labels, input['labels'] = input['labels'].tolist(), None 135 | 136 | # print(input.keys()) 137 | input = utils.move_to_cuda(input) 138 | predictions = model(**input) 139 | if index == 0: 140 | print('predictions', predictions) 141 | predictions = predictions['logits'] 142 | predictions = torch.argmax(predictions, axis=2).tolist() 143 | if index == 0: 144 | print('labels', labels) 145 | print('predictions', predictions) 146 | 147 | true_predictions.extend([ 148 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 149 | for prediction, label in zip(predictions, labels) 150 | ]) 151 | 152 | true_labels.extend([ 153 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 154 | for prediction, label in zip(predictions, labels) 155 | ]) 156 | 157 | #true_predictions = [true_prediction for true_prediction in true_predictions if true_prediction != []] 158 | #true_labels = [true_label for true_label in true_labels if true_label != []] 159 | 160 | print('true_predictions', true_predictions[0], true_predictions[-1]) 161 | print('true_labels', true_labels[0], true_labels[-1]) 162 | 163 | print( 164 | { 165 | "accuracy_score": accuracy_score(true_labels, true_predictions), 166 | "precision": precision_score(true_labels, true_predictions), 167 | "recall": recall_score(true_labels, true_predictions), 168 | "f1": f1_score(true_labels, true_predictions), 169 | } 170 | ) 171 | 172 | 173 | def prepare_dataset(args): 174 | data_files = {} 175 | if args.test_file is not None: 176 | data_files["test"] = args.test_file 177 | else: 178 | raise ValueError('Evaluation must specify test_file!') 179 | 180 | if args.train_file is not None: 181 | data_files["train"] = args.train_file 182 | if args.validation_file is not None: 183 | data_files["validation"] = args.validation_file 184 | 185 | extension = args.extension_file 186 | print('extension', extension) 187 | print('data_files', data_files) 188 | dataset = load_dataset(extension, data_files=data_files) 189 | return dataset 190 | 191 | 192 | def prepare_tokenizer(args): 193 | config = BertConfig.from_json_file(args.config_file) 194 | # tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=512) 195 | # print('config', type(config), config,) 196 | tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=config.max_position_embeddings) 197 | return tokenizer 198 | 199 | 200 | def prepare_model(args): 201 | config = BertConfig.from_json_file(args.config_file) 202 | config.num_labels = args.num_labels 203 | model = BertForTokenClassification(config) 204 | if args.hetseq_state_dict != '': 205 | # load hetseq state_dictionary 206 | model.load_state_dict(torch.load(args.hetseq_state_dict, map_location='cpu')['model'], strict=True) 207 | 208 | elif args.transformers_state_dict != '': 209 | model.load_state_dict(torch.load(args.transformers_state_dict, map_location='cpu'), strict=True) 210 | return model 211 | 212 | 213 | def get_label_list(labels): 214 | unique_labels = set() 215 | for label in labels: 216 | unique_labels = unique_labels | set(label) 217 | label_list = list(unique_labels) 218 | label_list.sort() 219 | return label_list 220 | 221 | 222 | def cli_main(): 223 | parser = argparse.ArgumentParser(allow_abbrev=False) 224 | 225 | # **YD** datasets argument 226 | parser.add_argument( 227 | '--train_file', 228 | help='local training file for BERT-ner fine-tuning', 229 | type=str, 230 | ) 231 | 232 | parser.add_argument( 233 | '--validation_file', 234 | help='local validation file for BERT-ner fine-tuning', 235 | type=str, 236 | ) 237 | 238 | parser.add_argument( 239 | '--test_file', 240 | help='local test file for BERT-ner fine-tuning', 241 | type=str, 242 | ) 243 | 244 | parser.add_argument( 245 | '--extension_file', 246 | help='local extension file for building dataset similar to conll2003 using "datasets" package', 247 | type=str, 248 | ) 249 | 250 | # **YD** BERT model related argument 251 | parser.add_argument( 252 | '--config_file', 253 | help='BERT config_file', 254 | type=str, 255 | ) 256 | 257 | parser.add_argument( 258 | '--vocab_file', 259 | help='BERT vocabulary file', 260 | type=str, 261 | ) 262 | 263 | parser.add_argument( 264 | '--hetseq_state_dict', 265 | help='hetseq pre-trained state dictionary ', 266 | type=str, 267 | default='', 268 | ) 269 | 270 | parser.add_argument( 271 | '--transformers_state_dict', 272 | help='transformers official pre-trained state dictionary ', 273 | type=str, 274 | default='', 275 | ) 276 | 277 | parser.add_argument( 278 | '--batch_size', 279 | help='batch size for data_loader', 280 | type=int, 281 | default=8, 282 | ) 283 | 284 | args = parser.parse_args() 285 | main(args) 286 | 287 | 288 | if __name__ == '__main__': 289 | cli_main() -------------------------------------------------------------------------------- /test/TRANFORMERS_test_eval_bert_fine_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N eval_ner_fine_tuning 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | CODE=/scratch365/yding4/hetseq/test/TRANFORMERS_test_eval_bert_fine_tuning.py 14 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/ace2004.py 15 | 16 | HETSEQ_STATE_DICT=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_transformers_to_aida_ner/bert_fine_tuning_ner/checkpoint14.pt 17 | TRANSFORMERS_STATE_DICT=/scratch365/yding4/EL_resource/transformer/script/distri_run_m_d_conll/pytorch_model.bin 18 | #TRANSFORMERS_STATE_DICT=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_transformers_to_aida_ner/pytorch_model.bin 19 | DATA_DIR=/scratch365/yding4/EL_resource/data/processed/EL_CONLL_NER_stanza/ 20 | TEST_FILE=ace2004.conll 21 | 22 | 23 | python3 ${CODE} \ 24 | --test_file ${DATA_DIR}${TEST_FILE} \ 25 | --extension_file ${EXTENSION_FILE} \ 26 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 27 | --vocab_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 28 | --hetseq_state_dict ${HETSEQ_STATE_DICT} 29 | #--transformers_state_dict ${TRANSFORMERS_STATE_DICT} 30 | -------------------------------------------------------------------------------- /test/WLGPU_test_conll_2003_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N run_ner 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/home/yding4/hetseq 13 | CODE=/home/yding4/hetseq/test/test_conll_2003_dataset.py 14 | EXTENSION_FILE=/home/yding4/EL_resource/preprocess/build_datasets_huggingface/BI_local_conll2003.py 15 | 16 | TRANSFORMERS_STATE_DICT=/home/yding4/hetseq/preprocessing/transformers_ckpt/pytorch_model.bin 17 | DATA_DIR=/home/yding4/EL_resource/data/CONLL2003/ 18 | TRAIN_FILE=train.txt 19 | VALIDATION_FILE=valid.txt 20 | TEST_FILE=test.txt 21 | 22 | 23 | 24 | python3 ${CODE} \ 25 | --train_file ${DATA_DIR}${TRAIN_FILE} \ 26 | --validation_file ${DATA_DIR}${VALIDATION_FILE} \ 27 | --test_file ${DATA_DIR}${TEST_FILE} \ 28 | --extension_file ${EXTENSION_FILE} \ 29 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 30 | --vocab_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 31 | --transformers_state_dict ${TRANSFORMERS_STATE_DICT} -------------------------------------------------------------------------------- /test/test_conll_2003_NER/test_conll_2003_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from datasets import load_dataset, ClassLabel 5 | 6 | from hetseq.bert_modeling import ( 7 | BertConfig, 8 | BertForPreTraining, 9 | BertForTokenClassification, 10 | ) 11 | 12 | from transformers import ( 13 | BertTokenizerFast, 14 | DataCollatorForTokenClassification, 15 | ) 16 | 17 | import torch 18 | from torch.utils.data.dataloader import DataLoader 19 | 20 | from hetseq.data_collator import YD_DataCollatorForTokenClassification 21 | 22 | _NER_COLUMNS = ['input_ids', 'labels', 'token_type_ids', 'attention_mask'] 23 | 24 | 25 | def main(args): 26 | # 1. prepare dataset from customized dataset 27 | dataset = prepare_dataset(args) 28 | 29 | # 2. prepare tokenizer from customized dictionary and build data_collator 30 | tokenizer = prepare_tokenizer(args) 31 | # data_collator = DataCollatorForTokenClassification(tokenizer) 32 | data_collator = YD_DataCollatorForTokenClassification(tokenizer) 33 | args.data_collator = data_collator 34 | 35 | if 'train' in dataset: 36 | column_names = dataset["train"].column_names 37 | features = dataset["train"].features 38 | elif 'validation' in dataset: 39 | column_names = dataset["validation"].column_names 40 | features = dataset["validation"].features 41 | elif 'test' in dataset: 42 | column_names = dataset["test"].column_names 43 | features = dataset["test"].features 44 | else: 45 | raise ValueError('dataset must contain "train"/"validation"/"test"') 46 | 47 | text_column_name = "tokens" if "tokens" in column_names else column_names[0] 48 | label_column_name = ('ner_tags' if 'ner_tags' in column_names else column_names[1]) 49 | 50 | if isinstance(features[label_column_name].feature, ClassLabel): 51 | label_list = features[label_column_name].feature.names 52 | # No need to convert the labels since they are already ints. 53 | label_to_id = {i: i for i in range(len(label_list))} 54 | else: 55 | if 'train' in label_column_name: 56 | label_list = get_label_list(datasets["train"][label_column_name]) 57 | elif 'validation' in label_column_name: 58 | label_list = get_label_list(datasets["validation"][label_column_name]) 59 | elif 'test' in label_column_name: 60 | label_list = get_label_list(datasets["test"][label_column_name]) 61 | else: 62 | raise ValueError('dataset must contain "train"/"validation"/"test"') 63 | 64 | label_to_id = {l: i for i, l in enumerate(label_list)} 65 | 66 | args.num_labels = len(label_list) 67 | 68 | # Tokenize all texts and align the labels with them. 69 | def tokenize_and_align_labels(examples, label_all_tokens=False): 70 | tokenized_inputs = tokenizer( 71 | examples[text_column_name], 72 | padding=False, 73 | truncation=True, 74 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 75 | is_split_into_words=True, 76 | return_offsets_mapping=True, 77 | ) 78 | #print('tokenized_inputs', tokenized_inputs) 79 | 80 | offset_mappings = tokenized_inputs.pop("offset_mapping") 81 | labels = [] 82 | for label, offset_mapping in zip(examples[label_column_name], offset_mappings): 83 | label_index = 0 84 | current_label = -100 85 | label_ids = [] 86 | for offset in offset_mapping: 87 | # We set the label for the first token of each word. Special characters will have an offset of (0, 0) 88 | # so the test ignores them. 89 | if offset[0] == 0 and offset[1] != 0: 90 | current_label = label_to_id[label[label_index]] 91 | label_index += 1 92 | label_ids.append(current_label) 93 | # For special tokens, we set the label to -100 so it's automatically ignored in the loss function. 94 | elif offset[0] == 0 and offset[1] == 0: 95 | label_ids.append(-100) 96 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 97 | # the label_all_tokens flag. 98 | else: 99 | label_ids.append(current_label if label_all_tokens else -100) 100 | 101 | labels.append(label_ids) 102 | tokenized_inputs["labels"] = labels 103 | return tokenized_inputs 104 | 105 | tokenized_datasets = dataset.map( 106 | tokenize_and_align_labels, 107 | # batched=True, **YD** test non-batched results 108 | batched=True, 109 | num_proc=8, 110 | load_from_cache_file=False, 111 | ) 112 | 113 | # print(tokenized_datasets['train']) 114 | # print(tokenized_datasets['train'][0]) 115 | train_dataset = tokenized_datasets['train'] 116 | 117 | # **YD** core code to keep only usefule parameters for model 118 | train_dataset.set_format(type=train_dataset.format["type"], columns=_NER_COLUMNS) 119 | 120 | 121 | # **YD** dataloader 122 | data_loader = DataLoader( 123 | train_dataset, 124 | batch_size=args.batch_size, 125 | #sampler=train_sampler, 126 | sampler=None, 127 | collate_fn=data_collator, 128 | drop_last=False, 129 | # num_workers=8, 130 | num_workers=0, 131 | ) 132 | 133 | # 3. prepare bert-model loading pre-trained checkpoint 134 | # **YD** num_class is defined by datasets, define model after datasets, in hetseq may require pass extra parameters 135 | model = prepare_model(args) 136 | 137 | for index, input in enumerate(data_loader): 138 | if index == 0: 139 | print('input.keys()', input.keys()) 140 | print('input.items()', input.items()) 141 | print(model(**input)) 142 | print('input_ids shape', input['input_ids'].shape, 'labels shape', input['labels'].shape) 143 | if index == 10: 144 | break 145 | """ 146 | 147 | # test customized dataset 148 | from hetseq.data import BertNerDataset 149 | cus_dataset = BertNerDataset(train_dataset, args) 150 | print('len(cus_dataset)', len(cus_dataset)) 151 | 152 | data_loader = DataLoader( 153 | cus_dataset, 154 | batch_size=args.batch_size, 155 | # sampler=train_sampler, 156 | sampler=None, 157 | collate_fn=args.data_collator, 158 | drop_last=False, 159 | # num_workers=8, 160 | num_workers=0, 161 | ) 162 | 163 | model = prepare_model(args) 164 | 165 | for index, input in enumerate(data_loader): 166 | if index == 0: 167 | print('input.keys()', input.keys(), input.items()) 168 | print(model(**input)) 169 | print('input_ids shape', input['input_ids'].shape, 'labels shape', input['labels'].shape) 170 | if index == 10: 171 | break 172 | """ 173 | 174 | 175 | def prepare_tokenizer(args): 176 | config = BertConfig.from_json_file(args.config_file) 177 | # tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=512) 178 | # print('config', type(config), config,) 179 | tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=config.max_position_embeddings) 180 | return tokenizer 181 | 182 | 183 | def prepare_model(args): 184 | config = BertConfig.from_json_file(args.config_file) 185 | model = BertForTokenClassification(config, args.num_labels) 186 | if args.hetseq_state_dict != '': 187 | # load hetseq state_dictionary 188 | model.load_state_dict(torch.load(args.hetseq_state_dict, map_location='cpu')['model'], strict=False) 189 | elif args.transformers_state_dict != '': 190 | model.load_state_dict(torch.load(args.transformers_state_dict, map_location='cpu'), strict=False) 191 | 192 | return model 193 | 194 | 195 | def prepare_dataset(args): 196 | data_files = {} 197 | if args.train_file is not None: 198 | data_files["train"] = args.train_file 199 | if args.validation_file is not None: 200 | data_files["validation"] = args.validation_file 201 | if args.test_file is not None: 202 | data_files["test"] = args.test_file 203 | 204 | extension = args.extension_file 205 | print('extension', extension) 206 | print('data_files', data_files) 207 | dataset = load_dataset(extension, data_files=data_files) 208 | 209 | return dataset 210 | 211 | 212 | def cli_main(): 213 | parser = argparse.ArgumentParser(allow_abbrev=False) 214 | 215 | # **YD** datasets argument 216 | parser.add_argument( 217 | '--train_file', 218 | help='local training file for BERT-ner fine-tuning', 219 | type=str, 220 | ) 221 | 222 | parser.add_argument( 223 | '--validation_file', 224 | help='local validation file for BERT-ner fine-tuning', 225 | type=str, 226 | ) 227 | 228 | parser.add_argument( 229 | '--test_file', 230 | help='local test file for BERT-ner fine-tuning', 231 | type=str, 232 | ) 233 | 234 | parser.add_argument( 235 | '--extension_file', 236 | help='local extension file for building dataset similar to conll2003 using "datasets" package', 237 | type=str, 238 | ) 239 | 240 | # **YD** BERT model related argument 241 | parser.add_argument( 242 | '--config_file', 243 | help='BERT config_file', 244 | type=str, 245 | ) 246 | 247 | parser.add_argument( 248 | '--vocab_file', 249 | help='BERT vocabulary file', 250 | type=str, 251 | ) 252 | 253 | parser.add_argument( 254 | '--hetseq_state_dict', 255 | help='hetseq pre-trained state dictionary ', 256 | type=str, 257 | default='', 258 | ) 259 | 260 | parser.add_argument( 261 | '--transformers_state_dict', 262 | help='transformers official pre-trained state dictionary ', 263 | type=str, 264 | default='', 265 | ) 266 | 267 | parser.add_argument( 268 | '--batch_size', 269 | help='batch size for data_loader', 270 | type=int, 271 | default=8, 272 | ) 273 | 274 | args = parser.parse_args() 275 | main(args) 276 | 277 | 278 | if __name__ == '__main__': 279 | cli_main() -------------------------------------------------------------------------------- /test/test_conll_2003_NER/test_conll_2003_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N run_ner 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | CODE=/scratch365/yding4/hetseq/test/test_conll_2003_dataset.py 14 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/BI_local_conll2003.py 15 | 16 | HETSEQ_STATE_DICT=/scratch365/yding4/Reddit_EL/RUN_FILE/Reddit_1st_test/node2gpu4/Reddit_512-num_sentece_6_gpu_8-update_4-phase_1/checkpoint_last.pt 17 | DATA_DIR=/scratch365/yding4/EL_resource/data/CONLL2003/ 18 | TRAIN_FILE=train.txt 19 | VALIDATION_FILE=valid.txt 20 | TEST_FILE=test.txt 21 | 22 | 23 | 24 | python3 ${CODE} \ 25 | --train_file ${DATA_DIR}${TRAIN_FILE} \ 26 | --validation_file ${DATA_DIR}${VALIDATION_FILE} \ 27 | --test_file ${DATA_DIR}${TEST_FILE} \ 28 | --extension_file ${EXTENSION_FILE} \ 29 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 30 | --vocab_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 31 | --hetseq_state_dict ${HETSEQ_STATE_DICT} -------------------------------------------------------------------------------- /test/test_entity_vocab/test_diff_dict.py: -------------------------------------------------------------------------------- 1 | # 1. load wikipedia2vec entity vector and vocab. 2 | # 2. load luke entity dictionary. 3 | # 3. load BIO_EL_aida dataset. 4 | # 4. for each entitty in aida, judge whether it was in the top 0.5M of luke dictionary 5 | 6 | # **YD** biggest problem is that, the entities in the evaluation dataset do not show up in the training dataset 7 | import argparse 8 | import collections 9 | 10 | import datasets 11 | import wikipedia2vec 12 | 13 | _NER_VOCAB = {"O": 0, "B": 1, "I": 2} 14 | _UNK_ENTITY_NAME = '{UNK}' 15 | 16 | def prepare_dataset(args): 17 | data_files = {} 18 | if args.train_file is not None: 19 | data_files["train"] = args.train_file 20 | if args.validation_file is not None: 21 | data_files["validation"] = args.validation_file 22 | if args.test_file is not None: 23 | data_files["test"] = args.test_file 24 | 25 | extension = args.extension_file 26 | print('extension', extension) 27 | print('data_files', data_files) 28 | dataset = datasets.load_dataset(extension, data_files=data_files) 29 | 30 | return dataset 31 | 32 | 33 | def load_args(): 34 | parser = argparse.ArgumentParser(allow_abbrev=False) 35 | 36 | # **YD** datasets argument 37 | 38 | parser.add_argument( 39 | '--extension_file', 40 | help='local extension file for building dataset similar to conll2003 using "datasets" package', 41 | type=str, 42 | default='/scratch365/yding4/EL_resource/preprocess/build_EL_datasets_huggingface/BIO_EL_aida.py', 43 | ) 44 | 45 | parser.add_argument( 46 | '--train_file', 47 | help='local training file for BERT-ner fine-tuning', 48 | type=str, 49 | default='/scratch365/yding4/EL_resource/data/raw/AIDA-CONLL/aida_train.txt', 50 | ) 51 | 52 | parser.add_argument( 53 | '--validation_file', 54 | help='local validation file for BERT-ner fine-tuning', 55 | type=str, 56 | default='/scratch365/yding4/EL_resource/data/raw/AIDA-CONLL/testa_testb_aggregate_original', 57 | ) 58 | 59 | parser.add_argument( 60 | '--test_file', 61 | help='local test file for BERT-ner fine-tuning', 62 | type=str, 63 | default='/scratch365/yding4/EL_resource/data/raw/AIDA-CONLL/testa_testb_aggregate_original', 64 | ) 65 | 66 | parser.add_argument( 67 | '--wikipedia2vec_file', 68 | help='wikipedia2vec_file', 69 | type=str, 70 | default='/scratch365/yding4/hetseq/preprocessing/wikipedia2vec/enwiki_20180420_100d.txt', 71 | ) 72 | 73 | parser.add_argument( 74 | '--luke_entity_dict_file', 75 | help='luke_entity_dict_file', 76 | type=str, 77 | default='/scratch365/yding4/luke_dir/luke_large_500k/entity_vocab.tsv', 78 | ) 79 | 80 | args = parser.parse_args() 81 | return args 82 | 83 | 84 | def load_luke_entity_dict(args): 85 | dic = collections.OrderedDict() 86 | with open(args.luke_entity_dict_file, 'r') as reader: 87 | for line in reader: 88 | line = line.rstrip() 89 | parts = line.split('\t') 90 | assert len(parts) == 2 91 | dic[parts[0]] = parts[1] 92 | 93 | return dic 94 | 95 | 96 | def collect_entity_freq_from_aida(dataset): 97 | ans = dict() 98 | for split in ['train', 'validation', 'test']: 99 | if split in dataset: 100 | dataset_split = dataset[split] 101 | for index in range(len(dataset_split)): 102 | instance = dataset_split[index] 103 | for entity_name, ner_tag in zip(instance['entity_names'], instance['ner_tags']): 104 | if ner_tag == _NER_VOCAB['B'] and entity_name != _UNK_ENTITY_NAME: 105 | entity_name = entity_name.replace('_', ' ') 106 | if entity_name not in ans: 107 | ans[entity_name] = 0 108 | ans[entity_name] += 1 109 | return ans 110 | 111 | 112 | if __name__ == '__main__': 113 | """ 114 | args = load_args() 115 | 116 | # 1. load wikipedia2vec entity vector and vocab. 117 | print('====> loading wikipedia2vec entity vector') 118 | wikipedia2vec_dict = wikipedia2vec.Wikipedia2Vec.load_text(args.wikipedia2vec_file) 119 | print('--------> loading wikipedia2vec entity vector finished') 120 | 121 | # 2. load luke entity dictionary. 122 | luke_entity_dict = load_luke_entity_dict(args) 123 | 124 | # 3.load BIO_EL_aida dataset. 125 | dataset = prepare_dataset(args) 126 | 127 | # 4. for each entitty in aida, judge whether it was in the top 0.5M of luke dictionary 128 | # 4-1. collect entity_names from BIO_EL_aida dataset, count the entity frequency in a dictionary. 129 | entity_freq_from_aida = collect_entity_freq_from_aida(dataset) 130 | """ 131 | 132 | """ 133 | # 4-1.a: change the space '_' with space 134 | # 4-2. collect entity_names from BIO_EL_aida dataset, count the entity frequency in a dictionary. 135 | find_entity_freq, not_find_entity_freq = 0, 0 136 | find_entity_set, not_find_entity_set = set(), set() 137 | for key, value in entity_freq_from_aida.items(): 138 | if wikipedia2vec_dict.get_entity(key): 139 | find_entity_set.add(key) 140 | find_entity_freq += value 141 | else: 142 | not_find_entity_set.add(key) 143 | not_find_entity_freq += value 144 | 145 | print(find_entity_freq, not_find_entity_freq) 146 | 147 | find_entity_freq, not_find_entity_freq = 0, 0 148 | find_entity_set, not_find_entity_set = set(), set() 149 | for key, value in entity_freq_from_aida.items(): 150 | if wikipedia2vec_dict.get_entity(key): 151 | if key in luke_entity_dict: 152 | find_entity_set.add(key) 153 | find_entity_freq += value 154 | else: 155 | not_find_entity_set.add(key) 156 | not_find_entity_freq += value 157 | else: 158 | not_find_entity_set.add(key) 159 | not_find_entity_freq += value 160 | 161 | print(find_entity_freq, not_find_entity_freq) 162 | """ 163 | 164 | find_entity_freq, not_find_entity_freq = 0, 0 165 | find_entity_set, not_find_entity_set = set(), set() 166 | for key in luke_entity_dict: 167 | if wikipedia2vec_dict.get_entity(key): 168 | find_entity_set.add(key) 169 | find_entity_freq += value 170 | else: 171 | not_find_entity_set.add(key) 172 | not_find_entity_freq += value 173 | 174 | print(find_entity_freq, not_find_entity_freq) -------------------------------------------------------------------------------- /test/test_eval_bert_fine_tuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from datasets import load_dataset, ClassLabel 5 | 6 | from hetseq.bert_modeling import ( 7 | BertConfig, 8 | BertForPreTraining, 9 | BertForTokenClassification, 10 | ) 11 | 12 | from hetseq import utils 13 | from hetseq.data_collator import YD_DataCollatorForTokenClassification 14 | 15 | from transformers import ( 16 | BertTokenizerFast, 17 | DataCollatorForTokenClassification, 18 | ) 19 | 20 | import torch 21 | from torch.utils.data.dataloader import DataLoader 22 | 23 | from tqdm import tqdm 24 | from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score 25 | 26 | 27 | __DATE__ = '2020/1/3' 28 | __AUTHOR__ = 'Yifan_Ding' 29 | __E_MAIL = 'dyf0125@gmail.edu' 30 | 31 | _NER_COLUMNS = ['input_ids', 'labels', 'token_type_ids', 'attention_mask'] 32 | 33 | # 1. load dataset, either from CONLL2003 or other entity linking datasets. 34 | # 2. load model from a trained one, load model architecture and state_dict from a trained one. 35 | # 3. predict loss, generate predicted label 36 | # 4. compare predicted label with ground truth label to obtain evaluation results. 37 | 38 | 39 | def main(args): 40 | dataset = prepare_dataset(args) 41 | tokenizer = prepare_tokenizer(args) 42 | # data_collator = YD_DataCollatorForTokenClassification(tokenizer) 43 | data_collator = DataCollatorForTokenClassification(tokenizer) 44 | 45 | if 'test' in dataset: 46 | column_names = dataset["test"].column_names 47 | features = dataset["test"].features 48 | else: 49 | raise ValueError('Evaluation must specify test_file!') 50 | 51 | text_column_name = "tokens" if "tokens" in column_names else column_names[0] 52 | label_column_name = ('ner_tags' if 'ner_tags' in column_names else column_names[1]) 53 | 54 | if isinstance(features[label_column_name].feature, ClassLabel): 55 | label_list = features[label_column_name].feature.names 56 | # No need to convert the labels since they are already ints. 57 | label_to_id = {i: i for i in range(len(label_list))} 58 | else: 59 | if 'test' in label_column_name: 60 | label_list = get_label_list(datasets["test"][label_column_name]) 61 | else: 62 | raise ValueError('Evaluation must specify test_file!') 63 | 64 | label_to_id = {l: i for i, l in enumerate(label_list)} 65 | args.num_labels = len(label_list) 66 | 67 | # Tokenize all texts and align the labels with them. 68 | def tokenize_and_align_labels(examples, label_all_tokens=False): 69 | tokenized_inputs = tokenizer( 70 | examples[text_column_name], 71 | padding=False, 72 | truncation=True, 73 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 74 | is_split_into_words=True, 75 | return_offsets_mapping=True, 76 | ) 77 | #print('tokenized_inputs', tokenized_inputs) 78 | 79 | offset_mappings = tokenized_inputs.pop("offset_mapping") 80 | labels = [] 81 | for label, offset_mapping in zip(examples[label_column_name], offset_mappings): 82 | label_index = 0 83 | current_label = -100 84 | label_ids = [] 85 | for offset in offset_mapping: 86 | # We set the label for the first token of each word. Special characters will have an offset of (0, 0) 87 | # so the test ignores them. 88 | if offset[0] == 0 and offset[1] != 0: 89 | current_label = label_to_id[label[label_index]] 90 | label_index += 1 91 | label_ids.append(current_label) 92 | # For special tokens, we set the label to -100 so it's automatically ignored in the loss function. 93 | elif offset[0] == 0 and offset[1] == 0: 94 | label_ids.append(-100) 95 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 96 | # the label_all_tokens flag. 97 | else: 98 | label_ids.append(current_label if label_all_tokens else -100) 99 | 100 | labels.append(label_ids) 101 | tokenized_inputs["labels"] = labels 102 | return tokenized_inputs 103 | 104 | tokenized_datasets = dataset.map( 105 | tokenize_and_align_labels, 106 | # batched=True, **YD** test non-batched results 107 | batched=True, 108 | num_proc=8, 109 | load_from_cache_file=False, 110 | ) 111 | 112 | test_dataset = tokenized_datasets['test'] 113 | # **YD** core code to keep only usefule parameters for model 114 | test_dataset.set_format(type=test_dataset.format["type"], columns=_NER_COLUMNS) 115 | 116 | # **YD** dataloader 117 | data_loader = DataLoader( 118 | test_dataset, 119 | batch_size=args.batch_size, 120 | sampler=None, 121 | collate_fn=data_collator, 122 | drop_last=False, 123 | # num_workers=8, 124 | num_workers=0, 125 | ) 126 | 127 | model = prepare_model(args) 128 | model.cuda() 129 | 130 | true_predictions = [] 131 | true_labels = [] 132 | 133 | for index, input in tqdm(enumerate(data_loader)): 134 | labels, input['labels'] = input['labels'].tolist(), None 135 | 136 | # print(input.keys()) 137 | input = utils.move_to_cuda(input) 138 | predictions = model(**input) 139 | if index == 0: 140 | print('predictions', predictions) 141 | predictions = torch.argmax(predictions, axis=2).tolist() 142 | if index == 0: 143 | print('labels', labels) 144 | print('predictions', predictions) 145 | 146 | true_predictions.extend([ 147 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 148 | for prediction, label in zip(predictions, labels) 149 | ]) 150 | 151 | true_labels.extend([ 152 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 153 | for prediction, label in zip(predictions, labels) 154 | ]) 155 | 156 | true_predictions = [true_prediction for true_prediction in true_predictions if true_prediction != []] 157 | true_labels = [true_label for true_label in true_labels if true_label != []] 158 | 159 | print('true_predictions', true_predictions[0], true_predictions[-1]) 160 | print('true_labels', true_labels[0], true_labels[-1]) 161 | 162 | print( 163 | { 164 | "accuracy_score": accuracy_score(true_labels, true_predictions), 165 | "precision": precision_score(true_labels, true_predictions), 166 | "recall": recall_score(true_labels, true_predictions), 167 | "f1": f1_score(true_labels, true_predictions), 168 | } 169 | ) 170 | 171 | def prepare_dataset(args): 172 | data_files = {} 173 | if args.test_file is not None: 174 | data_files["test"] = args.test_file 175 | else: 176 | raise ValueError('Evaluation must specify test_file!') 177 | 178 | if args.train_file is not None: 179 | data_files["train"] = args.train_file 180 | if args.validation_file is not None: 181 | data_files["validation"] = args.validation_file 182 | 183 | extension = args.extension_file 184 | print('extension', extension) 185 | print('data_files', data_files) 186 | dataset = load_dataset(extension, data_files=data_files) 187 | return dataset 188 | 189 | 190 | def prepare_tokenizer(args): 191 | config = BertConfig.from_json_file(args.config_file) 192 | # tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=512) 193 | # print('config', type(config), config,) 194 | tokenizer = BertTokenizerFast(args.vocab_file, model_max_length=config.max_position_embeddings) 195 | return tokenizer 196 | 197 | 198 | def prepare_model(args): 199 | config = BertConfig.from_json_file(args.config_file) 200 | model = BertForTokenClassification(config, args.num_labels) 201 | if args.hetseq_state_dict != '': 202 | # load hetseq state_dictionary 203 | model.load_state_dict(torch.load(args.hetseq_state_dict, map_location='cpu')['model'], strict=True) 204 | 205 | return model 206 | 207 | 208 | def get_label_list(labels): 209 | unique_labels = set() 210 | for label in labels: 211 | unique_labels = unique_labels | set(label) 212 | label_list = list(unique_labels) 213 | label_list.sort() 214 | return label_list 215 | 216 | 217 | def cli_main(): 218 | parser = argparse.ArgumentParser(allow_abbrev=False) 219 | 220 | # **YD** datasets argument 221 | parser.add_argument( 222 | '--train_file', 223 | help='local training file for BERT-ner fine-tuning', 224 | type=str, 225 | ) 226 | 227 | parser.add_argument( 228 | '--validation_file', 229 | help='local validation file for BERT-ner fine-tuning', 230 | type=str, 231 | ) 232 | 233 | parser.add_argument( 234 | '--test_file', 235 | help='local test file for BERT-ner fine-tuning', 236 | type=str, 237 | ) 238 | 239 | parser.add_argument( 240 | '--extension_file', 241 | help='local extension file for building dataset similar to conll2003 using "datasets" package', 242 | type=str, 243 | ) 244 | 245 | # **YD** BERT model related argument 246 | parser.add_argument( 247 | '--config_file', 248 | help='BERT config_file', 249 | type=str, 250 | ) 251 | 252 | parser.add_argument( 253 | '--vocab_file', 254 | help='BERT vocabulary file', 255 | type=str, 256 | ) 257 | 258 | parser.add_argument( 259 | '--hetseq_state_dict', 260 | help='hetseq pre-trained state dictionary ', 261 | type=str, 262 | default='', 263 | ) 264 | 265 | parser.add_argument( 266 | '--batch_size', 267 | help='batch size for data_loader', 268 | type=int, 269 | default=8, 270 | ) 271 | 272 | args = parser.parse_args() 273 | main(args) 274 | 275 | 276 | if __name__ == '__main__': 277 | cli_main() -------------------------------------------------------------------------------- /test/test_eval_bert_fine_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N eval_ner_fine_tuning 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | CODE=/scratch365/yding4/hetseq/test/test_eval_bert_fine_tuning.py 14 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/ace2004.py 15 | 16 | HETSEQ_STATE_DICT=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_reddit_to_aida_ner/bert_fine_tuning_ner/checkpoint1.pt 17 | DATA_DIR=/scratch365/yding4/EL_resource/data/processed/EL_CONLL_NER_stanza/ 18 | TEST_FILE=ace2004.conll 19 | 20 | 21 | python3 ${CODE} \ 22 | --test_file ${DATA_DIR}${TEST_FILE} \ 23 | --extension_file ${EXTENSION_FILE} \ 24 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 25 | --vocab_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 26 | --hetseq_state_dict ${HETSEQ_STATE_DICT} -------------------------------------------------------------------------------- /test/test_eval_bert_fine_tuning_aida.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-m abe 4 | #$-M yding4@nd.edu 5 | #$-q gpu@qa-xp-003 # specify the queue 6 | #$-l gpu_card=4 7 | #$-N eval_ner_fine_tuning 8 | 9 | export PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/bin:$PATH 10 | export LD_LIBRARY_PATH=/afs/crc.nd.edu/user/y/yding4/.conda/envs/hetseq/lib:$LD_LIBRARY_PATH 11 | 12 | DIST=/scratch365/yding4/hetseq 13 | CODE=/scratch365/yding4/hetseq/test/test_eval_bert_fine_tuning.py 14 | EXTENSION_FILE=/scratch365/yding4/EL_resource/preprocess/build_datasets_huggingface/BI_local_conll2003.py 15 | 16 | HETSEQ_STATE_DICT=/scratch365/yding4/hetseq/CRC_RUN_FILE/Train_bert_fine_tuning_ner/bert_from_reddit_to_aida_ner/bert_fine_tuning_ner/checkpoint_last.pt 17 | DATA_DIR=/scratch365/yding4/EL_resource/data/CONLL2003/ 18 | TRAIN_FILE=train.txt 19 | VALIDATION_FILE=valid.txt 20 | TEST_FILE=test.txt 21 | 22 | 23 | python3 ${CODE} \ 24 | --train_file ${DATA_DIR}${TRAIN_FILE} \ 25 | --validation_file ${DATA_DIR}${VALIDATION_FILE} \ 26 | --test_file ${DATA_DIR}${TRAIN_FILE} \ 27 | --extension_file ${EXTENSION_FILE} \ 28 | --config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \ 29 | --vocab_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \ 30 | --hetseq_state_dict ${HETSEQ_STATE_DICT} 31 | 32 | #--test_file ${DATA_DIR}${TEST_FILE} \ 33 | --------------------------------------------------------------------------------