├── LICENSE.txt ├── README.md ├── data ├── README.md ├── sample.l1 ├── sample.l2 └── sample.pair ├── experiments ├── run_extract.sh └── run_finetuning.sh ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── run_qa_alignment.py ├── scripts ├── SQuAD.py ├── doc_to_overlaped_squad.py ├── fileUtils.py ├── get_sent_align_for_overlap.py ├── mergeSQuADjson.py ├── pairDoc2SQuAD.py └── score.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software ("User(s)"), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User's acceptance of the Agreement, or (ii) upon User's installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by Userfs decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i)?SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 19 | 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARDLESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 23 | 9. General 24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 29 | (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT's obligation set forth under this Agreement due to any cause beyond NTTfs reasonable control. 30 | 31 | 32 | EXHIBIT A 33 | The software and related documentation in this repository, excluding sample data. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpanAlign: Sentence Alignment Method based on Cross-Language Span Prediction and ILP 2 | This repository includes the software described in "[SpanAlign: Sentence Alignment Method based on Cross-Language Span Prediction and ILP](https://www.aclweb.org/anthology/2020.coling-main.418/)" published at COLING'20. 3 | 4 | ## Setup 5 | This software is tested on the following. 6 | * 1 GeForce RTX 2080Ti 7 | * Python 3.8.6 8 | * CUDA 10.1 9 | * ILOG CLPEX 12.8.0.0 and 20.1 academic version 10 | * torch 1.7.1+cu101 11 | * transformers 4.1.1 12 | * nltk 13 | * tqdm 4.54.1 14 | * h5py 2.10.0 15 | * sentencepiece 0.1.94 16 | * tensorboardX 2.1 17 | * tabulate 0.8.7 18 | 19 | These python libraries may be installed by using pip as follows: 20 | ```sh 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## How to use with sample data 25 | ### Preprocessing 26 | First, you need to convert dataset to the json format of SQuAD 2.0 QA task. (For this example, we create train/dev/test json files from one sample data.) 27 | ```sh 28 | # Preprocessing for train/development sets 29 | python scripts/pairDoc2SQuAD.py data/sample.{pair,l1,l2} data/train.json train 30 | python scripts/pairDoc2SQuAD.py data/sample.{pair,l1,l2} data/dev.json dev 31 | # Preprocessing for a test set 32 | python scripts/doc_to_overlaped_squad.py data/sample.{l1,l2} data/test.json -t test 33 | ``` 34 | 35 | When you have a lot of json files of training/development set, you can separately convert these files and merge them into one file by using `scripts/mergeSQuADjson.py`, like this: 36 | ```sh 37 | python scripts/mergeSQuADjson.py -s data/train*.json --squad_version 2.0 data/train.json 38 | ``` 39 | 40 | ### Fine-tuning Model and Get Alignments 41 | According to your environment, You need to rewrite `__SET_DIR_PATH__` to `the root path of this repository` in `experiments/run_{finetuning,extract}.sh` and `__SET_CPLEX_PATH__` to `the executable path of CPLEX` in `scripts/get_sent_align_for_overlap.py`. 42 | These commands will fine-tune XLM-RoBERTa for sentence alignment. 43 | ```sh 44 | cd experiments 45 | sh ./run_finetuning.sh 46 | ``` 47 | 48 | The following script `run_extract.sh` do the cross-langauge span prediction for extracting alignment hypothesises and optimize these alignments by using Integer Linear Programming. 49 | ```sh 50 | sh ./run_extract.sh ./finetuning ./output ../data/test.json test_sample 51 | ``` 52 | 53 | The alignments results with three symmetization methods are at `experiments/output/test/test.{e2f,f2e,bidi}.pair`. 54 | ```shell 55 | $ cat output/test/test.bidi.pair 56 | []:[1]:1.0000 57 | [12,13]:[2,3,4,5,6,7,8,9,10,11,12,13,14]:7.0965 58 | [1]:[]:1.0000 59 | [2,3,4,5,6,7,8,9,10,11]:[15]:5.1970 60 | ``` 61 | 62 | Here, `bidi` means bi-directional symmetization in our paper. 63 | 64 | ### Evaluate 65 | Sentence Alignment accuracies can be calculated as follows. 66 | ```sh 67 | $ python ../scripts/score.py -g ../data/sample.pair -t ./output/test/test.bidi.pair 68 | --------------------------------- 69 | | | Strict | Lax | 70 | | Precision | 0.000 | 0.500 | 71 | | Recall | 0.000 | 0.250 | 72 | | F1 | 0.000 | 0.333 | 73 | --------------------------------- 74 | trg/src 0 1 75 | --------- --------------------- ---------------------- 76 | 0 0.000/0.000/0.000 (0) 0.000/0.000/0.000 (0) 77 | 1 0.000/0.000/0.000 (0) 0.000/0.000/0.000 (11) 78 | 2 0.000/0.000/0.000 (0) 0.000/0.000/0.000 (0) 79 | 3 0.000/0.000/0.000 (0) 0.000/0.000/0.000 (1) 80 | ``` 81 | We found that a part of this result is sometimes increased/decreased when we used a different architecture of GPU. 82 | 83 | ## License 84 | This software is released under the NTT License, see `LICENSE.txt`. 85 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Sample data 2 | * `sample.l1`: a document in English. 3 | * `sample.l2`: a document in Japanese. 4 | * `sample.pair`: sentence alignments between `sample.l1` and `sample.l2`. These alignments are annotated by us. 5 | 6 | ### License 7 | `sample.l1` and `sample.l2` are extracted from the [tatoeba corpus](https://tatoeba.org/jpn/). 8 | Some sentences are concatenate to create an example of many-to-many sentence alignment. 9 | For the license of the Tatoeba corpus, please see [this page](https://tatoeba.org/eng/terms_of_use). 10 | -------------------------------------------------------------------------------- /data/sample.l1: -------------------------------------------------------------------------------- 1 | We can not learn Japanese without learning Kanji. 2 | I learned many things about Greek culture. 3 | It's been a while since I've been here, but nothing has changed. 4 | I understand what you're saying. 5 | In other words, he is a man of faith. 6 | The principal severely reproved the students whenever they made a mess in the hallway. 7 | In other words, it takes all sorts of people to make a world. 8 | He writes a letter. 9 | In other words, reliability is impossible unless there is a natural warmth. 10 | She lost to him in tennis. 11 | She loves him all the more because he has faults. 12 | Only two things are infinite, the universe and human stupidity, and I'm not sure about the former. 13 | What do I have to do to make you believe me? 14 | -------------------------------------------------------------------------------- /data/sample.l2: -------------------------------------------------------------------------------- 1 | 言う価値のあることがなければ、しゃべるな。 2 | 言い換えると、世の中にはいろいろな人間が必要だということだ。 3 | 久しぶりに来たけど全然変わってないな。 4 | 彼女は彼に欠点があるからかえって彼を愛している。 5 | 校長先生は、生徒が廊下を散らかしたときは、きびしく叱りました。 6 | 漢字を学ばないで日本語の勉強はできない。 7 | 言いかえると、生来の温かさがない限り信頼性は不可能だ。 8 | 彼女は彼にテニスで負けてしまった。 9 | どうしたら信じてくれる? 10 | 果てがないものは二つだけある。 11 | 宇宙と人間の愚かさだ。 12 | 前者については確かではないが。 13 | 私はギリシャ文明について多くのことを学びました。 14 | 言いたいことは分かってるよ。 15 | 彼は手紙を書く。 16 | -------------------------------------------------------------------------------- /data/sample.pair: -------------------------------------------------------------------------------- 1 | [1]:[6] 2 | [2]:[13] 3 | [3]:[3] 4 | [4]:[14] 5 | [6]:[5] 6 | [7]:[2] 7 | [8]:[15] 8 | [9]:[7] 9 | [10]:[8] 10 | [11]:[4] 11 | [12]:[10,11,12] 12 | [13]:[9] 13 | -------------------------------------------------------------------------------- /experiments/run_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # created by Katsuki Chousa 4 | # updated on Dec. 22, 2020 by Katsuki Chousa 5 | 6 | PROJECT_DIR=__SET_DIR_PATH__ 7 | EXPERIMENT_DIR=$PROJECT_DIR/experiments 8 | MODEL_TYPE=xlm-roberta 9 | DATA_DIR=$PROJECT_DIR/data 10 | 11 | if [ $# -ne 4 ]; then 12 | echo "$0 model_path output_dir TEST.json test_title" 13 | exit 1 14 | fi 15 | 16 | model_path=$1 17 | test_file=$3 18 | test_prefix=`basename $test_file .json` 19 | output_dir=$2/$test_prefix 20 | test_title=$4 21 | 22 | date 23 | hostname 24 | echo $EXPERIMENT_DIR 25 | 26 | echo "" 27 | echo "### extraction ###" 28 | mkdir -p $output_dir 29 | 30 | if [ ! -e $output_dir/nbest_predictions_.json ]; then 31 | python $PROJECT_DIR/run_qa_alignment.py \ 32 | --model_type $MODEL_TYPE \ 33 | --model_name_or_path $model_path \ 34 | --version_2_with_negative \ 35 | --do_eval \ 36 | --predict_file $test_file \ 37 | --max_seq_length 384 \ 38 | --max_query_length 158 \ 39 | --max_answer_length 158 \ 40 | --doc_stride 64 \ 41 | --n_best_size 10 \ 42 | --data_dir $output_dir \ 43 | --output_dir $output_dir \ 44 | --overwrite_output_dir \ 45 | --save_steps 5000 \ 46 | --per_gpu_eval_batch_size 240 \ 47 | --thread 8 2>&1 \ 48 | | tee $output_dir/span_prediction.log 49 | fi 50 | 51 | python $PROJECT_DIR/scripts/get_sent_align_for_overlap.py \ 52 | --nbest 1 \ 53 | $DATA_DIR/sample.{l1,l2} \ 54 | $output_dir/nbest_predictions_.json \ 55 | $test_title \ 56 | $output_dir/$test_prefix \ 57 | -------------------------------------------------------------------------------- /experiments/run_finetuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # created by Katsuki Chousa 4 | # updated on Dec. 22, 2020 by Katsuki Chousa 5 | 6 | PROJECT_DIR=__SET_DIR_PATH__ 7 | EXPERIMENT_DIR=$PROJECT_DIR/experiments 8 | OUTPUT_DIR=$EXPERIMENT_DIR/finetuning 9 | 10 | DATA_DIR=$PROJECT_DIR/data 11 | TRAIN_FILE=$DATA_DIR/train.json 12 | DEV_FILE=$DATA_DIR/dev.json 13 | 14 | MODEL_TYPE=xlm-roberta 15 | MODEL_NAME=xlm-roberta-base 16 | 17 | date 18 | hostname 19 | echo $EXPERIMENT_DIR 20 | 21 | echo "" 22 | echo "### finetuning ###" 23 | mkdir -p $OUTPUT_DIR 24 | python $PROJECT_DIR/run_qa_alignment.py \ 25 | --model_type $MODEL_TYPE \ 26 | --model_name_or_path $MODEL_NAME \ 27 | --version_2_with_negative \ 28 | --do_train \ 29 | --do_eval \ 30 | --eval_all_checkpoints \ 31 | --train_file $TRAIN_FILE \ 32 | --predict_file $DEV_FILE \ 33 | --learning_rate 3e-5 \ 34 | --per_gpu_train_batch_size 5 \ 35 | --num_train_epochs 5 \ 36 | --max_seq_length 384 \ 37 | --max_query_length 158 \ 38 | --max_answer_length 158 \ 39 | --doc_stride 64 \ 40 | --n_best_size 10 \ 41 | --data_dir $OUTPUT_DIR \ 42 | --output_dir $OUTPUT_DIR \ 43 | --overwrite_output_dir \ 44 | --save_steps 5000 \ 45 | --thread 4 2>&1 \ 46 | | tee $EXPERIMENT_DIR/finetuning.log 47 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "appnope" 3 | version = "0.1.2" 4 | description = "Disable App Nap on macOS >= 10.9" 5 | category = "dev" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "backcall" 11 | version = "0.2.0" 12 | description = "Specifications for callback functions passed in to an API" 13 | category = "dev" 14 | optional = false 15 | python-versions = "*" 16 | 17 | [[package]] 18 | name = "certifi" 19 | version = "2020.12.5" 20 | description = "Python package for providing Mozilla's CA Bundle." 21 | category = "main" 22 | optional = false 23 | python-versions = "*" 24 | 25 | [[package]] 26 | name = "chardet" 27 | version = "4.0.0" 28 | description = "Universal encoding detector for Python 2 and 3" 29 | category = "main" 30 | optional = false 31 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 32 | 33 | [[package]] 34 | name = "click" 35 | version = "7.1.2" 36 | description = "Composable command line interface toolkit" 37 | category = "main" 38 | optional = false 39 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 40 | 41 | [[package]] 42 | name = "colorama" 43 | version = "0.4.4" 44 | description = "Cross-platform colored terminal text." 45 | category = "dev" 46 | optional = false 47 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 48 | 49 | [[package]] 50 | name = "decorator" 51 | version = "4.4.2" 52 | description = "Decorators for Humans" 53 | category = "dev" 54 | optional = false 55 | python-versions = ">=2.6, !=3.0.*, !=3.1.*" 56 | 57 | [[package]] 58 | name = "filelock" 59 | version = "3.0.12" 60 | description = "A platform independent file lock." 61 | category = "main" 62 | optional = false 63 | python-versions = "*" 64 | 65 | [[package]] 66 | name = "h5py" 67 | version = "2.10.0" 68 | description = "Read and write HDF5 files from Python" 69 | category = "main" 70 | optional = false 71 | python-versions = "*" 72 | 73 | [package.dependencies] 74 | numpy = ">=1.7" 75 | six = "*" 76 | 77 | [[package]] 78 | name = "idna" 79 | version = "2.10" 80 | description = "Internationalized Domain Names in Applications (IDNA)" 81 | category = "main" 82 | optional = false 83 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 84 | 85 | [[package]] 86 | name = "ipython" 87 | version = "7.19.0" 88 | description = "IPython: Productive Interactive Computing" 89 | category = "dev" 90 | optional = false 91 | python-versions = ">=3.7" 92 | 93 | [package.dependencies] 94 | appnope = {version = "*", markers = "sys_platform == \"darwin\""} 95 | backcall = "*" 96 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 97 | decorator = "*" 98 | jedi = ">=0.10" 99 | pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} 100 | pickleshare = "*" 101 | prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0" 102 | pygments = "*" 103 | traitlets = ">=4.2" 104 | 105 | [package.extras] 106 | all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.14)", "pygments", "qtconsole", "requests", "testpath"] 107 | doc = ["Sphinx (>=1.3)"] 108 | kernel = ["ipykernel"] 109 | nbconvert = ["nbconvert"] 110 | nbformat = ["nbformat"] 111 | notebook = ["notebook", "ipywidgets"] 112 | parallel = ["ipyparallel"] 113 | qtconsole = ["qtconsole"] 114 | test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.14)"] 115 | 116 | [[package]] 117 | name = "ipython-genutils" 118 | version = "0.2.0" 119 | description = "Vestigial utilities from IPython" 120 | category = "dev" 121 | optional = false 122 | python-versions = "*" 123 | 124 | [[package]] 125 | name = "jedi" 126 | version = "0.17.2" 127 | description = "An autocompletion tool for Python that can be used for text editors." 128 | category = "dev" 129 | optional = false 130 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 131 | 132 | [package.dependencies] 133 | parso = ">=0.7.0,<0.8.0" 134 | 135 | [package.extras] 136 | qa = ["flake8 (==3.7.9)"] 137 | testing = ["Django (<3.1)", "colorama", "docopt", "pytest (>=3.9.0,<5.0.0)"] 138 | 139 | [[package]] 140 | name = "joblib" 141 | version = "1.0.0" 142 | description = "Lightweight pipelining with Python functions" 143 | category = "main" 144 | optional = false 145 | python-versions = ">=3.6" 146 | 147 | [[package]] 148 | name = "numpy" 149 | version = "1.19.4" 150 | description = "NumPy is the fundamental package for array computing with Python." 151 | category = "main" 152 | optional = false 153 | python-versions = ">=3.6" 154 | 155 | [[package]] 156 | name = "packaging" 157 | version = "20.8" 158 | description = "Core utilities for Python packages" 159 | category = "main" 160 | optional = false 161 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 162 | 163 | [package.dependencies] 164 | pyparsing = ">=2.0.2" 165 | 166 | [[package]] 167 | name = "parso" 168 | version = "0.7.1" 169 | description = "A Python Parser" 170 | category = "dev" 171 | optional = false 172 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 173 | 174 | [package.extras] 175 | testing = ["docopt", "pytest (>=3.0.7)"] 176 | 177 | [[package]] 178 | name = "pexpect" 179 | version = "4.8.0" 180 | description = "Pexpect allows easy control of interactive console applications." 181 | category = "dev" 182 | optional = false 183 | python-versions = "*" 184 | 185 | [package.dependencies] 186 | ptyprocess = ">=0.5" 187 | 188 | [[package]] 189 | name = "pickleshare" 190 | version = "0.7.5" 191 | description = "Tiny 'shelve'-like database with concurrency support" 192 | category = "dev" 193 | optional = false 194 | python-versions = "*" 195 | 196 | [[package]] 197 | name = "prompt-toolkit" 198 | version = "3.0.8" 199 | description = "Library for building powerful interactive command lines in Python" 200 | category = "dev" 201 | optional = false 202 | python-versions = ">=3.6.1" 203 | 204 | [package.dependencies] 205 | wcwidth = "*" 206 | 207 | [[package]] 208 | name = "protobuf" 209 | version = "3.14.0" 210 | description = "Protocol Buffers" 211 | category = "main" 212 | optional = false 213 | python-versions = "*" 214 | 215 | [package.dependencies] 216 | six = ">=1.9" 217 | 218 | [[package]] 219 | name = "ptyprocess" 220 | version = "0.6.0" 221 | description = "Run a subprocess in a pseudo terminal" 222 | category = "dev" 223 | optional = false 224 | python-versions = "*" 225 | 226 | [[package]] 227 | name = "pudb" 228 | version = "2019.2" 229 | description = "A full-screen, console-based Python debugger" 230 | category = "dev" 231 | optional = false 232 | python-versions = "*" 233 | 234 | [package.dependencies] 235 | pygments = ">=1.0" 236 | urwid = ">=1.1.1" 237 | 238 | [[package]] 239 | name = "pygments" 240 | version = "2.7.3" 241 | description = "Pygments is a syntax highlighting package written in Python." 242 | category = "dev" 243 | optional = false 244 | python-versions = ">=3.5" 245 | 246 | [[package]] 247 | name = "pyparsing" 248 | version = "2.4.7" 249 | description = "Python parsing module" 250 | category = "main" 251 | optional = false 252 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 253 | 254 | [[package]] 255 | name = "regex" 256 | version = "2020.11.13" 257 | description = "Alternative regular expression module, to replace re." 258 | category = "main" 259 | optional = false 260 | python-versions = "*" 261 | 262 | [[package]] 263 | name = "requests" 264 | version = "2.25.1" 265 | description = "Python HTTP for Humans." 266 | category = "main" 267 | optional = false 268 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 269 | 270 | [package.dependencies] 271 | certifi = ">=2017.4.17" 272 | chardet = ">=3.0.2,<5" 273 | idna = ">=2.5,<3" 274 | urllib3 = ">=1.21.1,<1.27" 275 | 276 | [package.extras] 277 | security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] 278 | socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] 279 | 280 | [[package]] 281 | name = "sacremoses" 282 | version = "0.0.43" 283 | description = "SacreMoses" 284 | category = "main" 285 | optional = false 286 | python-versions = "*" 287 | 288 | [package.dependencies] 289 | click = "*" 290 | joblib = "*" 291 | regex = "*" 292 | six = "*" 293 | tqdm = "*" 294 | 295 | [[package]] 296 | name = "sentencepiece" 297 | version = "0.1.94" 298 | description = "SentencePiece python wrapper" 299 | category = "main" 300 | optional = false 301 | python-versions = "*" 302 | 303 | [[package]] 304 | name = "six" 305 | version = "1.15.0" 306 | description = "Python 2 and 3 compatibility utilities" 307 | category = "main" 308 | optional = false 309 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 310 | 311 | [[package]] 312 | name = "tabulate" 313 | version = "0.8.7" 314 | description = "Pretty-print tabular data" 315 | category = "main" 316 | optional = false 317 | python-versions = "*" 318 | 319 | [package.extras] 320 | widechars = ["wcwidth"] 321 | 322 | [[package]] 323 | name = "tensorboardx" 324 | version = "2.1" 325 | description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" 326 | category = "main" 327 | optional = false 328 | python-versions = "*" 329 | 330 | [package.dependencies] 331 | numpy = "*" 332 | protobuf = ">=3.8.0" 333 | six = "*" 334 | 335 | [[package]] 336 | name = "tokenizers" 337 | version = "0.9.4" 338 | description = "Fast and Customizable Tokenizers" 339 | category = "main" 340 | optional = false 341 | python-versions = "*" 342 | 343 | [package.extras] 344 | testing = ["pytest"] 345 | 346 | [[package]] 347 | name = "torch" 348 | version = "1.7.1+cu101" 349 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 350 | category = "main" 351 | optional = false 352 | python-versions = ">=3.6.2" 353 | 354 | [package.dependencies] 355 | numpy = "*" 356 | typing-extensions = "*" 357 | 358 | [package.source] 359 | type = "url" 360 | url = "https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp38-cp38-linux_x86_64.whl" 361 | 362 | [[package]] 363 | name = "tqdm" 364 | version = "4.54.1" 365 | description = "Fast, Extensible Progress Meter" 366 | category = "main" 367 | optional = false 368 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 369 | 370 | [package.extras] 371 | dev = ["py-make (>=0.1.0)", "twine", "argopt", "pydoc-markdown", "wheel"] 372 | 373 | [[package]] 374 | name = "traitlets" 375 | version = "5.0.5" 376 | description = "Traitlets Python configuration system" 377 | category = "dev" 378 | optional = false 379 | python-versions = ">=3.7" 380 | 381 | [package.dependencies] 382 | ipython-genutils = "*" 383 | 384 | [package.extras] 385 | test = ["pytest"] 386 | 387 | [[package]] 388 | name = "transformers" 389 | version = "4.1.1" 390 | description = "State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch" 391 | category = "main" 392 | optional = false 393 | python-versions = ">=3.6.0" 394 | 395 | [package.dependencies] 396 | filelock = "*" 397 | numpy = "*" 398 | packaging = "*" 399 | regex = "!=2019.12.17" 400 | requests = "*" 401 | sacremoses = "*" 402 | tokenizers = "0.9.4" 403 | tqdm = ">=4.27" 404 | 405 | [package.extras] 406 | all = ["tensorflow (>=2.0)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.0)", "jaxlib (==0.1.55)", "flax (>=0.2.2)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (==0.9.4)"] 407 | dev = ["tensorflow (>=2.0)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.0)", "jaxlib (==0.1.55)", "flax (>=0.2.2)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (==0.9.4)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "faiss-cpu", "datasets", "cookiecutter (==1.7.2)", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "scikit-learn"] 408 | docs = ["recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton"] 409 | flax = ["jax (>=0.2.0)", "jaxlib (==0.1.55)", "flax (>=0.2.2)"] 410 | ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"] 411 | modelcreation = ["cookiecutter (==1.7.2)"] 412 | onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 413 | quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 414 | retrieval = ["faiss-cpu", "datasets"] 415 | sentencepiece = ["sentencepiece (==0.1.91)", "protobuf"] 416 | serving = ["pydantic", "uvicorn", "fastapi", "starlette"] 417 | sklearn = ["scikit-learn"] 418 | testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "faiss-cpu", "datasets", "cookiecutter (==1.7.2)"] 419 | tf = ["tensorflow (>=2.0)", "onnxconverter-common", "keras2onnx"] 420 | tf-cpu = ["tensorflow-cpu (>=2.0)", "onnxconverter-common", "keras2onnx"] 421 | tokenizers = ["tokenizers (==0.9.4)"] 422 | torch = ["torch (>=1.0)"] 423 | 424 | [[package]] 425 | name = "typing-extensions" 426 | version = "3.7.4.3" 427 | description = "Backported and Experimental Type Hints for Python 3.5+" 428 | category = "main" 429 | optional = false 430 | python-versions = "*" 431 | 432 | [[package]] 433 | name = "urllib3" 434 | version = "1.26.2" 435 | description = "HTTP library with thread-safe connection pooling, file post, and more." 436 | category = "main" 437 | optional = false 438 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" 439 | 440 | [package.extras] 441 | brotli = ["brotlipy (>=0.6.0)"] 442 | secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] 443 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] 444 | 445 | [[package]] 446 | name = "urwid" 447 | version = "2.1.2" 448 | description = "A full-featured console (xterm et al.) user interface library" 449 | category = "dev" 450 | optional = false 451 | python-versions = "*" 452 | 453 | [[package]] 454 | name = "wcwidth" 455 | version = "0.2.5" 456 | description = "Measures the displayed width of unicode strings in a terminal" 457 | category = "dev" 458 | optional = false 459 | python-versions = "*" 460 | 461 | [metadata] 462 | lock-version = "1.1" 463 | python-versions = "^3.8" 464 | content-hash = "77c25d8945fd172070c63e0e6485a785eaffe871a2f0c785e15488bbde138386" 465 | 466 | [metadata.files] 467 | appnope = [ 468 | {file = "appnope-0.1.2-py2.py3-none-any.whl", hash = "sha256:93aa393e9d6c54c5cd570ccadd8edad61ea0c4b9ea7a01409020c9aa019eb442"}, 469 | {file = "appnope-0.1.2.tar.gz", hash = "sha256:dd83cd4b5b460958838f6eb3000c660b1f9caf2a5b1de4264e941512f603258a"}, 470 | ] 471 | backcall = [ 472 | {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, 473 | {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, 474 | ] 475 | certifi = [ 476 | {file = "certifi-2020.12.5-py2.py3-none-any.whl", hash = "sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830"}, 477 | {file = "certifi-2020.12.5.tar.gz", hash = "sha256:1a4995114262bffbc2413b159f2a1a480c969de6e6eb13ee966d470af86af59c"}, 478 | ] 479 | chardet = [ 480 | {file = "chardet-4.0.0-py2.py3-none-any.whl", hash = "sha256:f864054d66fd9118f2e67044ac8981a54775ec5b67aed0441892edb553d21da5"}, 481 | {file = "chardet-4.0.0.tar.gz", hash = "sha256:0d6f53a15db4120f2b08c94f11e7d93d2c911ee118b6b30a04ec3ee8310179fa"}, 482 | ] 483 | click = [ 484 | {file = "click-7.1.2-py2.py3-none-any.whl", hash = "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc"}, 485 | {file = "click-7.1.2.tar.gz", hash = "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a"}, 486 | ] 487 | colorama = [ 488 | {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, 489 | {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, 490 | ] 491 | decorator = [ 492 | {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, 493 | {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, 494 | ] 495 | filelock = [ 496 | {file = "filelock-3.0.12-py3-none-any.whl", hash = "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"}, 497 | {file = "filelock-3.0.12.tar.gz", hash = "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59"}, 498 | ] 499 | h5py = [ 500 | {file = "h5py-2.10.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f"}, 501 | {file = "h5py-2.10.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:86868dc07b9cc8cb7627372a2e6636cdc7a53b7e2854ad020c9e9d8a4d3fd0f5"}, 502 | {file = "h5py-2.10.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:aac4b57097ac29089f179bbc2a6e14102dd210618e94d77ee4831c65f82f17c0"}, 503 | {file = "h5py-2.10.0-cp27-cp27m-win32.whl", hash = "sha256:7be5754a159236e95bd196419485343e2b5875e806fe68919e087b6351f40a70"}, 504 | {file = "h5py-2.10.0-cp27-cp27m-win_amd64.whl", hash = "sha256:13c87efa24768a5e24e360a40e0bc4c49bcb7ce1bb13a3a7f9902cec302ccd36"}, 505 | {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:79b23f47c6524d61f899254f5cd5e486e19868f1823298bc0c29d345c2447172"}, 506 | {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:cbf28ae4b5af0f05aa6e7551cee304f1d317dbed1eb7ac1d827cee2f1ef97a99"}, 507 | {file = "h5py-2.10.0-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:c0d4b04bbf96c47b6d360cd06939e72def512b20a18a8547fa4af810258355d5"}, 508 | {file = "h5py-2.10.0-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:549ad124df27c056b2e255ea1c44d30fb7a17d17676d03096ad5cd85edb32dc1"}, 509 | {file = "h5py-2.10.0-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:a5f82cd4938ff8761d9760af3274acf55afc3c91c649c50ab18fcff5510a14a5"}, 510 | {file = "h5py-2.10.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:3dad1730b6470fad853ef56d755d06bb916ee68a3d8272b3bab0c1ddf83bb99e"}, 511 | {file = "h5py-2.10.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:063947eaed5f271679ed4ffa36bb96f57bc14f44dd4336a827d9a02702e6ce6b"}, 512 | {file = "h5py-2.10.0-cp35-cp35m-win32.whl", hash = "sha256:c54a2c0dd4957776ace7f95879d81582298c5daf89e77fb8bee7378f132951de"}, 513 | {file = "h5py-2.10.0-cp35-cp35m-win_amd64.whl", hash = "sha256:6998be619c695910cb0effe5eb15d3a511d3d1a5d217d4bd0bebad1151ec2262"}, 514 | {file = "h5py-2.10.0-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:ff7d241f866b718e4584fa95f520cb19405220c501bd3a53ee11871ba5166ea2"}, 515 | {file = "h5py-2.10.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:54817b696e87eb9e403e42643305f142cd8b940fe9b3b490bbf98c3b8a894cf4"}, 516 | {file = "h5py-2.10.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d3c59549f90a891691991c17f8e58c8544060fdf3ccdea267100fa5f561ff62f"}, 517 | {file = "h5py-2.10.0-cp36-cp36m-win32.whl", hash = "sha256:d7ae7a0576b06cb8e8a1c265a8bc4b73d05fdee6429bffc9a26a6eb531e79d72"}, 518 | {file = "h5py-2.10.0-cp36-cp36m-win_amd64.whl", hash = "sha256:bffbc48331b4a801d2f4b7dac8a72609f0b10e6e516e5c480a3e3241e091c878"}, 519 | {file = "h5py-2.10.0-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:51ae56894c6c93159086ffa2c94b5b3388c0400548ab26555c143e7cfa05b8e5"}, 520 | {file = "h5py-2.10.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16ead3c57141101e3296ebeed79c9c143c32bdd0e82a61a2fc67e8e6d493e9d1"}, 521 | {file = "h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f0e25bb91e7a02efccb50aba6591d3fe2c725479e34769802fcdd4076abfa917"}, 522 | {file = "h5py-2.10.0-cp37-cp37m-win32.whl", hash = "sha256:f23951a53d18398ef1344c186fb04b26163ca6ce449ebd23404b153fd111ded9"}, 523 | {file = "h5py-2.10.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8bb1d2de101f39743f91512a9750fb6c351c032e5cd3204b4487383e34da7f75"}, 524 | {file = "h5py-2.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64f74da4a1dd0d2042e7d04cf8294e04ddad686f8eba9bb79e517ae582f6668d"}, 525 | {file = "h5py-2.10.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:d35f7a3a6cefec82bfdad2785e78359a0e6a5fbb3f605dd5623ce88082ccd681"}, 526 | {file = "h5py-2.10.0-cp38-cp38-win32.whl", hash = "sha256:6ef7ab1089e3ef53ca099038f3c0a94d03e3560e6aff0e9d6c64c55fb13fc681"}, 527 | {file = "h5py-2.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:769e141512b54dee14ec76ed354fcacfc7d97fea5a7646b709f7400cf1838630"}, 528 | {file = "h5py-2.10.0.tar.gz", hash = "sha256:84412798925dc870ffd7107f045d7659e60f5d46d1c70c700375248bf6bf512d"}, 529 | ] 530 | idna = [ 531 | {file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"}, 532 | {file = "idna-2.10.tar.gz", hash = "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6"}, 533 | ] 534 | ipython = [ 535 | {file = "ipython-7.19.0-py3-none-any.whl", hash = "sha256:c987e8178ced651532b3b1ff9965925bfd445c279239697052561a9ab806d28f"}, 536 | {file = "ipython-7.19.0.tar.gz", hash = "sha256:cbb2ef3d5961d44e6a963b9817d4ea4e1fa2eb589c371a470fed14d8d40cbd6a"}, 537 | ] 538 | ipython-genutils = [ 539 | {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, 540 | {file = "ipython_genutils-0.2.0.tar.gz", hash = "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"}, 541 | ] 542 | jedi = [ 543 | {file = "jedi-0.17.2-py2.py3-none-any.whl", hash = "sha256:98cc583fa0f2f8304968199b01b6b4b94f469a1f4a74c1560506ca2a211378b5"}, 544 | {file = "jedi-0.17.2.tar.gz", hash = "sha256:86ed7d9b750603e4ba582ea8edc678657fb4007894a12bcf6f4bb97892f31d20"}, 545 | ] 546 | joblib = [ 547 | {file = "joblib-1.0.0-py3-none-any.whl", hash = "sha256:75ead23f13484a2a414874779d69ade40d4fa1abe62b222a23cd50d4bc822f6f"}, 548 | {file = "joblib-1.0.0.tar.gz", hash = "sha256:7ad866067ac1fdec27d51c8678ea760601b70e32ff1881d4dc8e1171f2b64b24"}, 549 | ] 550 | numpy = [ 551 | {file = "numpy-1.19.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949"}, 552 | {file = "numpy-1.19.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4"}, 553 | {file = "numpy-1.19.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad"}, 554 | {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83"}, 555 | {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764"}, 556 | {file = "numpy-1.19.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6"}, 557 | {file = "numpy-1.19.4-cp36-cp36m-win32.whl", hash = "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1"}, 558 | {file = "numpy-1.19.4-cp36-cp36m-win_amd64.whl", hash = "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb"}, 559 | {file = "numpy-1.19.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2"}, 560 | {file = "numpy-1.19.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2"}, 561 | {file = "numpy-1.19.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9"}, 562 | {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757"}, 563 | {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15"}, 564 | {file = "numpy-1.19.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387"}, 565 | {file = "numpy-1.19.4-cp37-cp37m-win32.whl", hash = "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36"}, 566 | {file = "numpy-1.19.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c"}, 567 | {file = "numpy-1.19.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909"}, 568 | {file = "numpy-1.19.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c"}, 569 | {file = "numpy-1.19.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893"}, 570 | {file = "numpy-1.19.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab"}, 571 | {file = "numpy-1.19.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9"}, 572 | {file = "numpy-1.19.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db"}, 573 | {file = "numpy-1.19.4-cp38-cp38-win32.whl", hash = "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac"}, 574 | {file = "numpy-1.19.4-cp38-cp38-win_amd64.whl", hash = "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce"}, 575 | {file = "numpy-1.19.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63"}, 576 | {file = "numpy-1.19.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37"}, 577 | {file = "numpy-1.19.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414"}, 578 | {file = "numpy-1.19.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc"}, 579 | {file = "numpy-1.19.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3"}, 580 | {file = "numpy-1.19.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753"}, 581 | {file = "numpy-1.19.4-cp39-cp39-win32.whl", hash = "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f"}, 582 | {file = "numpy-1.19.4-cp39-cp39-win_amd64.whl", hash = "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b"}, 583 | {file = "numpy-1.19.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08"}, 584 | {file = "numpy-1.19.4.zip", hash = "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512"}, 585 | ] 586 | packaging = [ 587 | {file = "packaging-20.8-py2.py3-none-any.whl", hash = "sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858"}, 588 | {file = "packaging-20.8.tar.gz", hash = "sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093"}, 589 | ] 590 | parso = [ 591 | {file = "parso-0.7.1-py2.py3-none-any.whl", hash = "sha256:97218d9159b2520ff45eb78028ba8b50d2bc61dcc062a9682666f2dc4bd331ea"}, 592 | {file = "parso-0.7.1.tar.gz", hash = "sha256:caba44724b994a8a5e086460bb212abc5a8bc46951bf4a9a1210745953622eb9"}, 593 | ] 594 | pexpect = [ 595 | {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, 596 | {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, 597 | ] 598 | pickleshare = [ 599 | {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, 600 | {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, 601 | ] 602 | prompt-toolkit = [ 603 | {file = "prompt_toolkit-3.0.8-py3-none-any.whl", hash = "sha256:7debb9a521e0b1ee7d2fe96ee4bd60ef03c6492784de0547337ca4433e46aa63"}, 604 | {file = "prompt_toolkit-3.0.8.tar.gz", hash = "sha256:25c95d2ac813909f813c93fde734b6e44406d1477a9faef7c915ff37d39c0a8c"}, 605 | ] 606 | protobuf = [ 607 | {file = "protobuf-3.14.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:629b03fd3caae7f815b0c66b41273f6b1900a579e2ccb41ef4493a4f5fb84f3a"}, 608 | {file = "protobuf-3.14.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:5b7a637212cc9b2bcf85dd828b1178d19efdf74dbfe1ddf8cd1b8e01fdaaa7f5"}, 609 | {file = "protobuf-3.14.0-cp35-cp35m-macosx_10_9_intel.whl", hash = "sha256:43b554b9e73a07ba84ed6cf25db0ff88b1e06be610b37656e292e3cbb5437472"}, 610 | {file = "protobuf-3.14.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:5e9806a43232a1fa0c9cf5da8dc06f6910d53e4390be1fa06f06454d888a9142"}, 611 | {file = "protobuf-3.14.0-cp35-cp35m-win32.whl", hash = "sha256:1c51fda1bbc9634246e7be6016d860be01747354ed7015ebe38acf4452f470d2"}, 612 | {file = "protobuf-3.14.0-cp35-cp35m-win_amd64.whl", hash = "sha256:4b74301b30513b1a7494d3055d95c714b560fbb630d8fb9956b6f27992c9f980"}, 613 | {file = "protobuf-3.14.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86a75477addde4918e9a1904e5c6af8d7b691f2a3f65587d73b16100fbe4c3b2"}, 614 | {file = "protobuf-3.14.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ecc33531a213eee22ad60e0e2aaea6c8ba0021f0cce35dbf0ab03dee6e2a23a1"}, 615 | {file = "protobuf-3.14.0-cp36-cp36m-win32.whl", hash = "sha256:72230ed56f026dd664c21d73c5db73ebba50d924d7ba6b7c0d81a121e390406e"}, 616 | {file = "protobuf-3.14.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0fc96785262042e4863b3f3b5c429d4636f10d90061e1840fce1baaf59b1a836"}, 617 | {file = "protobuf-3.14.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4e75105c9dfe13719b7293f75bd53033108f4ba03d44e71db0ec2a0e8401eafd"}, 618 | {file = "protobuf-3.14.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2a7e2fe101a7ace75e9327b9c946d247749e564a267b0515cf41dfe450b69bac"}, 619 | {file = "protobuf-3.14.0-cp37-cp37m-win32.whl", hash = "sha256:b0d5d35faeb07e22a1ddf8dce620860c8fe145426c02d1a0ae2688c6e8ede36d"}, 620 | {file = "protobuf-3.14.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8971c421dbd7aad930c9bd2694122f332350b6ccb5202a8b7b06f3f1a5c41ed5"}, 621 | {file = "protobuf-3.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9616f0b65a30851e62f1713336c931fcd32c057202b7ff2cfbfca0fc7d5e3043"}, 622 | {file = "protobuf-3.14.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:22bcd2e284b3b1d969c12e84dc9b9a71701ec82d8ce975fdda19712e1cfd4e00"}, 623 | {file = "protobuf-3.14.0-py2.py3-none-any.whl", hash = "sha256:0e247612fadda953047f53301a7b0407cb0c3cb4ae25a6fde661597a04039b3c"}, 624 | {file = "protobuf-3.14.0.tar.gz", hash = "sha256:1d63eb389347293d8915fb47bee0951c7b5dab522a4a60118b9a18f33e21f8ce"}, 625 | ] 626 | ptyprocess = [ 627 | {file = "ptyprocess-0.6.0-py2.py3-none-any.whl", hash = "sha256:d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f"}, 628 | {file = "ptyprocess-0.6.0.tar.gz", hash = "sha256:923f299cc5ad920c68f2bc0bc98b75b9f838b93b599941a6b63ddbc2476394c0"}, 629 | ] 630 | pudb = [ 631 | {file = "pudb-2019.2.tar.gz", hash = "sha256:e8f0ea01b134d802872184b05bffc82af29a1eb2f9374a277434b932d68f58dc"}, 632 | ] 633 | pygments = [ 634 | {file = "Pygments-2.7.3-py3-none-any.whl", hash = "sha256:f275b6c0909e5dafd2d6269a656aa90fa58ebf4a74f8fcf9053195d226b24a08"}, 635 | {file = "Pygments-2.7.3.tar.gz", hash = "sha256:ccf3acacf3782cbed4a989426012f1c535c9a90d3a7fc3f16d231b9372d2b716"}, 636 | ] 637 | pyparsing = [ 638 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 639 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 640 | ] 641 | regex = [ 642 | {file = "regex-2020.11.13-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:8b882a78c320478b12ff024e81dc7d43c1462aa4a3341c754ee65d857a521f85"}, 643 | {file = "regex-2020.11.13-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a63f1a07932c9686d2d416fb295ec2c01ab246e89b4d58e5fa468089cab44b70"}, 644 | {file = "regex-2020.11.13-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:6e4b08c6f8daca7d8f07c8d24e4331ae7953333dbd09c648ed6ebd24db5a10ee"}, 645 | {file = "regex-2020.11.13-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:bba349276b126947b014e50ab3316c027cac1495992f10e5682dc677b3dfa0c5"}, 646 | {file = "regex-2020.11.13-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:56e01daca75eae420bce184edd8bb341c8eebb19dd3bce7266332258f9fb9dd7"}, 647 | {file = "regex-2020.11.13-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:6a8ce43923c518c24a2579fda49f093f1397dad5d18346211e46f134fc624e31"}, 648 | {file = "regex-2020.11.13-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:1ab79fcb02b930de09c76d024d279686ec5d532eb814fd0ed1e0051eb8bd2daa"}, 649 | {file = "regex-2020.11.13-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:9801c4c1d9ae6a70aeb2128e5b4b68c45d4f0af0d1535500884d644fa9b768c6"}, 650 | {file = "regex-2020.11.13-cp36-cp36m-win32.whl", hash = "sha256:49cae022fa13f09be91b2c880e58e14b6da5d10639ed45ca69b85faf039f7a4e"}, 651 | {file = "regex-2020.11.13-cp36-cp36m-win_amd64.whl", hash = "sha256:749078d1eb89484db5f34b4012092ad14b327944ee7f1c4f74d6279a6e4d1884"}, 652 | {file = "regex-2020.11.13-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b2f4007bff007c96a173e24dcda236e5e83bde4358a557f9ccf5e014439eae4b"}, 653 | {file = "regex-2020.11.13-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:38c8fd190db64f513fe4e1baa59fed086ae71fa45083b6936b52d34df8f86a88"}, 654 | {file = "regex-2020.11.13-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5862975b45d451b6db51c2e654990c1820523a5b07100fc6903e9c86575202a0"}, 655 | {file = "regex-2020.11.13-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:262c6825b309e6485ec2493ffc7e62a13cf13fb2a8b6d212f72bd53ad34118f1"}, 656 | {file = "regex-2020.11.13-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:bafb01b4688833e099d79e7efd23f99172f501a15c44f21ea2118681473fdba0"}, 657 | {file = "regex-2020.11.13-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:e32f5f3d1b1c663af7f9c4c1e72e6ffe9a78c03a31e149259f531e0fed826512"}, 658 | {file = "regex-2020.11.13-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:3bddc701bdd1efa0d5264d2649588cbfda549b2899dc8d50417e47a82e1387ba"}, 659 | {file = "regex-2020.11.13-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:02951b7dacb123d8ea6da44fe45ddd084aa6777d4b2454fa0da61d569c6fa538"}, 660 | {file = "regex-2020.11.13-cp37-cp37m-win32.whl", hash = "sha256:0d08e71e70c0237883d0bef12cad5145b84c3705e9c6a588b2a9c7080e5af2a4"}, 661 | {file = "regex-2020.11.13-cp37-cp37m-win_amd64.whl", hash = "sha256:1fa7ee9c2a0e30405e21031d07d7ba8617bc590d391adfc2b7f1e8b99f46f444"}, 662 | {file = "regex-2020.11.13-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:baf378ba6151f6e272824b86a774326f692bc2ef4cc5ce8d5bc76e38c813a55f"}, 663 | {file = "regex-2020.11.13-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e3faaf10a0d1e8e23a9b51d1900b72e1635c2d5b0e1bea1c18022486a8e2e52d"}, 664 | {file = "regex-2020.11.13-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:2a11a3e90bd9901d70a5b31d7dd85114755a581a5da3fc996abfefa48aee78af"}, 665 | {file = "regex-2020.11.13-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:d1ebb090a426db66dd80df8ca85adc4abfcbad8a7c2e9a5ec7513ede522e0a8f"}, 666 | {file = "regex-2020.11.13-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:b2b1a5ddae3677d89b686e5c625fc5547c6e492bd755b520de5332773a8af06b"}, 667 | {file = "regex-2020.11.13-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2c99e97d388cd0a8d30f7c514d67887d8021541b875baf09791a3baad48bb4f8"}, 668 | {file = "regex-2020.11.13-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:c084582d4215593f2f1d28b65d2a2f3aceff8342aa85afd7be23a9cad74a0de5"}, 669 | {file = "regex-2020.11.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:a3d748383762e56337c39ab35c6ed4deb88df5326f97a38946ddd19028ecce6b"}, 670 | {file = "regex-2020.11.13-cp38-cp38-win32.whl", hash = "sha256:7913bd25f4ab274ba37bc97ad0e21c31004224ccb02765ad984eef43e04acc6c"}, 671 | {file = "regex-2020.11.13-cp38-cp38-win_amd64.whl", hash = "sha256:6c54ce4b5d61a7129bad5c5dc279e222afd00e721bf92f9ef09e4fae28755683"}, 672 | {file = "regex-2020.11.13-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1862a9d9194fae76a7aaf0150d5f2a8ec1da89e8b55890b1786b8f88a0f619dc"}, 673 | {file = "regex-2020.11.13-cp39-cp39-manylinux1_i686.whl", hash = "sha256:4902e6aa086cbb224241adbc2f06235927d5cdacffb2425c73e6570e8d862364"}, 674 | {file = "regex-2020.11.13-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7a25fcbeae08f96a754b45bdc050e1fb94b95cab046bf56b016c25e9ab127b3e"}, 675 | {file = "regex-2020.11.13-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:d2d8ce12b7c12c87e41123997ebaf1a5767a5be3ec545f64675388970f415e2e"}, 676 | {file = "regex-2020.11.13-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f7d29a6fc4760300f86ae329e3b6ca28ea9c20823df123a2ea8693e967b29917"}, 677 | {file = "regex-2020.11.13-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:717881211f46de3ab130b58ec0908267961fadc06e44f974466d1887f865bd5b"}, 678 | {file = "regex-2020.11.13-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:3128e30d83f2e70b0bed9b2a34e92707d0877e460b402faca908c6667092ada9"}, 679 | {file = "regex-2020.11.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:8f6a2229e8ad946e36815f2a03386bb8353d4bde368fdf8ca5f0cb97264d3b5c"}, 680 | {file = "regex-2020.11.13-cp39-cp39-win32.whl", hash = "sha256:f8f295db00ef5f8bae530fc39af0b40486ca6068733fb860b42115052206466f"}, 681 | {file = "regex-2020.11.13-cp39-cp39-win_amd64.whl", hash = "sha256:a15f64ae3a027b64496a71ab1f722355e570c3fac5ba2801cafce846bf5af01d"}, 682 | {file = "regex-2020.11.13.tar.gz", hash = "sha256:83d6b356e116ca119db8e7c6fc2983289d87b27b3fac238cfe5dca529d884562"}, 683 | ] 684 | requests = [ 685 | {file = "requests-2.25.1-py2.py3-none-any.whl", hash = "sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e"}, 686 | {file = "requests-2.25.1.tar.gz", hash = "sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804"}, 687 | ] 688 | sacremoses = [ 689 | {file = "sacremoses-0.0.43.tar.gz", hash = "sha256:123c1bf2664351fb05e16f87d3786dbe44a050cfd7b85161c09ad9a63a8e2948"}, 690 | ] 691 | sentencepiece = [ 692 | {file = "sentencepiece-0.1.94-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:7b6c794d30272a5e635e958fdb4976dd991bf35eed90441104a042b2e51723c7"}, 693 | {file = "sentencepiece-0.1.94-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:b5e3eedad0ef5b3a4ae1d201fc0edc7f4b4d567c016913d4b996ebf0ab66748b"}, 694 | {file = "sentencepiece-0.1.94-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:58db565195ee31efbaca9d00937f9f73aa131cc820c2ad46a39ac62f8671866f"}, 695 | {file = "sentencepiece-0.1.94-cp35-cp35m-manylinux2014_ppc64le.whl", hash = "sha256:cbde526df19d6bcfa2b8503b2a4bf6996dd3172f631fd2b7efd7b6435d96407c"}, 696 | {file = "sentencepiece-0.1.94-cp35-cp35m-manylinux2014_s390x.whl", hash = "sha256:b01057743c2488c8d6e7b45b0732ee23976ac3d58d11cd90390cbc3221c07402"}, 697 | {file = "sentencepiece-0.1.94-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:cd6434909e1c8494b3254bf3150420e45489214d9bc7ab6ad4d1804d75d6d58f"}, 698 | {file = "sentencepiece-0.1.94-cp36-cp36m-macosx_10_6_x86_64.whl", hash = "sha256:7b4867845e6935c43e37042a451d2ce84d9d97365300151a8c1c1cc724acad32"}, 699 | {file = "sentencepiece-0.1.94-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4d7d0844a57156b630fb98e21203c2755b342824b8c5a445e4ac78612c291218"}, 700 | {file = "sentencepiece-0.1.94-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:a75f418bd92c6c92e2ee0c95e89b45b76bc54e45ed7cf2b3b74d313b263d1baa"}, 701 | {file = "sentencepiece-0.1.94-cp36-cp36m-manylinux2014_ppc64le.whl", hash = "sha256:995e645a94107e46317987d348216a0fb1ae3a8befec9c99cc506b8994aa133d"}, 702 | {file = "sentencepiece-0.1.94-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:232a882ebf074966e24943119ab83554642bd339bd5d6bd2641092133983bc6a"}, 703 | {file = "sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:db744b73b5a5fd7adfa5cfc4eb4b7d0f408c2059783fd52c934b49743a0d2326"}, 704 | {file = "sentencepiece-0.1.94-cp36-cp36m-win32.whl", hash = "sha256:1d7c9f52a2e32a7a2eb9685ddf74a86b5df94fcaccf37be661ac9bb5c9db4893"}, 705 | {file = "sentencepiece-0.1.94-cp36-cp36m-win_amd64.whl", hash = "sha256:11bd70be4baf4e67b1714e43bcd1e7fed0ce04616a20388367299846fdaf712d"}, 706 | {file = "sentencepiece-0.1.94-cp37-cp37m-macosx_10_6_x86_64.whl", hash = "sha256:9c8476febe8eb0a165cf04192ebd2b15124d83cfc44269e10d2a83ace677f109"}, 707 | {file = "sentencepiece-0.1.94-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:9d2245d400424ab261e3253308001606668126a08efdc19ee2c348b0e228e1e1"}, 708 | {file = "sentencepiece-0.1.94-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:e4aef0be184f3c5b72a1c3f7e01fbf245eb3b3c70365f823e24542008afe387f"}, 709 | {file = "sentencepiece-0.1.94-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:5c2969c4f62039d82f761c9548011bf39673a1eb8dc8f747943b88851523c943"}, 710 | {file = "sentencepiece-0.1.94-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:c9d440d9ecf8c8787b89bc8596f7a47c548a9968f802d654faaf5652598ffbb0"}, 711 | {file = "sentencepiece-0.1.94-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:295ef1ccf570c33728040a461cf837611495e8a5bd954012a5784fb3529ff460"}, 712 | {file = "sentencepiece-0.1.94-cp37-cp37m-win32.whl", hash = "sha256:9d446ad41744a898f34800ee492553b4a24255a0f922cb32fe33a3c0a865d153"}, 713 | {file = "sentencepiece-0.1.94-cp37-cp37m-win_amd64.whl", hash = "sha256:fd12969cf8420870bee743398e2e60f722d1ffdf9d201dc1d6b09096c971bfd9"}, 714 | {file = "sentencepiece-0.1.94-cp38-cp38-macosx_10_6_x86_64.whl", hash = "sha256:3f6c0b5c501053a2f9d99daccbf187f367ded5ae35e9e031feae56188b352433"}, 715 | {file = "sentencepiece-0.1.94-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:05e6ef6a669d2e2d3232d95acfb2a9d255272484b898ea0650659d95448bf93f"}, 716 | {file = "sentencepiece-0.1.94-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:1e6b711563163fc8cf2c873d08b4495244859e3f6d6c18859b524395d8550482"}, 717 | {file = "sentencepiece-0.1.94-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:bf524fa6243cfd05a04f65a6b17516ddd58438adf3c35df02ca3ebb832270a47"}, 718 | {file = "sentencepiece-0.1.94-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:ed49f1187a25db531e2ad95718a5640a3f7e0467bc82e4267cc6f7b6caa3054a"}, 719 | {file = "sentencepiece-0.1.94-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:fb31a1827da0de50dc8ca33d4e657121594092c7231a4fb2d6a86149dfd98bc5"}, 720 | {file = "sentencepiece-0.1.94-cp38-cp38-win32.whl", hash = "sha256:9c87f759dddefff52c12d4a3500a00faf22ea476a004c33c78794699069d8fc9"}, 721 | {file = "sentencepiece-0.1.94-cp38-cp38-win_amd64.whl", hash = "sha256:4c11b2fc89c71510a900e2dbd4d93fb18a867ce7160f298bb6bb8a581d646d63"}, 722 | {file = "sentencepiece-0.1.94-cp39-cp39-macosx_10_6_x86_64.whl", hash = "sha256:88ef71e36b09ddd53498064efaec5470a09698df2427362cc4e86198d88aa01e"}, 723 | {file = "sentencepiece-0.1.94-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a89d90b45ba5025fcd19cad685c7572624a036d883091af967a75f3793c2aee4"}, 724 | {file = "sentencepiece-0.1.94-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:c571b26017d8dd1c47dc2eeae09caa15cfe3d2f31fb01f004d463403a1f1349b"}, 725 | {file = "sentencepiece-0.1.94-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:42d35adb51eb530d57c56c2cd445dbf9bd9db36bf82741aa5b42216f7f34c12d"}, 726 | {file = "sentencepiece-0.1.94-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fee3e6849b9e0cef774fb003ba2950b282b1910cdd761794bbf8dc0aa9d5f7d3"}, 727 | {file = "sentencepiece-0.1.94-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:e5074d8239dcc6130dce8ffd734ab797f86679fc75a4a1d96adc243293178c05"}, 728 | {file = "sentencepiece-0.1.94.tar.gz", hash = "sha256:849d74885f6f7af03a5d354b919bf23c757f94257d7a068bc464efd70d651e3a"}, 729 | ] 730 | six = [ 731 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, 732 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, 733 | ] 734 | tabulate = [ 735 | {file = "tabulate-0.8.7-py3-none-any.whl", hash = "sha256:ac64cb76d53b1231d364babcd72abbb16855adac7de6665122f97b593f1eb2ba"}, 736 | {file = "tabulate-0.8.7.tar.gz", hash = "sha256:db2723a20d04bcda8522165c73eea7c300eda74e0ce852d9022e0159d7895007"}, 737 | ] 738 | tensorboardx = [ 739 | {file = "tensorboardX-2.1-py2.py3-none-any.whl", hash = "sha256:2d81c10d9e3225dcd9bb5fb277588610bdf45317603e7682f6953d83b5b38f6a"}, 740 | {file = "tensorboardX-2.1.tar.gz", hash = "sha256:9e8907cf2ab900542d6cb72bf91aa87b43005a7f0aa43126268697e3727872f9"}, 741 | ] 742 | tokenizers = [ 743 | {file = "tokenizers-0.9.4-cp35-cp35m-macosx_10_11_x86_64.whl", hash = "sha256:082de5272363aee13f36641065a3dd2d78f5b51486e3ab7d6d34138905a46303"}, 744 | {file = "tokenizers-0.9.4-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:543dcb31b8534cf3ad66817f925f50f4ccd182ed1433fcd07adaed5d389f682b"}, 745 | {file = "tokenizers-0.9.4-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:89f816e5aa61c464e9d82025f2c4f1f66cd92f648ab9194a154ba2b0e180dc70"}, 746 | {file = "tokenizers-0.9.4-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:768f36e743604f567f4e4817a76738ed1bcdaecfef5ae8c74bdf2277a7a1902d"}, 747 | {file = "tokenizers-0.9.4-cp35-cp35m-manylinux2014_ppc64le.whl", hash = "sha256:800917d7085245db0b55f88b2a12bd0ba4eb5966e8b88bd9f21aa46aadfa8204"}, 748 | {file = "tokenizers-0.9.4-cp35-cp35m-manylinux2014_s390x.whl", hash = "sha256:bce664d24c744387760beab14cc7bd4e405bbef93c333ba3ca4a93347949c3ba"}, 749 | {file = "tokenizers-0.9.4-cp35-cp35m-win32.whl", hash = "sha256:b57fc7f2003f1f7b873dcffd5d0ee7c71f01709c54c36f4d191e4a7911d49565"}, 750 | {file = "tokenizers-0.9.4-cp35-cp35m-win_amd64.whl", hash = "sha256:1313d63ce286c6c9812a51ea39ae84cf1b8f2887c8ce8cc813459fdfbf526c9b"}, 751 | {file = "tokenizers-0.9.4-cp36-cp36m-macosx_10_11_x86_64.whl", hash = "sha256:2dd1156815cf2ca2a0942c8efc72e0725b6cd4640a61e026c72bf5a330f4383a"}, 752 | {file = "tokenizers-0.9.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:58e1904c3e75e37be379ee4b29b21b05189d54bfab0260b334cff6e5a44a4f45"}, 753 | {file = "tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:4fd1a765af0a7aff7dab58d7fcd63a2e4a860e829b931bdfd59e2c56ba1769b9"}, 754 | {file = "tokenizers-0.9.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:3cf5b470b2e06aadee22771740d87a706216385f881308c70cb317476ec40904"}, 755 | {file = "tokenizers-0.9.4-cp36-cp36m-manylinux2014_ppc64le.whl", hash = "sha256:c83f7a26d6f0c765906440c7f2b726cbd18e5c7a63e0364095600c91e2905cc4"}, 756 | {file = "tokenizers-0.9.4-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:427257e78b71e9310d0c035df9b054525d1da91cc46efbae95fee2d523b88eb9"}, 757 | {file = "tokenizers-0.9.4-cp36-cp36m-win32.whl", hash = "sha256:4a5ddd6689e18b6c5398b97134e79e948e1bbe7664f6962aa63f50fb05cae091"}, 758 | {file = "tokenizers-0.9.4-cp36-cp36m-win_amd64.whl", hash = "sha256:53395c4423e8309b208f1e973337c08a3cb68af5eb9dee8d8618428fd4579803"}, 759 | {file = "tokenizers-0.9.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:d2824dedd9f26e3757159d99c743b287ebf78775ccf4a36a3e0ec7058ee66303"}, 760 | {file = "tokenizers-0.9.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b49f17c2ac2bf88875a74d63e8070fd5a69e8c3b2874dee47649826b603a3af1"}, 761 | {file = "tokenizers-0.9.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:da361a88b21cd141441fb139d1ee05c815103d49d10b49bfb4218a240d0d5a84"}, 762 | {file = "tokenizers-0.9.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:a03c101d8058c851a7647cc74c68d4db511d7a3db8a73f7ec715e4fe14281ed7"}, 763 | {file = "tokenizers-0.9.4-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:8d8ca7daa2f2274ec9327961ac828c20fcadd76e88d07f611742f240a6c73abe"}, 764 | {file = "tokenizers-0.9.4-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:9de00f951fa8c1cf5c54a5a813447c9bf810759822de6ba6cfa42d7f503ff799"}, 765 | {file = "tokenizers-0.9.4-cp37-cp37m-win32.whl", hash = "sha256:535cf3edfd0df2c1887ea388691dd8f614331f47b41cb40c0901a2ce070ff7e0"}, 766 | {file = "tokenizers-0.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:f3351eef9187ba7b9ceb04ff74fcda535f26c4146fe40155c6ed6087302944fd"}, 767 | {file = "tokenizers-0.9.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:06e1a1c50c7600d8162d8f0eeed460ad9e9234ffee7d5c7bcd1308024d781647"}, 768 | {file = "tokenizers-0.9.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c60b8ba2d8a948bb40c39223a4b2553c7c1df9f732b0077722b91df5d63c5e37"}, 769 | {file = "tokenizers-0.9.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:31184c4691aed1e84088d7a18c1000bbc59f7bedeec95774ec4027129ea16272"}, 770 | {file = "tokenizers-0.9.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:abdbd169738c33e2e643e7701230f43c2f4e6e03d49283d4250f19159f6a6c71"}, 771 | {file = "tokenizers-0.9.4-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:ac4c0a2f052a83146c6475dc22f9eb740d352b29779ac6036459f00d897025b8"}, 772 | {file = "tokenizers-0.9.4-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:96879e21be25b63fb99fa7d65b50b05c2a0333f104ca003917df7433d6eb073e"}, 773 | {file = "tokenizers-0.9.4-cp38-cp38-win32.whl", hash = "sha256:1764a705be63fb61abcaa96637399f124528f9a01925c88efb438aefe315b61b"}, 774 | {file = "tokenizers-0.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a3180c8a1cb77eca8fe9c291e0f197aee202c93ffdea4f96d06ca154f319980c"}, 775 | {file = "tokenizers-0.9.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:d518ef8323690cd4d51979ff2f44edbac5862db8c8af125e815e41cf4517c638"}, 776 | {file = "tokenizers-0.9.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:807f321731a3466b9e0230cbc8e6d9c5581d5ac6536d96360b5fe1ec457d837f"}, 777 | {file = "tokenizers-0.9.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:3ea6d65a32c8b3236553e489573f42855af484d24bf96ab32a5d6d1a2c4b0ed0"}, 778 | {file = "tokenizers-0.9.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:15440ba1db7c7b3eb7b5881b276555e25420ce14639926585837b7b60ddb55a8"}, 779 | {file = "tokenizers-0.9.4-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:bd46747f5c7d6e1721234d5ec1c0038bcfe0050c147c92171c3ef5b36d6fb2a9"}, 780 | {file = "tokenizers-0.9.4-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:9f79b57a4d6a1aa8379a931e8ee54cb155cc3f5f1ba5172bcdea504dbd4cb746"}, 781 | {file = "tokenizers-0.9.4-cp39-cp39-win32.whl", hash = "sha256:c496748853c0300b8b7be916e130f0de8224575ee72e8889405477f120bfe575"}, 782 | {file = "tokenizers-0.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:2479ef9a30fe8a961cb49c8bf6a5c5e2ce8e1b87849374c9756f41cf06189bdf"}, 783 | {file = "tokenizers-0.9.4.tar.gz", hash = "sha256:3ea3038008f1f74c8a1e1e2e73728690eed2d7fa4db0a51bcea391e644672426"}, 784 | ] 785 | torch = [] 786 | tqdm = [ 787 | {file = "tqdm-4.54.1-py2.py3-none-any.whl", hash = "sha256:d4f413aecb61c9779888c64ddf0c62910ad56dcbe857d8922bb505d4dbff0df1"}, 788 | {file = "tqdm-4.54.1.tar.gz", hash = "sha256:38b658a3e4ecf9b4f6f8ff75ca16221ae3378b2e175d846b6b33ea3a20852cf5"}, 789 | ] 790 | traitlets = [ 791 | {file = "traitlets-5.0.5-py3-none-any.whl", hash = "sha256:69ff3f9d5351f31a7ad80443c2674b7099df13cc41fc5fa6e2f6d3b0330b0426"}, 792 | {file = "traitlets-5.0.5.tar.gz", hash = "sha256:178f4ce988f69189f7e523337a3e11d91c786ded9360174a3d9ca83e79bc5396"}, 793 | ] 794 | transformers = [ 795 | {file = "transformers-4.1.1-py3-none-any.whl", hash = "sha256:3a525c33b544eccc0afbf03f7636db92e6d56d728dc6a4cdae7862af379c2193"}, 796 | {file = "transformers-4.1.1.tar.gz", hash = "sha256:f2cf80855edfb47d87894a4462dde0d9973b99ee1b78ef4cfb16191b92b79858"}, 797 | ] 798 | typing-extensions = [ 799 | {file = "typing_extensions-3.7.4.3-py2-none-any.whl", hash = "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"}, 800 | {file = "typing_extensions-3.7.4.3-py3-none-any.whl", hash = "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918"}, 801 | {file = "typing_extensions-3.7.4.3.tar.gz", hash = "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c"}, 802 | ] 803 | urllib3 = [ 804 | {file = "urllib3-1.26.2-py2.py3-none-any.whl", hash = "sha256:d8ff90d979214d7b4f8ce956e80f4028fc6860e4431f731ea4a8c08f23f99473"}, 805 | {file = "urllib3-1.26.2.tar.gz", hash = "sha256:19188f96923873c92ccb987120ec4acaa12f0461fa9ce5d3d0772bc965a39e08"}, 806 | ] 807 | urwid = [ 808 | {file = "urwid-2.1.2.tar.gz", hash = "sha256:588bee9c1cb208d0906a9f73c613d2bd32c3ed3702012f51efe318a3f2127eae"}, 809 | ] 810 | wcwidth = [ 811 | {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, 812 | {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, 813 | ] 814 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "spanalign" 3 | version = "0.1.0" 4 | description = "The code described in \"SpanAlign: Sentence Alignment Method based on Cross-Language Span Prediction and ILP\" published at COLING'20." 5 | authors = ["Katsuki Chousa "] 6 | license = "NTT License" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.8" 10 | sacremoses = "^0.0.43" 11 | transformers = "^4.1.1" 12 | tqdm = "^4.54.1" 13 | tensorboardX = "^2.1" 14 | sentencepiece = "^0.1.94" 15 | h5py = "2.10.0" 16 | torch = {url = "https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp38-cp38-linux_x86_64.whl"} 17 | tabulate = "^0.8.7" 18 | 19 | [tool.poetry.dev-dependencies] 20 | pudb = "^2019.2" 21 | ipython = "^7.19.0" 22 | 23 | [build-system] 24 | requires = ["poetry-core>=1.0.0"] 25 | build-backend = "poetry.core.masonry.api" 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5; python_full_version >= "3.6.0" 2 | chardet==4.0.0; python_full_version >= "3.6.0" 3 | click==7.1.2; python_full_version >= "3.6.0" 4 | filelock==3.0.12; python_full_version >= "3.6.0" 5 | h5py==2.10.0 6 | idna==2.10; python_full_version >= "3.6.0" 7 | joblib==1.0.0; python_version >= "3.6" and python_full_version >= "3.6.0" 8 | numpy==1.19.4; python_version >= "3.6" and python_full_version >= "3.6.2" 9 | packaging==20.8; python_full_version >= "3.6.0" 10 | protobuf==3.14.0 11 | pyparsing==2.4.7; python_full_version >= "3.6.0" 12 | regex==2020.11.13; python_full_version >= "3.6.0" 13 | requests==2.25.1; python_full_version >= "3.6.0" 14 | sacremoses==0.0.43 15 | sentencepiece==0.1.94 16 | six==1.15.0; python_full_version >= "3.6.0" 17 | tabulate==0.8.7 18 | tensorboardx==2.1 19 | tokenizers==0.9.4; python_full_version >= "3.6.0" 20 | torch @ https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp38-cp38-linux_x86_64.whl ; python_full_version >= "3.6.2" 21 | tqdm==4.54.1; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.4.0") 22 | transformers==4.1.1; python_full_version >= "3.6.0" 23 | typing-extensions==3.7.4.3; python_full_version >= "3.6.2" 24 | urllib3==1.26.2; python_full_version >= "3.6.0" and python_version < "4" 25 | -------------------------------------------------------------------------------- /run_qa_alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | modified on Dec. 22, 2020 by Katsuki Chousa 5 | original: huggingface/transformers v4.0.1 examples/question-answering/run_squad.py 6 | """ 7 | 8 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 9 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" 23 | 24 | 25 | import argparse 26 | import glob 27 | import logging 28 | import os 29 | import random 30 | import timeit 31 | 32 | import numpy as np 33 | import torch 34 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 35 | from torch.utils.data.distributed import DistributedSampler 36 | from tqdm import tqdm, trange 37 | 38 | import transformers 39 | from transformers import ( 40 | MODEL_FOR_QUESTION_ANSWERING_MAPPING, 41 | WEIGHTS_NAME, 42 | AdamW, 43 | AutoConfig, 44 | AutoModelForQuestionAnswering, 45 | AutoTokenizer, 46 | get_linear_schedule_with_warmup, 47 | ) 48 | from transformers.data.metrics.squad_metrics import ( 49 | compute_predictions_log_probs, 50 | compute_predictions_logits, 51 | squad_evaluate, 52 | ) 53 | from transformers.trainer_utils import is_main_process 54 | 55 | 56 | try: 57 | from torch.utils.tensorboard import SummaryWriter 58 | except ImportError: 59 | from tensorboardX import SummaryWriter 60 | 61 | # added 62 | from transformers.data.processors.squad import ( 63 | SquadResult, 64 | squad_convert_examples_to_features, 65 | ) 66 | from utils import ( 67 | load_and_cache_examples, 68 | MySquadProcessor, 69 | HDF5Dataset 70 | ) 71 | 72 | 73 | logger = logging.getLogger(__name__) 74 | 75 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) 76 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 77 | 78 | 79 | def set_seed(args): 80 | random.seed(args.seed) 81 | np.random.seed(args.seed) 82 | torch.manual_seed(args.seed) 83 | if args.n_gpu > 0: 84 | torch.cuda.manual_seed_all(args.seed) 85 | 86 | 87 | def to_list(tensor): 88 | return tensor.detach().cpu().tolist() 89 | 90 | 91 | def train(args, train_dataset, model, tokenizer): 92 | """ Train the model """ 93 | if args.local_rank in [-1, 0]: 94 | tb_writer = SummaryWriter() 95 | 96 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 97 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 98 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 99 | 100 | if args.max_steps > 0: 101 | t_total = args.max_steps 102 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 103 | else: 104 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 105 | 106 | # Prepare optimizer and schedule (linear warmup and decay) 107 | no_decay = ["bias", "LayerNorm.weight"] 108 | optimizer_grouped_parameters = [ 109 | { 110 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 111 | "weight_decay": args.weight_decay, 112 | }, 113 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 114 | ] 115 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 116 | scheduler = get_linear_schedule_with_warmup( 117 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 118 | ) 119 | 120 | # Check if saved optimizer or scheduler states exist 121 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 122 | os.path.join(args.model_name_or_path, "scheduler.pt") 123 | ): 124 | # Load in optimizer and scheduler states 125 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 126 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 127 | 128 | if args.fp16: 129 | try: 130 | from apex import amp 131 | except ImportError: 132 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 133 | 134 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 135 | 136 | # multi-gpu training (should be after apex fp16 initialization) 137 | if args.n_gpu > 1: 138 | model = torch.nn.DataParallel(model) 139 | 140 | # Distributed training (should be after apex fp16 initialization) 141 | if args.local_rank != -1: 142 | model = torch.nn.parallel.DistributedDataParallel( 143 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True 144 | ) 145 | 146 | # Train! 147 | logger.info("***** Running training *****") 148 | logger.info(" Num examples = %d", len(train_dataset)) 149 | logger.info(" Num Epochs = %d", args.num_train_epochs) 150 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 151 | logger.info( 152 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 153 | args.train_batch_size 154 | * args.gradient_accumulation_steps 155 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 156 | ) 157 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 158 | logger.info(" Total optimization steps = %d", t_total) 159 | 160 | global_step = 1 161 | epochs_trained = 0 162 | steps_trained_in_current_epoch = 0 163 | # Check if continuing training from a checkpoint 164 | if os.path.exists(args.model_name_or_path): 165 | try: 166 | # set global_step to gobal_step of last saved checkpoint from model path 167 | checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] 168 | global_step = int(checkpoint_suffix) 169 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 170 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 171 | 172 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 173 | logger.info(" Continuing training from epoch %d", epochs_trained) 174 | logger.info(" Continuing training from global step %d", global_step) 175 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 176 | except ValueError: 177 | logger.info(" Starting fine-tuning.") 178 | 179 | tr_loss, logging_loss = 0.0, 0.0 180 | model.zero_grad() 181 | train_iterator = trange( 182 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] 183 | ) 184 | # Added here for reproductibility 185 | set_seed(args) 186 | 187 | for _ in train_iterator: 188 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 189 | for step, batch in enumerate(epoch_iterator): 190 | 191 | # Skip past any already trained steps if resuming training 192 | if steps_trained_in_current_epoch > 0: 193 | steps_trained_in_current_epoch -= 1 194 | continue 195 | 196 | model.train() 197 | batch = tuple(t.to(args.device) for t in batch) 198 | 199 | inputs = { 200 | "input_ids": batch[0], 201 | "attention_mask": batch[1], 202 | "token_type_ids": batch[2], 203 | "start_positions": batch[3], 204 | "end_positions": batch[4], 205 | } 206 | 207 | if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]: 208 | del inputs["token_type_ids"] 209 | 210 | if args.model_type in ["xlnet", "xlm"]: 211 | inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) 212 | if args.version_2_with_negative: 213 | inputs.update({"is_impossible": batch[7]}) 214 | if hasattr(model, "config") and hasattr(model.config, "lang2id"): 215 | inputs.update( 216 | {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} 217 | ) 218 | 219 | outputs = model(**inputs) 220 | # model outputs are always tuple in transformers (see doc) 221 | loss = outputs[0] 222 | 223 | if args.n_gpu > 1: 224 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 225 | if args.gradient_accumulation_steps > 1: 226 | loss = loss / args.gradient_accumulation_steps 227 | 228 | if args.fp16: 229 | with amp.scale_loss(loss, optimizer) as scaled_loss: 230 | scaled_loss.backward() 231 | else: 232 | loss.backward() 233 | 234 | tr_loss += loss.item() 235 | if (step + 1) % args.gradient_accumulation_steps == 0: 236 | if args.fp16: 237 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 238 | else: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 240 | 241 | optimizer.step() 242 | scheduler.step() # Update learning rate schedule 243 | model.zero_grad() 244 | global_step += 1 245 | 246 | # Log metrics 247 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 248 | # Only evaluate when single GPU otherwise metrics may not average well 249 | if args.local_rank == -1 and args.evaluate_during_training: 250 | results = evaluate(args, model, tokenizer) 251 | for key, value in results.items(): 252 | tb_writer.add_scalar("eval_{}".format(key), value, global_step) 253 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 254 | tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 255 | logging_loss = tr_loss 256 | 257 | # Save model checkpoint 258 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 259 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 260 | # Take care of distributed/parallel training 261 | model_to_save = model.module if hasattr(model, "module") else model 262 | model_to_save.save_pretrained(output_dir) 263 | tokenizer.save_pretrained(output_dir) 264 | 265 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 266 | logger.info("Saving model checkpoint to %s", output_dir) 267 | 268 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 269 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 270 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 271 | 272 | if args.max_steps > 0 and global_step > args.max_steps: 273 | epoch_iterator.close() 274 | break 275 | if args.max_steps > 0 and global_step > args.max_steps: 276 | train_iterator.close() 277 | break 278 | 279 | if args.local_rank in [-1, 0]: 280 | tb_writer.close() 281 | 282 | return global_step, tr_loss / global_step 283 | 284 | 285 | def evaluate(args, model, tokenizer, prefix=""): 286 | # dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) 287 | feature_reader = load_and_cache_examples(args, tokenizer, evaluate=True) 288 | dataset = HDF5Dataset(feature_reader) 289 | examples = feature_reader.load_examples() 290 | features = feature_reader.get_features() 291 | 292 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 293 | os.makedirs(args.output_dir) 294 | 295 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 296 | 297 | # Note that DistributedSampler samples randomly 298 | eval_sampler = SequentialSampler(dataset) 299 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 300 | 301 | # multi-gpu evaluate 302 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 303 | model = torch.nn.DataParallel(model) 304 | 305 | # Eval! 306 | logger.info("***** Running evaluation {} *****".format(prefix)) 307 | logger.info(" Num examples = %d", len(dataset)) 308 | logger.info(" Batch size = %d", args.eval_batch_size) 309 | 310 | all_results = [] 311 | start_time = timeit.default_timer() 312 | 313 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 314 | model.eval() 315 | batch = tuple(t.to(args.device) for t in batch) 316 | 317 | with torch.no_grad(): 318 | inputs = { 319 | "input_ids": batch[0], 320 | "attention_mask": batch[1], 321 | "token_type_ids": batch[2], 322 | } 323 | 324 | if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]: 325 | del inputs["token_type_ids"] 326 | 327 | feature_indices = batch[3] 328 | 329 | # XLNet and XLM use more arguments for their predictions 330 | if args.model_type in ["xlnet", "xlm"]: 331 | inputs.update({"cls_index": batch[4], "p_mask": batch[5]}) 332 | # for lang_id-sensitive xlm models 333 | if hasattr(model, "config") and hasattr(model.config, "lang2id"): 334 | inputs.update( 335 | {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} 336 | ) 337 | outputs = model(**inputs) 338 | 339 | for i, feature_index in enumerate(feature_indices): 340 | eval_feature = features[feature_index.item()] 341 | unique_id = int(eval_feature.unique_id) 342 | 343 | output = [to_list(output[i]) for output in outputs.to_tuple()] 344 | 345 | # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler" 346 | # models only use two. 347 | if len(output) >= 5: 348 | start_logits = output[0] 349 | start_top_index = output[1] 350 | end_logits = output[2] 351 | end_top_index = output[3] 352 | cls_logits = output[4] 353 | 354 | result = SquadResult( 355 | unique_id, 356 | start_logits, 357 | end_logits, 358 | start_top_index=start_top_index, 359 | end_top_index=end_top_index, 360 | cls_logits=cls_logits, 361 | ) 362 | 363 | else: 364 | start_logits, end_logits = output 365 | result = SquadResult(unique_id, start_logits, end_logits) 366 | 367 | all_results.append(result) 368 | 369 | evalTime = timeit.default_timer() - start_time 370 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) 371 | 372 | # Compute predictions 373 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) 374 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) 375 | 376 | if args.version_2_with_negative: 377 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) 378 | else: 379 | output_null_log_odds_file = None 380 | 381 | # XLNet and XLM use a more complex post-processing procedure 382 | if args.model_type in ["xlnet", "xlm"]: 383 | start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top 384 | end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top 385 | 386 | predictions = compute_predictions_log_probs( 387 | examples, 388 | features, 389 | all_results, 390 | args.n_best_size, 391 | args.max_answer_length, 392 | output_prediction_file, 393 | output_nbest_file, 394 | output_null_log_odds_file, 395 | start_n_top, 396 | end_n_top, 397 | args.version_2_with_negative, 398 | tokenizer, 399 | args.verbose_logging, 400 | ) 401 | else: 402 | predictions = compute_predictions_logits( 403 | examples, 404 | features, 405 | all_results, 406 | args.n_best_size, 407 | args.max_answer_length, 408 | args.do_lower_case, 409 | output_prediction_file, 410 | output_nbest_file, 411 | output_null_log_odds_file, 412 | args.verbose_logging, 413 | args.version_2_with_negative, 414 | args.null_score_diff_threshold, 415 | tokenizer, 416 | ) 417 | 418 | # Compute the F1 and exact scores. 419 | results = squad_evaluate(examples, predictions) 420 | return results 421 | 422 | 423 | def main(): 424 | parser = argparse.ArgumentParser() 425 | 426 | # Required parameters 427 | parser.add_argument( 428 | "--model_type", 429 | default=None, 430 | type=str, 431 | required=True, 432 | help="Model type selected in the list: " + ", ".join(MODEL_TYPES), 433 | ) 434 | parser.add_argument( 435 | "--model_name_or_path", 436 | default=None, 437 | type=str, 438 | required=True, 439 | help="Path to pretrained model or model identifier from huggingface.co/models", 440 | ) 441 | parser.add_argument( 442 | "--output_dir", 443 | default=None, 444 | type=str, 445 | required=True, 446 | help="The output directory where the model checkpoints and predictions will be written.", 447 | ) 448 | 449 | # Other parameters 450 | parser.add_argument( 451 | "--data_dir", 452 | default=None, 453 | type=str, 454 | help="The input data dir. Should contain the .json files for the task." 455 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 456 | ) 457 | parser.add_argument( 458 | "--train_file", 459 | default=None, 460 | type=str, 461 | help="The input training file. If a data dir is specified, will look for the file there" 462 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 463 | ) 464 | parser.add_argument( 465 | "--predict_file", 466 | default=None, 467 | type=str, 468 | help="The input evaluation file. If a data dir is specified, will look for the file there" 469 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 470 | ) 471 | parser.add_argument( 472 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 473 | ) 474 | parser.add_argument( 475 | "--tokenizer_name", 476 | default="", 477 | type=str, 478 | help="Pretrained tokenizer name or path if not the same as model_name", 479 | ) 480 | parser.add_argument( 481 | "--cache_dir", 482 | default="", 483 | type=str, 484 | help="Where do you want to store the pre-trained models downloaded from huggingface.co", 485 | ) 486 | 487 | parser.add_argument( 488 | "--version_2_with_negative", 489 | action="store_true", 490 | help="If true, the SQuAD examples contain some that do not have an answer.", 491 | ) 492 | parser.add_argument( 493 | "--null_score_diff_threshold", 494 | type=float, 495 | default=0.0, 496 | help="If null_score - best_non_null is greater than the threshold predict null.", 497 | ) 498 | 499 | parser.add_argument( 500 | "--max_seq_length", 501 | default=384, 502 | type=int, 503 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 504 | "longer than this will be truncated, and sequences shorter than this will be padded.", 505 | ) 506 | parser.add_argument( 507 | "--doc_stride", 508 | default=128, 509 | type=int, 510 | help="When splitting up a long document into chunks, how much stride to take between chunks.", 511 | ) 512 | parser.add_argument( 513 | "--max_query_length", 514 | default=64, 515 | type=int, 516 | help="The maximum number of tokens for the question. Questions longer than this will " 517 | "be truncated to this length.", 518 | ) 519 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 520 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 521 | parser.add_argument( 522 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." 523 | ) 524 | parser.add_argument( 525 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 526 | ) 527 | 528 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 529 | parser.add_argument( 530 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." 531 | ) 532 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 533 | parser.add_argument( 534 | "--gradient_accumulation_steps", 535 | type=int, 536 | default=1, 537 | help="Number of updates steps to accumulate before performing a backward/update pass.", 538 | ) 539 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 540 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 541 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 542 | parser.add_argument( 543 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." 544 | ) 545 | parser.add_argument( 546 | "--max_steps", 547 | default=-1, 548 | type=int, 549 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 550 | ) 551 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 552 | parser.add_argument( 553 | "--n_best_size", 554 | default=20, 555 | type=int, 556 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.", 557 | ) 558 | parser.add_argument( 559 | "--max_answer_length", 560 | default=30, 561 | type=int, 562 | help="The maximum length of an answer that can be generated. This is needed because the start " 563 | "and end predictions are not conditioned on one another.", 564 | ) 565 | parser.add_argument( 566 | "--verbose_logging", 567 | action="store_true", 568 | help="If true, all of the warnings related to data processing will be printed. " 569 | "A number of warnings are expected for a normal SQuAD evaluation.", 570 | ) 571 | parser.add_argument( 572 | "--lang_id", 573 | default=0, 574 | type=int, 575 | help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)", 576 | ) 577 | 578 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 579 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 580 | parser.add_argument( 581 | "--eval_all_checkpoints", 582 | action="store_true", 583 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 584 | ) 585 | parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") 586 | parser.add_argument( 587 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 588 | ) 589 | parser.add_argument( 590 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 591 | ) 592 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 593 | 594 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 595 | parser.add_argument( 596 | "--fp16", 597 | action="store_true", 598 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 599 | ) 600 | parser.add_argument( 601 | "--fp16_opt_level", 602 | type=str, 603 | default="O1", 604 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 605 | "See details at https://nvidia.github.io/apex/amp.html", 606 | ) 607 | parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") 608 | parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") 609 | 610 | parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") 611 | args = parser.parse_args() 612 | 613 | if args.doc_stride >= args.max_seq_length - args.max_query_length: 614 | logger.warning( 615 | "WARNING - You've set a doc stride which may be superior to the document length in some " 616 | "examples. This could result in errors when building features from the examples. Please reduce the doc " 617 | "stride or increase the maximum length to ensure the features are correctly built." 618 | ) 619 | 620 | if ( 621 | os.path.exists(args.output_dir) 622 | and os.listdir(args.output_dir) 623 | and args.do_train 624 | and not args.overwrite_output_dir 625 | ): 626 | raise ValueError( 627 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 628 | args.output_dir 629 | ) 630 | ) 631 | 632 | # Setup distant debugging if needed 633 | if args.server_ip and args.server_port: 634 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 635 | import ptvsd 636 | 637 | print("Waiting for debugger attach") 638 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 639 | ptvsd.wait_for_attach() 640 | 641 | # Setup CUDA, GPU & distributed training 642 | if args.local_rank == -1 or args.no_cuda: 643 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 644 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 645 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 646 | torch.cuda.set_device(args.local_rank) 647 | device = torch.device("cuda", args.local_rank) 648 | torch.distributed.init_process_group(backend="nccl") 649 | args.n_gpu = 1 650 | args.device = device 651 | 652 | # Setup logging 653 | logging.basicConfig( 654 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 655 | datefmt="%m/%d/%Y %H:%M:%S", 656 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 657 | ) 658 | logger.warning( 659 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 660 | args.local_rank, 661 | device, 662 | args.n_gpu, 663 | bool(args.local_rank != -1), 664 | args.fp16, 665 | ) 666 | # Set the verbosity to info of the Transformers logger (on main process only): 667 | if is_main_process(args.local_rank): 668 | transformers.utils.logging.set_verbosity_info() 669 | transformers.utils.logging.enable_default_handler() 670 | transformers.utils.logging.enable_explicit_format() 671 | # Set seed 672 | set_seed(args) 673 | 674 | # Load pretrained model and tokenizer 675 | if args.local_rank not in [-1, 0]: 676 | # Make sure only the first process in distributed training will download model & vocab 677 | torch.distributed.barrier() 678 | 679 | args.model_type = args.model_type.lower() 680 | config = AutoConfig.from_pretrained( 681 | args.config_name if args.config_name else args.model_name_or_path, 682 | cache_dir=args.cache_dir if args.cache_dir else None, 683 | ) 684 | tokenizer = AutoTokenizer.from_pretrained( 685 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 686 | do_lower_case=args.do_lower_case, 687 | cache_dir=args.cache_dir if args.cache_dir else None, 688 | use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling 689 | ) 690 | model = AutoModelForQuestionAnswering.from_pretrained( 691 | args.model_name_or_path, 692 | from_tf=bool(".ckpt" in args.model_name_or_path), 693 | config=config, 694 | cache_dir=args.cache_dir if args.cache_dir else None, 695 | ) 696 | 697 | if args.local_rank == 0: 698 | # Make sure only the first process in distributed training will download model & vocab 699 | torch.distributed.barrier() 700 | 701 | model.to(args.device) 702 | 703 | logger.info("Training/evaluation parameters %s", args) 704 | 705 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. 706 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will 707 | # remove the need for this code, but it is still valid. 708 | if args.fp16: 709 | try: 710 | import apex 711 | 712 | apex.amp.register_half_function(torch, "einsum") 713 | except ImportError: 714 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 715 | 716 | # Training 717 | if args.do_train: 718 | # train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) 719 | feature_reader = load_and_cache_examples(args, tokenizer, evaluate=False) 720 | train_dataset = HDF5Dataset(feature_reader) 721 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 722 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 723 | 724 | # Save the trained model and the tokenizer 725 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 726 | logger.info("Saving model checkpoint to %s", args.output_dir) 727 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 728 | # They can then be reloaded using `from_pretrained()` 729 | # Take care of distributed/parallel training 730 | model_to_save = model.module if hasattr(model, "module") else model 731 | model_to_save.save_pretrained(args.output_dir) 732 | tokenizer.save_pretrained(args.output_dir) 733 | 734 | # Good practice: save your training arguments together with the trained model 735 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 736 | 737 | # Load a trained model and vocabulary that you have fine-tuned 738 | model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True) 739 | 740 | # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling 741 | # So we use use_fast=False here for now until Fast-tokenizer-compatible-examples are out 742 | tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, use_fast=False) 743 | model.to(args.device) 744 | 745 | # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory 746 | results = {} 747 | if args.do_eval and args.local_rank in [-1, 0]: 748 | if args.do_train: 749 | logger.info("Loading checkpoints saved during training for evaluation") 750 | checkpoints = [args.output_dir] 751 | if args.eval_all_checkpoints: 752 | checkpoints = list( 753 | os.path.dirname(c) 754 | for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) 755 | ) 756 | 757 | else: 758 | logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) 759 | checkpoints = [args.model_name_or_path] 760 | 761 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 762 | 763 | for checkpoint in checkpoints: 764 | # Reload the model 765 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 766 | model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True) 767 | model.to(args.device) 768 | 769 | # Evaluate 770 | result = evaluate(args, model, tokenizer, prefix=global_step) 771 | 772 | result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) 773 | results.update(result) 774 | 775 | logger.info("Results: {}".format(results)) 776 | 777 | return results 778 | 779 | 780 | if __name__ == "__main__": 781 | main() 782 | -------------------------------------------------------------------------------- /scripts/SQuAD.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on 2019/10/31 4 | 5 | ''' 6 | 7 | class SQuAD(): 8 | def __init__(self, version2=False): 9 | self.version = "v2.0" if version2 else "1.1" 10 | self.data = [] 11 | # end 12 | 13 | def _dataToJson(self): 14 | return [d.toJson() for d in self.data] 15 | # end 16 | 17 | def appendData(self, dat): 18 | self.data.append(dat) 19 | # end 20 | 21 | def toJson(self): 22 | jo = {"version": self.version, "data": self._dataToJson()} 23 | 24 | return jo 25 | # end 26 | # end 27 | 28 | class SQuADDataItem(): 29 | def __init__(self, title): 30 | self.title = title 31 | self.paragraphs = [] 32 | # end 33 | 34 | def _paragraphsToJson(self): 35 | return [p.toJson() for p in self.paragraphs] 36 | # end 37 | 38 | def appendParagraph(self, p): 39 | self.paragraphs.append(p) 40 | # end 41 | 42 | def toJson(self): 43 | jo = {"title": self.title, "paragraphs": self._paragraphsToJson()} 44 | 45 | return jo 46 | # end 47 | # end 48 | 49 | class SQuADParagrap(): 50 | def __init__(self, context): 51 | self.context = context 52 | self.qas = [] 53 | # end 54 | 55 | def _qasToJson(self): 56 | return [qa.toJson() for qa in self.qas] 57 | # end 58 | 59 | def appendQA(self, qa): 60 | self.qas.append(qa) 61 | # end 62 | 63 | def extendQA(self, qa_list): 64 | self.qas.extend(qa_list) 65 | # end 66 | 67 | def toJson(self): 68 | jo = {"context": self.context, "qas": self._qasToJson()} 69 | 70 | return jo 71 | # end 72 | # end 73 | 74 | class SQuAD_QA(): 75 | def __init__(self, qa_id, question, is_impossible=False): 76 | self.id = qa_id 77 | self.question = question 78 | self.answers = [] 79 | self.is_impossible = is_impossible 80 | # end 81 | 82 | def _answersToJson(self): 83 | return [ans.toJson() for ans in self.answers] 84 | # end 85 | 86 | def appendAnswer(self, answer): 87 | self.answers.append(answer) 88 | # end 89 | 90 | def toJson(self): 91 | jo = {"id": self.id, 92 | "question": self.question, 93 | "answers": self._answersToJson(), 94 | "is_impossible": self.is_impossible} 95 | 96 | return jo 97 | # end 98 | # end 99 | 100 | class SQuADAnswer(): 101 | def __init__(self, text, start, end=None): 102 | self.text = text 103 | self.answer_start = start 104 | if end is None: 105 | self.answer_end = None 106 | else: 107 | self.answer_end = end 108 | # end 109 | # end 110 | 111 | def toJson(self): 112 | if self.answer_end is None: 113 | jo = {"text": self.text, "answer_start": self.answer_start} 114 | else: 115 | jo = {"text": self.text, "answer_start": self.answer_start, "answer_end": self.answer_end} 116 | # end 117 | 118 | return jo 119 | # end 120 | # end 121 | 122 | class Span(): 123 | def __init__(self, start, end, text): 124 | self.start = start 125 | self.end = end 126 | self.text = text 127 | # end 128 | # end 129 | -------------------------------------------------------------------------------- /scripts/doc_to_overlaped_squad.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | created by Katsuki Chousa 5 | updated on Dec. 22, 2020 by Katsuki Chousa 6 | """ 7 | 8 | import sys 9 | import json 10 | from pathlib import Path 11 | from argparse import ArgumentParser 12 | 13 | import SQuAD 14 | 15 | def makeQA(raw_lines, qa_id_prefix, sent_ngram): 16 | qa_list = [] 17 | 18 | for num_sents in range(1, sent_ngram+1): 19 | for begin in range(len(raw_lines) - num_sents + 1): 20 | qa_id = "{}_{}-{}".format(qa_id_prefix, begin, begin+num_sents-1) 21 | q_str = "\n".join(raw_lines[begin : begin+num_sents]) 22 | qa = SQuAD.SQuAD_QA(qa_id, q_str) 23 | 24 | qa_list.append(qa) 25 | 26 | return qa_list 27 | 28 | def main(config): 29 | title = "{}_{}".format(config.title, config.raw_file_l1.stem) 30 | 31 | with config.raw_file_l1.open() as ifs: 32 | raw_lines_l1 = [x.strip().replace(" ", " ") for x in ifs.readlines()] 33 | with config.raw_file_l2.open() as ifs: 34 | raw_lines_l2 = [x.strip().replace(" ", " ") for x in ifs.readlines()] 35 | 36 | context_l1 = "\n".join(raw_lines_l1) 37 | context_l2 = "\n".join(raw_lines_l2) 38 | 39 | paragraph_l1_to_l2 = SQuAD.SQuADParagrap(context_l2) 40 | paragraph_l2_to_l1 = SQuAD.SQuADParagrap(context_l1) 41 | 42 | qa_id_prefix_l1_to_l2 = title + "_1_2" 43 | qa_list_l1_to_l2 = makeQA(raw_lines_l1, qa_id_prefix_l1_to_l2, 44 | config.ngram) 45 | qa_id_prefix_l2_to_l1 = title + "_2_1" 46 | qa_list_l2_to_l1 = makeQA(raw_lines_l2, qa_id_prefix_l2_to_l1, 47 | config.ngram) 48 | 49 | paragraph_l1_to_l2.extendQA(qa_list_l1_to_l2) 50 | paragraph_l2_to_l1.extendQA(qa_list_l2_to_l1) 51 | 52 | squad_data = SQuAD.SQuADDataItem(title) 53 | squad_data.appendParagraph(paragraph_l1_to_l2) 54 | squad_data.appendParagraph(paragraph_l2_to_l1) 55 | 56 | squad = SQuAD.SQuAD(version2=True) 57 | squad.appendData(squad_data) 58 | 59 | with config.output.open("w") as ofs: 60 | json.dump(squad.toJson(), ofs, ensure_ascii=False, indent=2) 61 | 62 | def parse_args(): 63 | parser = ArgumentParser() 64 | parser.add_argument('raw_file_l1', type=Path) 65 | parser.add_argument('raw_file_l2', type=Path) 66 | parser.add_argument('output', type=Path) 67 | parser.add_argument('-t', '--title', type=str, default='yomiuri') 68 | parser.add_argument('-n', '--ngram', type=int, default=4) 69 | 70 | return parser.parse_args() 71 | 72 | if __name__ == '__main__': 73 | config = parse_args() 74 | main(config) 75 | -------------------------------------------------------------------------------- /scripts/fileUtils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on 2012/09/26 4 | update: 2019/01/24 5 | for python 3.6 6 | ''' 7 | 8 | import os.path 9 | import glob 10 | import codecs 11 | import shutil 12 | import re 13 | 14 | CTRL_CODE_PAT = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]") 15 | 16 | def getAbsolutePath(path_str): 17 | return os.path.abspath(os.path.expanduser(path_str)) 18 | # end 19 | 20 | def getBaseName(path_str): 21 | return os.path.basename(getAbsolutePath(path_str)) 22 | # end 23 | 24 | def getFileNameRoot(path_str, with_ext=False): 25 | (root, ext) = os.path.splitext(getBaseName(path_str)) 26 | 27 | if (with_ext): 28 | ext = ext[1:] 29 | return (root, ext) 30 | else: 31 | return root 32 | # end 33 | # end 34 | 35 | def listUpFiles(dir_name, pattern, sort=True): 36 | path = os.path.join(dir_name, pattern) 37 | fullpath = os.path.abspath(os.path.expanduser(path)) 38 | 39 | filelist = glob.glob(fullpath) 40 | 41 | if (sort): 42 | filelist.sort() 43 | # end 44 | 45 | return filelist 46 | # end 47 | 48 | def saveStringToFile(string, filename, enc="utf-8", mode="ow", add_terminator=False): 49 | abs_filename = getAbsolutePath(filename) 50 | 51 | if (os.path.exists(abs_filename)): 52 | if (mode == "bk"): 53 | (path, fname) = os.path.split(abs_filename) 54 | 55 | backups = listUpFiles(path, "{0}.old_*".format(fname)) 56 | 57 | if (len(backups) == 0): 58 | next_num = 0 59 | else: 60 | try: 61 | current_num = int(backups[-1].replace("{0}.old_".format(abs_filename), "")) 62 | next_num = current_num + 1 63 | except ValueError: 64 | next_num = 0 65 | # end 66 | # end 67 | 68 | if (next_num > 99): 69 | next_num = 0 70 | # end 71 | 72 | backup_name = "{0}.old_{1:02d}".format(abs_filename, next_num) 73 | shutil.copy2(abs_filename, backup_name) 74 | elif (mode == "st"): 75 | print("{0} is already exist.".format(abs_filename)) 76 | return 77 | # end 78 | # end 79 | 80 | if (add_terminator and string[-1] != "\n"): 81 | string = "{}\n".format(string) 82 | # end 83 | 84 | with codecs.open(abs_filename, mode="w", encoding=enc) as fileObj: 85 | fileObj.write(string) 86 | fileObj.flush() 87 | # end 88 | # end 89 | 90 | def saveListToFile(stringList, filename, enc="utf-8", mode="ow"): 91 | if (len(stringList) == 0 or stringList[-1] != ""): 92 | stringList.append("") 93 | # end 94 | txt = "\n".join(stringList) 95 | 96 | saveStringToFile(txt, filename, enc, mode) 97 | # end 98 | 99 | 100 | def loadFileToList(filename, enc="utf-8-sig", commentMark="#", removeEmptyLine=True, strip="lr", line_num=0, removeCtrlCode=False): 101 | line_counter = 0 102 | results = [] 103 | fullpath = getAbsolutePath(filename) 104 | with codecs.open(fullpath, "r", enc, "replace") as fileObj: 105 | for line in fileObj: 106 | line = line.replace("\0", "") 107 | 108 | if (strip == "lr"): 109 | line = line.strip() 110 | elif (strip == "l"): 111 | line = line.lstrip() 112 | elif (strip == "r"): 113 | line = line.rstrip() 114 | # end 115 | 116 | if (removeCtrlCode): 117 | line = CTRL_CODE_PAT.sub("", line) 118 | # end 119 | 120 | if (removeEmptyLine and not line): 121 | continue 122 | # end 123 | 124 | if (commentMark and line.startswith(commentMark)): 125 | continue 126 | else: 127 | results.append(line) 128 | # end 129 | 130 | # 取得行数指定(1以上)があった場合、その行数だけ取得したら終了 131 | line_counter += 1 132 | if (line_num > 0 and line_counter > line_num): 133 | break 134 | # end 135 | # end 136 | # end 137 | 138 | return results 139 | # end 140 | 141 | def loadFileLineIter(filename, enc="utf-8-sig", commentMark="#", removeEmptyLine=True, strip="lr", line_num=0, removeCtrlCode=False): 142 | fullpath = getAbsolutePath(filename) 143 | with codecs.open(fullpath, "r", enc, "replace") as fileObj: 144 | for line in fileObj: 145 | line = line.replace(u"\0", u"") 146 | 147 | if (strip == "lr"): 148 | line = line.strip() 149 | elif (strip == "l"): 150 | line = line.lstrip() 151 | elif (strip == "r"): 152 | line = line.rstrip() 153 | # end 154 | 155 | if (removeCtrlCode): 156 | line = CTRL_CODE_PAT.sub(u"", line) 157 | # end 158 | 159 | if (removeEmptyLine and not line): 160 | continue 161 | # end 162 | 163 | if (commentMark and line.startswith(commentMark)): 164 | continue 165 | else: 166 | yield line 167 | # end 168 | # end 169 | # end 170 | # end 171 | 172 | def loadFileString(filename, enc="utf-8-sig", commentMark=None, removeEmptyLine=False, strip="r", line_num=0, removeCtrlCode=False): 173 | lines = loadFileToList(filename, enc, commentMark, removeEmptyLine, strip, line_num, removeCtrlCode) 174 | lines.append("") 175 | result_str = "\n".join(lines) 176 | 177 | return result_str 178 | # end 179 | 180 | def collectiveReplace(files, old, new, backup=False): 181 | for fileName in files: 182 | lines = loadFileToList(fileName, commentMark=None, removeEmptyLine=False, strip="r") 183 | lines_str = "\n".join(lines) 184 | 185 | if (backup): 186 | bkfile = "{}.bak".format(fileName) 187 | saveStringToFile(lines_str, bkfile) 188 | # end 189 | 190 | lines_str = lines_str.replace(old, new) 191 | saveStringToFile(lines_str, fileName) 192 | # end 193 | # end 194 | 195 | def countFileLines(filename, enc="utf-8-sig"): 196 | fullpath = getAbsolutePath(filename) 197 | #count = 0 198 | 199 | with codecs.open(fullpath, "rb", enc, "replace") as fileObj: 200 | count = sum(1 for _line in fileObj) 201 | # end 202 | 203 | return count 204 | # end 205 | -------------------------------------------------------------------------------- /scripts/get_sent_align_for_overlap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | created by Katsuki Chousa 5 | updated on Dec. 22, 2020 by Katsuki Chousa 6 | """ 7 | 8 | import sys 9 | import os 10 | import re 11 | import json 12 | import subprocess 13 | import tempfile 14 | 15 | from pathlib import Path 16 | from argparse import ArgumentParser 17 | from collections import namedtuple, defaultdict 18 | from statistics import mean 19 | from typing import List 20 | 21 | CPLEX_PATH = "__SET_CPLEX_PATH__" 22 | 23 | 24 | def edit_distance(s1, s2): 25 | len1 = len(s1) 26 | len2 = len(s2) 27 | 28 | lev = [[0] * len2 for _ in range(len1)] 29 | for i in range(len1): 30 | lev[i][0] = i 31 | for j in range(len2): 32 | lev[0][j] = j 33 | 34 | for i in range(1, len1 + 1): 35 | for j in range(1, len2 + 1): 36 | lev[i][j] = min( 37 | lev[i - 1][j] + 1, 38 | lev[i][j - 1] + 1, 39 | lev[i - 1][j - 1] + (s1[i - 1] != s2[j - 1]) 40 | ) 41 | return lev[len1][len2] 42 | 43 | 44 | def load_raw_data(filepath: Path) -> List[str]: 45 | sents = [] 46 | with filepath.open() as ifs: 47 | for line in ifs: 48 | sents.append(line.strip()) 49 | 50 | return sents 51 | 52 | class Span(object): 53 | def __init__(self, start, end): 54 | self.start = start 55 | self.end = end 56 | 57 | def __len__(self): 58 | if self.start < 0: 59 | return 0 60 | return self.end - self.start + 1 61 | 62 | 63 | Alignment = namedtuple("Alignment", ['l1', 'l2', 'prob']) 64 | def get_alignment_candidate(trg_sents: List[str], 65 | nbest_predictions: dict, 66 | num_src_sents: int, 67 | qa_id_prefix: str, 68 | nbest: int=1, 69 | max_sents: int=-1, 70 | reverse: bool=False) -> List[Alignment]: 71 | noans_flag = [False] * num_src_sents 72 | candidates = [] 73 | for pos in range(len(trg_sents)): 74 | trg_span = Span(pos, pos) 75 | src_span = Span(-1, -1) 76 | if reverse: 77 | candidates.append(Alignment(trg_span, src_span, 0)) 78 | else: 79 | candidates.append(Alignment(src_span, trg_span, 0)) 80 | 81 | for k, v in nbest_predictions.items(): 82 | if not k.startswith(qa_id_prefix): 83 | continue 84 | 85 | start, end = map(int, k.split("_")[-1].split("-")) 86 | src_span = Span(start, end) 87 | 88 | for i in range(min(len(v), nbest)): 89 | pred = v[i] 90 | pred_text = pred['text'] 91 | pred_prob = pred['probability'] 92 | 93 | def span_dist(span): 94 | s2 = "\n".join(trg_sents[span.start : span.end + 1]) 95 | return edit_distance(pred_text, s2) 96 | 97 | trg_span = Span(-1, -1) 98 | if len(pred_text) != 0: 99 | start = 0 100 | trg_span = Span(-1, -1) 101 | while start < len(trg_sents): 102 | if pred_text.find(trg_sents[start]) < 0: 103 | start += 1 104 | continue 105 | 106 | end = start + 1 107 | while (end < len(trg_sents) and 108 | pred_text.find(trg_sents[end]) >= 0): 109 | end += 1 110 | 111 | cand = Span(start, end - 1) 112 | if trg_span.start < 0 or span_dist(trg_span) > span_dist(cand): 113 | trg_span = cand 114 | start = end + 1 115 | 116 | if len(src_span) > 1 and trg_span.start == -1: 117 | continue 118 | if max_sents > 0 and len(trg_span) > max_sents: 119 | continue 120 | if len(src_span) == 1 and trg_span.start == -1: 121 | noans_flag[src_span.start] = True 122 | 123 | if reverse: 124 | candidates.append(Alignment(trg_span, src_span, pred_prob)) 125 | else: 126 | candidates.append(Alignment(src_span, trg_span, pred_prob)) 127 | 128 | for pos, flag in enumerate(noans_flag): 129 | if flag: 130 | continue 131 | 132 | src_span = Span(pos, pos) 133 | trg_span = Span(-1, -1) 134 | if reverse: 135 | candidates.append(Alignment(trg_span, src_span, 0)) 136 | else: 137 | candidates.append(Alignment(src_span, trg_span, 0)) 138 | 139 | return candidates 140 | 141 | def fix_variable_name(src, trg): 142 | name = "x%d_%d_%d_%d" % (src.start, src.end, trg.start, trg.end) 143 | name = name.replace("-1", "X") 144 | 145 | return name 146 | 147 | def create_lp(alignment_candidates, output_file, sentence_penalty=0.): 148 | ofs = output_file.open('w') 149 | src_pos = set() 150 | trg_pos = set() 151 | 152 | print("Minimize", file=ofs) 153 | print('obj:', file=ofs) 154 | for alignment in alignment_candidates: 155 | src = alignment.l1 156 | trg = alignment.l2 157 | prob = alignment.prob 158 | 159 | score = ((1 - prob) + sentence_penalty) * ((max(1, len(src)) + max(1, len(trg))) / 2) 160 | if score >= 0: 161 | print("+", file=ofs, end='') 162 | print(score, fix_variable_name(src, trg), file=ofs) 163 | src_pos.add(src.start) 164 | src_pos.add(src.end) 165 | trg_pos.add(trg.start) 166 | trg_pos.add(trg.end) 167 | print('', file=ofs) 168 | 169 | print('Subject to', file=ofs) 170 | print('', file=ofs) 171 | for pos in src_pos: 172 | if pos < 0: 173 | continue 174 | 175 | print('src_%d:' % (pos), file=ofs) 176 | for alignment in alignment_candidates: 177 | src = alignment.l1 178 | trg = alignment.l2 179 | if src.start <= pos <= src.end: 180 | print("+ 1", fix_variable_name(src, trg), file=ofs) 181 | print("= 1", file=ofs) 182 | print("", file=ofs) 183 | 184 | for pos in trg_pos: 185 | if pos < 0: 186 | continue 187 | 188 | print("trg_%d:" % (pos), file=ofs) 189 | for alignment in alignment_candidates: 190 | src = alignment.l1 191 | trg = alignment.l2 192 | if trg.start <= pos <= trg.end: 193 | print("+ 1", fix_variable_name(src, trg), file=ofs) 194 | print('= 1', file=ofs) 195 | print("", file=ofs) 196 | 197 | print("Binary", file=ofs) 198 | for alignment in alignment_candidates: 199 | src = alignment.l1 200 | trg = alignment.l2 201 | print(fix_variable_name(src, trg), file=ofs) 202 | 203 | print("End", file=ofs) 204 | ofs.close() 205 | 206 | def solve_alignment(alignment_candidates: List[Alignment], title: str, 207 | sentence_penalty: float=0, use_doc_score: bool=False): 208 | lp_file = Path('/tmp/solve_{}.lp'.format(title)) 209 | create_lp(alignment_candidates, lp_file, sentence_penalty) 210 | 211 | batch_file = Path('/tmp/solve_{}.batch'.format(title)) 212 | with batch_file.open('w') as ofs: 213 | print("read", str(lp_file), file=ofs) 214 | print("optimize", file=ofs) 215 | print("display solution variables x*", file=ofs) 216 | 217 | command = [CPLEX_PATH, '-f', batch_file] 218 | result = subprocess.run(command, stdout=subprocess.PIPE, encoding='utf-8').stdout 219 | 220 | solution_span = [] 221 | doc_score = 0 222 | for line in result.split('\n'): 223 | if len(line) == 0 or line[0] != 'x': 224 | continue 225 | 226 | a, b, c, d, _ = re.split('[ _]+', line[1:]) 227 | a, b, c, d = [int(x) if x != "X" else -1 for x in [a, b, c, d]] 228 | 229 | sent_prob = 0 230 | for cand in alignment_candidates: 231 | if cand.l1.start != a or cand.l1.end != b or cand.l2.start != c or cand.l2.end != d: 232 | continue 233 | sent_prob = max(sent_prob, cand.prob) 234 | score = (1 - sent_prob + sentence_penalty) * ((max(1, len(Span(a, b))) + max(1, len(Span(c, d)))) / 2) 235 | 236 | doc_score += score 237 | solution_span.append((Span(a, b), Span(c, d), score)) 238 | 239 | doc_score /= len(solution_span) 240 | for idx in range(len(solution_span)): 241 | l1, l2, score = solution_span[idx] 242 | if use_doc_score: 243 | score *= doc_score 244 | solution_span[idx] = (l1, l2, score) 245 | 246 | assert len(solution_span) > 0, 'solution span not found.' 247 | lp_file.unlink() 248 | batch_file.unlink() 249 | 250 | return solution_span 251 | 252 | def output_alignments(alignments, e_lines, f_lines, index_offset=1, 253 | sentence_penalty=0, ofs=sys.stdout): 254 | for src, trg, score in alignments: 255 | src_sid = list(map(str, range(src.start + index_offset, src.end + 1 + index_offset))) 256 | trg_sid = list(map(str, range(trg.start + index_offset, trg.end + 1 + index_offset))) 257 | 258 | if src.start > -1 and trg.start > -1: 259 | print("[{}]:[{}]:{:.04f}".format(",".join(src_sid), ",".join(trg_sid), score), 260 | file=ofs) 261 | elif src.start > -1: 262 | for sid in src_sid: 263 | print("[{}]:[]:{:.04f}".format(sid, score), 264 | file=ofs) 265 | else: 266 | for sid in trg_sid: 267 | print("[]:[{}]:{:.04f}".format(sid, score), 268 | file=ofs) 269 | 270 | def merge_candidate(alignments_1, alignments_2): 271 | align_set = defaultdict(lambda: [0., 0.]) 272 | 273 | for align in alignments_1: 274 | align_set[(align.l1.start, align.l1.end, align.l2.start, align.l2.end)][0] = align.prob 275 | 276 | for align in alignments_2: 277 | align_set[(align.l1.start, align.l1.end, align.l2.start, align.l2.end)][1] = align.prob 278 | 279 | alignments = [] 280 | for k, v in align_set.items(): 281 | src = Span(k[0], k[1]) 282 | trg = Span(k[2], k[3]) 283 | v = mean(v) 284 | alignments.append(Alignment(src, trg, v)) 285 | 286 | return alignments 287 | 288 | def main(config): 289 | offset = 0 if config.zero_index else 1 290 | 291 | with config.nbest_predictions.open() as ifs: 292 | nbest_predictions = json.load(ifs) 293 | 294 | sents_l1 = load_raw_data(config.lang1) 295 | sents_l2 = load_raw_data(config.lang2) 296 | 297 | l1_to_l2_prefix = config.title + "_1_2" 298 | l1_to_l2_candidate = get_alignment_candidate(sents_l2, nbest_predictions, 299 | len(sents_l1), l1_to_l2_prefix, 300 | nbest=config.nbest, 301 | max_sents=config.max_sents) 302 | l1_to_l2_alignments = solve_alignment(l1_to_l2_candidate, config.title, 303 | config.sentence_penalty, config.use_doc_score) 304 | with config.output.with_suffix('.e2f.pair').open('w') as ofs: 305 | output_alignments(l1_to_l2_alignments, len(sents_l1), len(sents_l2), offset, 306 | ofs=ofs) 307 | 308 | l2_to_l1_prefix = config.title + "_2_1" 309 | l2_to_l1_candidate = get_alignment_candidate(sents_l1, nbest_predictions, 310 | len(sents_l2), l2_to_l1_prefix, 311 | nbest=config.nbest, 312 | max_sents=config.max_sents, 313 | reverse=True) 314 | l2_to_l1_alignments = solve_alignment(l2_to_l1_candidate, config.title, 315 | config.sentence_penalty, config.use_doc_score) 316 | with config.output.with_suffix('.f2e.pair').open('w') as ofs: 317 | output_alignments(l2_to_l1_alignments, len(sents_l1), len(sents_l2), offset, 318 | ofs=ofs) 319 | 320 | bidi_candidate = merge_candidate(l1_to_l2_candidate, l2_to_l1_candidate) 321 | bidi_alignments = solve_alignment(bidi_candidate, config.title, 322 | config.sentence_penalty, config.use_doc_score) 323 | with config.output.with_suffix('.bidi.pair').open('w') as ofs: 324 | output_alignments(bidi_alignments, len(sents_l1), len(sents_l2), offset, 325 | ofs=ofs) 326 | 327 | def parse_args(): 328 | parser = ArgumentParser() 329 | parser.add_argument('lang1', type=Path) 330 | parser.add_argument('lang2', type=Path) 331 | parser.add_argument('nbest_predictions', type=Path) 332 | parser.add_argument('title', type=str) 333 | parser.add_argument('output', type=Path) 334 | parser.add_argument('-z', '--zero_index', action='store_true', 335 | help='output alignment based on 0-index') 336 | parser.add_argument('-n', '--nbest', type=int, default=1) 337 | parser.add_argument('-m', '--max_sents', type=int, default=-1) 338 | parser.add_argument('-p', '--sentence_penalty', type=float, default=0.) 339 | parser.add_argument('-d', '--use_doc_score', action='store_true') 340 | 341 | return parser.parse_args() 342 | 343 | if __name__ == '__main__': 344 | config = parse_args() 345 | 346 | print(config.title) 347 | main(config) 348 | -------------------------------------------------------------------------------- /scripts/mergeSQuADjson.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on 2019/01/25 4 | 5 | 複数の SQuDA 形式データを一つにまとめる 6 | 各 SQuAD の JSON から data 要素のリストの中身を取り出し、 7 | 出力用 JSON の data リストへ追加 8 | 9 | ''' 10 | 11 | import argparse 12 | import json 13 | from pathlib import Path 14 | 15 | import fileUtils 16 | 17 | def makeQARefMap(squad_obj): 18 | qa_map = dict() 19 | 20 | for data in squad_obj["data"]: 21 | for p in data["paragraphs"]: 22 | for qa in p["qas"]: 23 | qa_id = qa["id"] 24 | 25 | qa_map[qa_id] = qa 26 | # end 27 | # end 28 | # end 29 | 30 | return qa_map 31 | # end 32 | 33 | 34 | 35 | def merge(jsonFileList, version="1.1"): 36 | merged = {"version": version, "data": []} 37 | for json_file in jsonFileList: 38 | with open(json_file) as f: 39 | jo = json.load(f) 40 | 41 | merged["data"].extend(jo["data"]) 42 | # end 43 | # end 44 | 45 | return merged 46 | # end 47 | 48 | def main(args): 49 | if args.src_list: 50 | src_list = fileUtils.loadFileToList(args.src_list) 51 | elif args.src_files: 52 | src_list = args.src_files 53 | elif args.src_dir: 54 | if (args.mode): 55 | file_pat = "{}.{}.json".format(args.pat, args.mode) 56 | else: 57 | file_pat = "{}.json".format(args.pat) 58 | # end 59 | src_list = fileUtils.listUpFiles(args.src_dir, file_pat) 60 | else: 61 | print("No SRC") 62 | return 63 | # end 64 | 65 | print("merge files: {}".format(src_list)) 66 | merged = merge(src_list, args.version) 67 | 68 | dst_file = args.dst_file if args.dst_file else args.dst_file_op 69 | if dst_file: 70 | if isinstance(dst_file, list): 71 | dst_file = dst_file[0] 72 | # end 73 | dst_file = Path(dst_file) 74 | with open(dst_file, "w", encoding="utf-8") as f: 75 | json.dump(merged, f, ensure_ascii=False) 76 | # end 77 | else: 78 | print(json.dumps(merged)) 79 | # end 80 | # end 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("dst_file", nargs="?", default=None) 85 | parser.add_argument("-c", "--src_dir", default=None) 86 | parser.add_argument("-s", "--src_files", dest="src_files", nargs="+") 87 | parser.add_argument("-d", "--dest_file", dest="dst_file_op", nargs=1) 88 | parser.add_argument("-l", dest="src_list", default=None) 89 | parser.add_argument("-r", dest="impossible_item_rate", default=None) 90 | parser.add_argument("--mode", dest="mode", default=None) 91 | parser.add_argument("--pat", dest="pat", default="*") 92 | parser.add_argument("--squad_version", dest="version", default="1.1") 93 | args = parser.parse_args() 94 | 95 | main(args) 96 | # end 97 | -------------------------------------------------------------------------------- /scripts/pairDoc2SQuAD.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on 2019/12/11 4 | 5 | 人手で作成した *.pair ファイルと元テキスト(1行1文)から文対応実験用 SQuAD データを作成する。 6 | 7 | pair_id: 対訳データのファイル名に付けられた 6 桁の番号 8 | qa_id: pair_id + 文対応番号(対訳データ毎に 000000 からの連番) 9 | context: 元テキストを一纏めにしたもの。改行はどうする? 10 | -> 無視する(改行なしで繋げる) 11 | -> 特別なトークンを入れる 12 | -> \n のまま1文字として処理する(ただしJSON上ではエスケープされて \\n になる) 13 | 14 | SQuAD は 2.0 に対応。対応無しの文もデータに入れる。(is_impossible = True) 15 | 16 | 17 | ''' 18 | 19 | import argparse 20 | import json 21 | import sys 22 | from pathlib import Path 23 | from ast import literal_eval 24 | 25 | import fileUtils 26 | import SQuAD 27 | 28 | class SpanInfo(): 29 | def __init__(self, head=None, span=None, pair_head=None): 30 | self.modify(head, span, pair_head) 31 | # end 32 | 33 | def modify(self, head, span, pair_head=None): 34 | self.head = head 35 | self.span = span 36 | self.pair_head = pair_head 37 | # end 38 | # end 39 | 40 | def lineNumCsvToLineIdxs(lines_csv, zero_based_index=False): 41 | ''' 42 | 文番号の CSV 文字列を文インデックスのリストにする 43 | ただし文番号が連続していない場合は空のリストを返す 44 | ''' 45 | offset = 0 if zero_based_index else 1 46 | 47 | if len(lines_csv) == '[]': 48 | return [] 49 | 50 | line_idx_list = literal_eval(lines_csv) 51 | s_l_idx_list = [int(x)-offset for x in line_idx_list] # offset 分ずらして 0 始まりにする 52 | s_l_idx_list.sort() 53 | 54 | # 連続しているなら最大インデックスは先頭インデックス+リストの長さ−1になるはず 55 | idx_max = s_l_idx_list[-1] 56 | if idx_max != s_l_idx_list[0] + len(s_l_idx_list) - 1: 57 | return [] 58 | else: 59 | return s_l_idx_list 60 | # end 61 | # end 62 | 63 | def makeRawLineToCharSpanMap(raw_lines, line_sep=""): 64 | ''' 65 | 元テキストの各文(インデックス)を文字 Span に対応づける 66 | 文区切り記号については line_sep で指定できるようにし、その分も考慮して Span の開始位置を算出 67 | 元テキストの先頭を 0 文字目とする 68 | ''' 69 | raw_line_to_span = [] 70 | sep_len = len(line_sep) 71 | 72 | offset = 0 73 | for line in raw_lines: 74 | l_len = len(line) 75 | span = SQuAD.Span(offset, offset+l_len-1, line) 76 | raw_line_to_span.append(span) 77 | 78 | offset = offset + l_len + sep_len 79 | # end 80 | 81 | return raw_line_to_span 82 | # end 83 | 84 | def lineIndexesToSpan(sent_line_indexes, raw_line_idx_to_span_map, line_sep=""): 85 | ''' 86 | sent_line_indexes: 1対応を構成する片側の文インデックスのリスト 87 | raw_line_idx_to_span_map: 文インデックスと Span の対応マップ 88 | line_sep: 文区切り記号 89 | ''' 90 | start = raw_line_idx_to_span_map[sent_line_indexes[0]].start 91 | end = raw_line_idx_to_span_map[sent_line_indexes[-1]].end 92 | 93 | span_texts = [] 94 | for line_idx in sent_line_indexes: 95 | span_texts.append(raw_line_idx_to_span_map[line_idx].text) 96 | # end 97 | text = line_sep.join(span_texts) 98 | 99 | return SQuAD.Span(start, end, text) 100 | # end 101 | 102 | def setLineToSpan(line_to_sent_span, line_idxs, raw_line_to_span, line_sep): 103 | ''' 104 | 行:Span 対応リストの span 開始行に対応する要素の SpanInfo オブジェクトを変更し、span 開始行のインデックスを返す 105 | ''' 106 | head_idx = line_idxs[0] 107 | char_span = lineIndexesToSpan(line_idxs, raw_line_to_span, line_sep) 108 | line_to_sent_span[head_idx].modify(head=head_idx, span=char_span) 109 | 110 | for idx in line_idxs[1:]: 111 | line_to_sent_span[idx].modify(head=-1, span=None, pair_head=-1) 112 | # end 113 | 114 | return head_idx 115 | # end 116 | 117 | def makeSentSpanList(pair_line, raw_lines_l1, raw_lines_l2, line_sep="", zero_based_index=False): 118 | ''' 119 | 文対応情報と元テキストから文字単位 Span と対応相手の情報を持った辞書データを作成し、 120 | 元テキストの文インデックスと対応付けて格納したリストを返す 121 | 1-n の対応の時、n が連続した文でないなら、その対応は無視する 122 | ''' 123 | line_to_sent_span_l1 = [SpanInfo(head=None, span=SQuAD.Span(-1, -1, line)) for line in raw_lines_l1] 124 | line_to_sent_span_l2 = [SpanInfo(head=None, span=SQuAD.Span(-1, -1, line)) for line in raw_lines_l2] 125 | 126 | raw_line_to_span_l1 = makeRawLineToCharSpanMap(raw_lines_l1, line_sep) 127 | raw_line_to_span_l2 = makeRawLineToCharSpanMap(raw_lines_l2, line_sep) 128 | 129 | for line in pair_line: 130 | lines_csv_l1, lines_csv_l2 = line.split(":") 131 | 132 | # 文対 CSV を文インデックスのリストにする 133 | line_idxs_l1 = lineNumCsvToLineIdxs(lines_csv_l1, zero_based_index) 134 | line_idxs_l2 = lineNumCsvToLineIdxs(lines_csv_l2, zero_based_index) 135 | 136 | # line_idxs が両方空リストでないなら 137 | if line_idxs_l1 and line_idxs_l2: 138 | span_head_idx_l1 = setLineToSpan(line_to_sent_span_l1, line_idxs_l1, raw_line_to_span_l1, line_sep) 139 | span_head_idx_l2 = setLineToSpan(line_to_sent_span_l2, line_idxs_l2, raw_line_to_span_l2, line_sep) 140 | 141 | # 対応する相手側の開始行インデックスを保持 142 | line_to_sent_span_l1[span_head_idx_l1].pair_head = span_head_idx_l2 143 | line_to_sent_span_l2[span_head_idx_l2].pair_head = span_head_idx_l1 144 | # end 145 | # end 146 | 147 | return line_to_sent_span_l1, line_to_sent_span_l2 148 | # end 149 | 150 | def makeContextText(raw_lines, line_sep=""): 151 | ''' 152 | 元データを文区切り記号で繋げて context 用テキストを作る 153 | ''' 154 | return line_sep.join(raw_lines) 155 | # end 156 | 157 | def loadSpansAndContext(pairs_file, raw_file_l1, raw_file_l2, line_sep="", zero_based_index=False): 158 | ''' 159 | データをファイルから読み出し、元テキスト行:span の対応リストと context 用テキストを返す 160 | ''' 161 | pair_lines = fileUtils.loadFileToList(pairs_file) 162 | raw_lines_l1 = fileUtils.loadFileToList(raw_file_l1, commentMark=None) 163 | raw_lines_l2 = fileUtils.loadFileToList(raw_file_l2, commentMark=None) 164 | 165 | # 全角スペースを半角に 166 | raw_lines_l1 = [x.replace(" ", " ") for x in raw_lines_l1] 167 | raw_lines_l2 = [x.replace(" ", " ") for x in raw_lines_l2] 168 | 169 | line_to_sent_span_l1, line_to_sent_span_l2 = makeSentSpanList(pair_lines, raw_lines_l1, raw_lines_l2, line_sep, zero_based_index) 170 | 171 | context_l1 = makeContextText(raw_lines_l1, line_sep) 172 | context_l2 = makeContextText(raw_lines_l2, line_sep) 173 | 174 | return line_to_sent_span_l1, line_to_sent_span_l2, context_l1, context_l2 175 | # end 176 | 177 | def makeQA(line_to_sent_span_Q, line_to_sent_span_A, qa_id_prefix): 178 | qa_list = [] 179 | 180 | for line_idx, q_span_info in enumerate(line_to_sent_span_Q): 181 | if q_span_info.head == -1: 182 | continue 183 | elif q_span_info.head is None: 184 | qa_id = "{}_{}_X".format(qa_id_prefix, line_idx) 185 | q_str = q_span_info.span.text 186 | qa = SQuAD.SQuAD_QA(qa_id, q_str, is_impossible=True) 187 | else: 188 | qa_id = "{}_{}".format(qa_id_prefix, line_idx) 189 | q_str = q_span_info.span.text 190 | qa = SQuAD.SQuAD_QA(qa_id, q_str, is_impossible=False) 191 | 192 | ans_head = q_span_info.pair_head 193 | ans_span = line_to_sent_span_A[ans_head].span 194 | answer = SQuAD.SQuADAnswer(ans_span.text, ans_span.start, ans_span.end) 195 | 196 | qa.appendAnswer(answer) 197 | # end 198 | 199 | qa_list.append(qa) 200 | # end 201 | 202 | return qa_list 203 | # end 204 | 205 | def makeParagraph(line_to_sent_span_l1, line_to_sent_span_l2, context_l1, context_l2, pid_prefix): 206 | paragraph_l1_to_l2 = SQuAD.SQuADParagrap(context_l2) 207 | paragraph_l2_to_l1 = SQuAD.SQuADParagrap(context_l1) 208 | 209 | qa_id_prefix_l1tol2 = "{}_1_2".format(pid_prefix) 210 | qa_list_l1tol2 = makeQA(line_to_sent_span_l1, line_to_sent_span_l2, qa_id_prefix_l1tol2) 211 | qa_id_prefix_l2tol1 = "{}_2_1".format(pid_prefix) 212 | qa_list_l2tol1 = makeQA(line_to_sent_span_l2, line_to_sent_span_l1, qa_id_prefix_l2tol1) 213 | 214 | paragraph_l1_to_l2.extendQA(qa_list_l1tol2) 215 | paragraph_l2_to_l1.extendQA(qa_list_l2tol1) 216 | 217 | return paragraph_l1_to_l2, paragraph_l2_to_l1 218 | # end 219 | 220 | 221 | def main(args): 222 | pairs_file = Path(args.pairs_file) 223 | raw_file_l1 = args.raw_file_l1 224 | raw_file_l2 = args.raw_file_l2 225 | dest = args.dest 226 | title_prefix = args.title_prefix 227 | line_sep = args.line_sep 228 | indent = args.indent 229 | zero_based_index = args.zero_based_index 230 | 231 | title = "{}_{}".format(title_prefix, pairs_file.stem) 232 | 233 | line_to_sent_span_l1, line_to_sent_span_l2, context_l1, context_l2 = loadSpansAndContext(pairs_file, raw_file_l1, raw_file_l2, line_sep, zero_based_index) 234 | paragraph_l1_to_l2, paragraph_l2_to_l1 = makeParagraph(line_to_sent_span_l1, line_to_sent_span_l2, context_l1, context_l2, title) 235 | 236 | squad_data = SQuAD.SQuADDataItem(title) 237 | squad_data.appendParagraph(paragraph_l1_to_l2) 238 | squad_data.appendParagraph(paragraph_l2_to_l1) 239 | 240 | squad = SQuAD.SQuAD(version2=True) 241 | squad.appendData(squad_data) 242 | 243 | with open(dest, "w", encoding='utf-8') as fp: 244 | json.dump(squad.toJson(), fp, ensure_ascii=False, indent=indent) 245 | # end 246 | # end 247 | 248 | if __name__ == '__main__': 249 | parser = argparse.ArgumentParser() 250 | parser.add_argument("pairs_file", help="input *.pair file") 251 | parser.add_argument("raw_file_l1", help="source raw text file") 252 | parser.add_argument("raw_file_l2", help="target raw text file") 253 | parser.add_argument("dest", help="output file") 254 | parser.add_argument("title_prefix", help="prefix for SQuAD title") 255 | parser.add_argument("-s", dest="line_sep", default="\n", help="line separator") 256 | parser.add_argument("-i", dest="indent", type=int, default=2, help="indent level for JSON dump") 257 | parser.add_argument("-z", dest="zero_based_index", action="store_true", help="whether sentence numbers in *.pair file are zero-based") 258 | 259 | args = parser.parse_args() 260 | main(args) 261 | # end 262 | -------------------------------------------------------------------------------- /scripts/score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | modified on Dec. 22, 2020 by Katsuki Chousa 5 | original: thompsonb/vecalign score.py and dp_utils.py 6 | """ 7 | 8 | """ 9 | Copyright 2019 Brian Thompson 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | https://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | 23 | """ 24 | 25 | import argparse 26 | import sys 27 | from collections import defaultdict 28 | from ast import literal_eval 29 | 30 | import numpy as np 31 | 32 | """ 33 | Faster implementation of lax and strict precision and recall, based on 34 | https://www.aclweb.org/anthology/W11-4624/. 35 | 36 | """ 37 | 38 | 39 | def read_alignments(fin): 40 | alignments = [] 41 | with open(fin, 'rt', encoding="utf-8") as infile: 42 | for line in infile: 43 | fields = [x.strip() for x in line.split(':') if len(x.strip())] 44 | if len(fields) < 2: 45 | raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip()) 46 | try: 47 | src = literal_eval(fields[0]) 48 | tgt = literal_eval(fields[1]) 49 | except: 50 | raise Exception('Failed to parse line "%s"' % line.strip()) 51 | alignments.append((src, tgt)) 52 | 53 | # I know bluealign files have a few entries entries missing, 54 | # but I don't fix them in order to be consistent previous reported scores 55 | return alignments 56 | 57 | def _precision(goldalign, testalign): 58 | """ 59 | Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments 60 | """ 61 | tpstrict = 0 # true positive strict counter 62 | tplax = 0 # true positive lax counter 63 | fpstrict = 0 # false positive strict counter 64 | fplax = 0 # false positive lax counter 65 | 66 | # convert to sets, remove alignments empty on both sides 67 | testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)]) 68 | goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)]) 69 | 70 | # mappings from source test sentence idxs to 71 | # target gold sentence idxs for which the source test sentence 72 | # was found in corresponding source gold alignment 73 | src_id_to_gold_tgt_ids = defaultdict(set) 74 | for gold_src, gold_tgt in goldalign: 75 | for gold_src_id in gold_src: 76 | for gold_tgt_id in gold_tgt: 77 | src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id) 78 | 79 | for (test_src, test_target) in testalign: 80 | if (test_src, test_target) == ((), ()): 81 | continue 82 | if (test_src, test_target) in goldalign: 83 | # strict match 84 | tpstrict += 1 85 | tplax += 1 86 | else: 87 | # For anything with partial gold/test overlap on the source, 88 | # see if there is also partial overlap on the gold/test target 89 | # If so, its a lax match 90 | target_ids = set() 91 | for src_test_id in test_src: 92 | for tgt_id in src_id_to_gold_tgt_ids[src_test_id]: 93 | target_ids.add(tgt_id) 94 | if set(test_target).intersection(target_ids): 95 | tplax += 1 96 | else: 97 | fplax += 1 98 | fpstrict += 1 99 | 100 | return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32) 101 | 102 | 103 | def score_multiple(gold_list, test_list, value_for_div_by_0=0.0, delete=True): 104 | # accumulate counts for all gold/test files 105 | pcounts = np.array([0, 0, 0, 0], dtype=np.int32) 106 | rcounts = np.array([0, 0, 0, 0], dtype=np.int32) 107 | for goldalign, testalign in zip(gold_list, test_list): 108 | pcounts += _precision(goldalign=goldalign, testalign=testalign) 109 | # recall is precision with no insertion/deletion and swap args 110 | if delete: 111 | test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)] 112 | gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)] 113 | rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del) 114 | else: 115 | rcounts += _precision(goldalign=testalign, testalign=goldalign) 116 | 117 | # assert pcounts[0] == rcounts[0], "TP for precision and recall are mismatched!" 118 | 119 | # Compute results 120 | # pcounts: tpstrict,fnstrict,tplax,fnlax 121 | # rcounts: tpstrict,fpstrict,tplax,fplax 122 | 123 | if pcounts[0] + pcounts[1] == 0: 124 | pstrict = value_for_div_by_0 125 | else: 126 | pstrict = pcounts[0] / float(pcounts[0] + pcounts[1]) 127 | 128 | if pcounts[2] + pcounts[3] == 0: 129 | plax = value_for_div_by_0 130 | else: 131 | plax = pcounts[2] / float(pcounts[2] + pcounts[3]) 132 | 133 | if rcounts[0] + rcounts[1] == 0: 134 | rstrict = value_for_div_by_0 135 | else: 136 | rstrict = rcounts[0] / float(rcounts[0] + rcounts[1]) 137 | 138 | if rcounts[2] + rcounts[3] == 0: 139 | rlax = value_for_div_by_0 140 | else: 141 | rlax = rcounts[2] / float(rcounts[2] + rcounts[3]) 142 | 143 | if (pstrict + rstrict) == 0: 144 | fstrict = value_for_div_by_0 145 | else: 146 | fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict) 147 | 148 | if (plax + rlax) == 0: 149 | flax = value_for_div_by_0 150 | else: 151 | flax = 2 * (plax * rlax) / (plax + rlax) 152 | 153 | result = dict(recall_strict=rstrict, 154 | recall_lax=rlax, 155 | precision_strict=pstrict, 156 | precision_lax=plax, 157 | f1_strict=fstrict, 158 | f1_lax=flax) 159 | 160 | return result 161 | 162 | def score_separate(gold_alignment_list, test_alignment_list, value_for_div_by_0=0.0): 163 | max_src_align = 0 164 | max_trg_align = 0 165 | gold_list = defaultdict(lambda: []) 166 | for gold_alignments in gold_alignment_list: 167 | for alignment in gold_alignments: 168 | gold_list[(len(alignment[0]), len(alignment[1]))].append(alignment) 169 | max_src_align = max([max_src_align, len(alignment[0])]) 170 | max_trg_align = max([max_trg_align, len(alignment[1])]) 171 | 172 | test_list = defaultdict(lambda: []) 173 | for test_alignments in test_alignment_list: 174 | for alignment in test_alignments: 175 | test_list[(len(alignment[0]), len(alignment[1]))].append(alignment) 176 | 177 | res = np.full((max_src_align + 1, max_trg_align + 1), None) 178 | for num_src in range(max_src_align + 1): 179 | for num_trg in range(max_trg_align + 1): 180 | ret = score_multiple(gold_list=[gold_list[(num_src, num_trg)]], 181 | test_list=[test_list[(num_src, num_trg)]], delete=False) 182 | res[num_src][num_trg] = "{precision_strict:.3f}/{recall_strict:.3f}/{f1_strict:.3f} ".format(**ret) + "({})".format(len(gold_list[(num_src, num_trg)])) 183 | 184 | return res 185 | 186 | 187 | def log_final_scores(res): 188 | print(' ---------------------------------', file=sys.stderr) 189 | print('| | Strict | Lax |', file=sys.stderr) 190 | print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr) 191 | print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr) 192 | print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr) 193 | print(' ---------------------------------', file=sys.stderr) 194 | 195 | def log_matrix(res): 196 | from tabulate import tabulate 197 | 198 | header = ["trg/src"] + list(range(len(res) + 1)) 199 | table = [] 200 | for num_trg in range(len(res[0])): 201 | row_result = [str(num_trg)] 202 | for num_src in range(len(res)): 203 | row_result.append(res[(num_src, num_trg)]) 204 | table.append(row_result) 205 | 206 | print(tabulate(table, header), file=sys.stderr) 207 | 208 | 209 | def main(): 210 | parser = argparse.ArgumentParser( 211 | 'Compute strict/lax precision and recall for one or more pairs of gold/test alignments', 212 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 213 | 214 | parser.add_argument('-t', '--test', type=str, nargs='+', required=True, 215 | help='one or more test alignment files') 216 | 217 | parser.add_argument('-g', '--gold', type=str, nargs='+', required=True, 218 | help='one or more gold alignment files') 219 | 220 | args = parser.parse_args() 221 | 222 | if len(args.test) != len(args.gold): 223 | raise Exception('number of gold/test files must be the same') 224 | 225 | gold_list = [read_alignments(x) for x in args.gold] 226 | test_list = [read_alignments(x) for x in args.test] 227 | 228 | res = score_multiple(gold_list=gold_list, test_list=test_list) 229 | log_final_scores(res) 230 | res = score_separate(gold_list, test_list) 231 | log_matrix(res) 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from functools import partial 4 | from multiprocessing import Pool, cpu_count 5 | 6 | import os 7 | import logging 8 | import json 9 | import h5py 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | import transformers 14 | from transformers.data.processors.squad import ( 15 | squad_convert_example_to_features_init, 16 | squad_convert_example_to_features, 17 | SquadV2Processor, 18 | SquadFeatures, 19 | _is_whitespace, 20 | ) 21 | 22 | import torch 23 | from torch.utils.data import Dataset 24 | 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | def load_and_cache_examples(args, tokenizer, evaluate=False): 30 | if args.local_rank not in [-1, 0] and not evaluate: 31 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 32 | torch.distributed.barrier() 33 | 34 | # Load data features from cache or dataset file 35 | input_dir = args.data_dir if args.data_dir else "." 36 | cached_features_file = os.path.join( 37 | input_dir, 38 | "cached_{}_{}_{}".format( 39 | "dev" if evaluate else "train", 40 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 41 | str(args.max_seq_length), 42 | ), 43 | ) 44 | 45 | # Init features and dataset from cache if it exists 46 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 47 | logger.info("Loading features from cached file %s", cached_features_file) 48 | else: 49 | logger.info("Creating features from dataset file at %s", input_dir) 50 | 51 | assert args.data_dir, "data_dir must be set." 52 | assert not evaluate or args.predict_file, 'When evalute == true, predict_file must be specified.' 53 | assert evaluate or args.train_file, "at least one of evaluate and train_file must be specified." 54 | 55 | processor = MySquadProcessor() 56 | if evaluate: 57 | examples = processor.get_dev_examples("", filename=args.predict_file) 58 | else: 59 | examples = processor.get_train_examples("", filename=args.train_file) 60 | 61 | feature_writer = CacheDataWriter(cached_features_file) 62 | squad_convert_examples_to_features( 63 | examples=examples, 64 | tokenizer=tokenizer, 65 | max_seq_length=args.max_seq_length, 66 | doc_stride=args.doc_stride, 67 | max_query_length=args.max_query_length, 68 | is_training=not evaluate, 69 | feature_writer=feature_writer, 70 | threads=args.threads, 71 | ) 72 | 73 | if args.local_rank == 0 and not evaluate: 74 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 75 | torch.distributed.barrier() 76 | 77 | reader = CacheDataReader(cached_features_file, is_training=not evaluate) 78 | return reader 79 | 80 | def squad_convert_example_to_features_try(example, max_seq_length, doc_stride, max_query_length, is_training): 81 | try: 82 | ret = squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, "max_length", is_training) 83 | except: 84 | logger.warning('error on {}'.format(example.qas_id)) 85 | ret = [] 86 | 87 | return ret 88 | 89 | def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, feature_writer, threads=1): 90 | 91 | 92 | # Defining helper methods 93 | threads = min(threads, cpu_count()) 94 | with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: 95 | annotate_ = partial( 96 | squad_convert_example_to_features_try, 97 | max_seq_length=max_seq_length, 98 | doc_stride=doc_stride, 99 | max_query_length=max_query_length, 100 | is_training=is_training, 101 | ) 102 | 103 | unique_id = 1000000000 104 | example_index = 0 105 | tmp_features = [] 106 | for example_features in tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc="convert squad examples to features"): 107 | if not example_features: 108 | continue 109 | 110 | for example_feature in example_features: 111 | example_feature.example_index = example_index 112 | example_feature.unique_id = unique_id 113 | unique_id += 1 114 | tmp_features.append(example_feature) 115 | 116 | example_index += 1 117 | if len(tmp_features) > 10000: 118 | feature_writer.write_features(tmp_features) 119 | tmp_features.clear() 120 | if tmp_features: 121 | feature_writer.write_features(tmp_features) 122 | tmp_features.clear() 123 | 124 | feature_writer.write_examples(examples) 125 | 126 | 127 | class CacheDataWriter(): 128 | def __init__(self, cache_file): 129 | self.examples_file = cache_file + ".examples" 130 | self.features_file = cache_file 131 | 132 | self.vl_int_dt = h5py.vlen_dtype("i8") 133 | self.str_dt = h5py.string_dtype(encoding="utf-8") 134 | self.tok_is_max_cntxt_dt = np.dtype([("tok_id", "i"), ("flag", "?")]) 135 | self.vl_timc_dt = h5py.vlen_dtype(self.tok_is_max_cntxt_dt) 136 | self.tok_to_orig_map_dt = np.dtype([("tok_id", "i"), ("orig_txt_id", "i")]) 137 | self.vl_ttom_dt = h5py.vlen_dtype(self.tok_to_orig_map_dt) 138 | 139 | self.feature_vals = {"input_ids": [], "attention_mask": [], "token_type_ids": [], 140 | "cls_index": [], "p_mask": [], "example_index": [], "unique_id": [], 141 | "paragraph_len": [], "token_is_max_context": [], "tokens": [], 142 | "token_to_orig_map": [], "start_position": [], "end_position": [], "is_impossible": []} 143 | 144 | with h5py.File(self.features_file, "w") as hdf: 145 | g_features = hdf.create_group("features") 146 | g_features.create_dataset("input_ids", shape=(1,), dtype=self.vl_int_dt, maxshape=(None,)) 147 | g_features.create_dataset("attention_mask", shape=(1,), dtype=self.vl_int_dt, maxshape=(None,)) 148 | g_features.create_dataset("token_type_ids", shape=(1,), dtype=self.vl_int_dt, maxshape=(None,)) 149 | g_features.create_dataset("cls_index", shape=(1,), dtype="i8", maxshape=(None,)) 150 | g_features.create_dataset("p_mask", shape=(1,), dtype=self.vl_int_dt, maxshape=(None,)) 151 | g_features.create_dataset("example_index", shape=(1,), dtype="i8", maxshape=(None,)) 152 | g_features.create_dataset("unique_id", shape=(1,), dtype="i8", maxshape=(None,)) 153 | g_features.create_dataset("paragraph_len", shape=(1,), dtype="i8", maxshape=(None,)) 154 | g_features.create_dataset("token_is_max_context", shape=(1,), dtype=self.vl_timc_dt, maxshape=(None,)) 155 | g_features.create_dataset("tokens", shape=(1,), dtype=self.str_dt, maxshape=(None,)) 156 | g_features.create_dataset("token_to_orig_map", shape=(1,), dtype=self.vl_ttom_dt, maxshape=(None,)) 157 | g_features.create_dataset("start_position", shape=(1,), dtype="i8", maxshape=(None,)) 158 | g_features.create_dataset("end_position", shape=(1,), dtype="i8", maxshape=(None,)) 159 | g_features.create_dataset("is_impossible", shape=(1,), dtype="?", maxshape=(None,)) 160 | g_features.create_dataset("feature_index", shape=(1,), dtype="i8", maxshape=(None,)) 161 | g_features.attrs["size"] = 0 162 | g_features.attrs["offset"] = 0 163 | 164 | # features が要素毎にリスト化して持っているので、それが dataset になる 165 | g_dataset = hdf.create_group("dataset") 166 | g_train_dataset = g_dataset.create_group("train") 167 | g_eval_dataset = g_dataset.create_group("eval") 168 | 169 | g_dataset.attrs["size"] = g_features.attrs["size"] 170 | g_train_dataset.attrs["size"] = g_dataset.attrs["size"] 171 | g_eval_dataset.attrs["size"] = g_dataset.attrs["size"] 172 | 173 | g_train_dataset["all_input_ids"] = g_features["input_ids"] 174 | g_train_dataset["all_attention_masks"] = g_features["attention_mask"] 175 | g_train_dataset["all_token_type_ids"] = g_features["token_type_ids"] 176 | g_train_dataset["all_start_positions"] = g_features["start_position"] 177 | g_train_dataset["all_end_positions"] = g_features["end_position"] 178 | g_train_dataset["all_cls_index"] = g_features["cls_index"] 179 | g_train_dataset["all_p_mask"] = g_features["p_mask"] 180 | g_train_dataset["all_is_impossible"] = g_features["is_impossible"] 181 | 182 | g_eval_dataset["all_input_ids"] = g_features["input_ids"] 183 | g_eval_dataset["all_attention_masks"] = g_features["attention_mask"] 184 | g_eval_dataset["all_token_type_ids"] = g_features["token_type_ids"] 185 | g_eval_dataset["all_cls_index"] = g_features["cls_index"] 186 | g_eval_dataset["all_p_mask"] = g_features["p_mask"] 187 | # all_example_index は 個々の feature の example_index の集約ではなく 188 | # feature の index = 何番目の feature かを表す 189 | g_eval_dataset["all_example_index"] = g_features["feature_index"] 190 | 191 | def write_examples(self, examples): 192 | torch.save(examples, self.examples_file) 193 | 194 | def _set_dataset_size(self, hdf, size): 195 | gp_dataset = hdf["/dataset"] 196 | gp_eval_dataset = hdf["/dataset/eval"] 197 | gp_train_dataset = hdf["/dataset/train"] 198 | 199 | gp_dataset.attrs["size"] = size 200 | gp_eval_dataset.attrs["size"] = size 201 | gp_train_dataset.attrs["size"] = size 202 | 203 | def write_features(self, features): 204 | with h5py.File(self.features_file, "a") as hdf: 205 | gp_features = hdf["/features"] 206 | offset = gp_features.attrs["offset"] 207 | limit = offset + len(features) 208 | new_size = gp_features.attrs["size"] + len(features) 209 | 210 | for _k, ds in gp_features.items(): 211 | ds.resize(size=(new_size, )) 212 | 213 | for f in features: 214 | self.feature_vals["input_ids"].append(f.input_ids) 215 | self.feature_vals["attention_mask"].append(f.attention_mask) 216 | self.feature_vals["token_type_ids"].append(f.token_type_ids) 217 | self.feature_vals["cls_index"].append(f.cls_index) 218 | self.feature_vals["p_mask"].append(f.p_mask) 219 | self.feature_vals["example_index"].append(f.example_index) 220 | self.feature_vals["unique_id"].append(f.unique_id) 221 | self.feature_vals["paragraph_len"].append(f.paragraph_len) 222 | self.feature_vals["token_is_max_context"].append(np.array([(k, v) for k,v in f.token_is_max_context.items()], 223 | dtype=self.tok_is_max_cntxt_dt)) 224 | self.feature_vals["tokens"].append(json.dumps(f.tokens, ensure_ascii=False)) 225 | self.feature_vals["token_to_orig_map"].append(np.array([(k, v) for k,v in f.token_to_orig_map.items()], 226 | dtype=self.tok_to_orig_map_dt)) 227 | self.feature_vals["start_position"].append(f.start_position) 228 | self.feature_vals["end_position"].append(f.end_position) 229 | self.feature_vals["is_impossible"].append(f.is_impossible) 230 | 231 | for k, v in self.feature_vals.items(): 232 | # limit = offset + len(v) 233 | gp_features[k][offset:limit] = v 234 | v.clear() 235 | 236 | f_idxs = [x for x in range(offset, limit)] 237 | gp_features["feature_index"][offset:limit] = f_idxs 238 | 239 | gp_features.attrs["size"] = new_size 240 | gp_features.attrs["offset"] = new_size 241 | self._set_dataset_size(hdf, new_size) 242 | 243 | 244 | class CacheDataReader(): 245 | def __init__(self, cache_file, is_training=False): 246 | self.examples_file = cache_file + ".examples" 247 | self.cache_data = h5py.File(cache_file, "r", swmr=True) 248 | if is_training: 249 | self.dataset_group = self.cache_data["/dataset/train"] 250 | self.data_keys = ["all_input_ids", 251 | "all_attention_masks", 252 | "all_token_type_ids", 253 | "all_start_positions", 254 | "all_end_positions", 255 | "all_cls_index", 256 | "all_p_mask", 257 | "all_is_impossible" 258 | ] 259 | else: 260 | self.dataset_group = self.cache_data["/dataset/eval"] 261 | self.data_keys = ["all_input_ids", 262 | "all_attention_masks", 263 | "all_token_type_ids", 264 | "all_example_index", 265 | "all_cls_index", 266 | "all_p_mask", 267 | ] 268 | 269 | def __del__(self): 270 | self.cache_data.close() 271 | 272 | s = super() 273 | if hasattr(s, "__del__"): 274 | s.__del__(self) 275 | 276 | def get_item(self, index): 277 | return tuple(self.dataset_group[key][index] for key in self.data_keys) 278 | 279 | def get_size(self): 280 | return self.dataset_group.attrs["size"] 281 | 282 | def load_examples(self): 283 | examples = torch.load(self.examples_file) 284 | return examples 285 | 286 | def get_features(self): 287 | features = [self.get_feature(index) for index in range(self.get_size())] 288 | return features 289 | 290 | def get_feature(self, index): 291 | token_is_max_context = dict(self.cache_data["/features/token_is_max_context"][index]) 292 | tokens = json.loads(self.cache_data["/features/tokens"][index]) 293 | token_to_orig_map = dict(self.cache_data["/features/token_to_orig_map"][index]) 294 | 295 | feature = SquadFeatures( 296 | self.cache_data["/features/input_ids"][index], 297 | self.cache_data["/features/attention_mask"][index], 298 | self.cache_data["/features/token_type_ids"][index], 299 | self.cache_data["/features/cls_index"][index], 300 | self.cache_data["/features/p_mask"][index], 301 | example_index=self.cache_data["/features/example_index"][index], 302 | unique_id=self.cache_data["/features/unique_id"][index], 303 | paragraph_len=self.cache_data["/features/paragraph_len"][index], 304 | token_is_max_context=token_is_max_context, 305 | tokens=tokens, 306 | token_to_orig_map=token_to_orig_map, 307 | start_position=self.cache_data["/features/start_position"][index], 308 | end_position=self.cache_data["/features/end_position"][index], 309 | is_impossible=self.cache_data["/features/is_impossible"][index], 310 | ) 311 | 312 | return feature 313 | 314 | def get_feature_value(self, key, index): 315 | path = "/features/{}".format(key) 316 | value = self.cache_data[path][index] 317 | return value 318 | 319 | 320 | class HDF5Dataset(Dataset): 321 | def __init__(self, cache_reader): 322 | self.cache_reader = cache_reader 323 | 324 | def __getitem__(self, index): 325 | return self.cache_reader.get_item(index) 326 | 327 | def __len__(self): 328 | return self.cache_reader.get_size() 329 | 330 | 331 | class MySquadProcessor(SquadV2Processor): 332 | def _create_examples(self, input_data, set_type): 333 | is_training = set_type == "train" 334 | examples = [] 335 | for entry in tqdm(input_data): 336 | title = entry["title"] 337 | for paragraph in entry["paragraphs"]: 338 | context_text = paragraph["context"] 339 | doc_tokens, char_to_word_offset = MySquadExample.divide_context_text(context_text) 340 | 341 | for qa in paragraph["qas"]: 342 | qas_id = qa["id"] 343 | question_text = qa["question"] 344 | start_position_character = None 345 | answer_text = None 346 | answers = [] 347 | 348 | if "is_impossible" in qa: 349 | is_impossible = qa["is_impossible"] 350 | else: 351 | is_impossible = False 352 | 353 | if not is_impossible: 354 | if is_training: 355 | answer = qa["answers"][0] 356 | answer_text = answer["text"] 357 | start_position_character = answer["answer_start"] 358 | else: 359 | answers = qa["answers"] 360 | 361 | if start_position_character is not None and not is_impossible: 362 | start_position = char_to_word_offset[start_position_character] 363 | end_position = char_to_word_offset[ 364 | min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) 365 | ] 366 | else: 367 | start_position = -1 368 | end_position = -1 369 | 370 | example = MySquadExample( 371 | qas_id=qas_id, 372 | question_text=question_text, 373 | doc_tokens=doc_tokens, 374 | answer_text=answer_text, 375 | start_position=start_position, 376 | end_position=end_position, 377 | title=title, 378 | answers=answers, 379 | is_impossible=is_impossible, 380 | ) 381 | 382 | examples.append(example) 383 | return examples 384 | 385 | 386 | class MySquadExample(object): 387 | def __init__(self, 388 | qas_id, 389 | question_text, 390 | doc_tokens, 391 | answer_text, 392 | start_position, 393 | end_position, 394 | title, 395 | answers=[], 396 | is_impossible=False): 397 | 398 | self.qas_id = qas_id 399 | self.question_text = question_text 400 | self.doc_tokens = doc_tokens 401 | self.answer_text = answer_text 402 | self.title = title 403 | self.start_position = start_position 404 | self.end_position = end_position 405 | self.is_impossible = is_impossible 406 | self.answers = answers 407 | 408 | @staticmethod 409 | def divide_context_text(context_text): 410 | doc_tokens = [] 411 | char_to_word_offset = [] 412 | prev_is_whitespace = True 413 | 414 | for c in context_text: 415 | if _is_whitespace(c): 416 | prev_is_whitespace = True 417 | else: 418 | if prev_is_whitespace: 419 | doc_tokens.append(c) 420 | else: 421 | doc_tokens[-1] += c 422 | prev_is_whitespace = False 423 | char_to_word_offset.append(len(doc_tokens) - 1) 424 | 425 | return doc_tokens, char_to_word_offset 426 | --------------------------------------------------------------------------------