├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── script ├── interactive.sh ├── preprocess.sh └── train.sh └── src ├── convert_m2_to_parallel.py └── remove.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Masahiro Kaneko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fairseqで文法誤り訂正モデルを学習,推論と評価する 2 | 3 | 文法誤り訂正 (Grammatical Error Correction; GEC) に入門する[ブログ記事]()のコードである. 4 | 5 | ## セットアップ 6 | 7 | 実行環境はpython3.9であり,`pip install -r requirements.txt`により必要なライブラリをインストールする. 8 | 9 | ### Fairseの準備 10 | ```shell 11 | git clone https://github.com/pytorch/fairseq.git 12 | cd fairseq 13 | pip install --editable ./ 14 | cd ../ 15 | ``` 16 | 17 | ## データセットの入手と前処理 18 | 19 | W&I+LOCNESSとFCEはwgetによりダウンロードすることが可能である.Lang-8とNUCLEはリクエストが必要であるため,各自申請して`data/m2`に配置する.これらのデータはM2形式で配布されており,Fairseqで取り扱えるようにパラレル形式(ソースデータとターゲットデータ)に変換する必要がある. 20 | ```shell 21 | # W&I+LOCNESSとFCEデータをダウンロードし`data/m2`ディレクトリに配置する.Lang-8とNUCLEは適宜リクエストして配置する. 22 | M2_DIR=data/m2 23 | PARA_DIR=data/parallel 24 | mkdir -p $M2_DIR 25 | mkdir -p $PARA_DIR 26 | wget https://www.cl.cam.ac.uk/research/nl/bea2019st/data/fce_v2.1.bea19.tar.gz -O - | tar xvf - -C $M2_DIR 27 | wget https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz -O - | tar xvf - -C $M2_DIR 28 | # データに対して前処理(1:M2形式からパラレルデータ形式に変換する,2:データを結合する,3:訂正されていない文対を除外する)を行う. 29 | ./script/preprocess.sh $M2_DIR $PARA_DIR 30 | ``` 31 | 32 | 上記のコマンドによりW&I+LOCNESSの評価データはダウンロードされているが,CoNLL-2014とJFLEGは以下のコマンドで`data`にダウンロードする必要がある. 33 | 34 | ```shell 35 | # CoNLL-2014のダウンロードとパラレルデータ形式に変換 36 | wget https://www.comp.nus.edu.sg/~nlp/conll14st/conll14st-test-data.tar.gz -O - | tar xvf - -C data 37 | python src/convert_m2_to_parallel.py data/conll14st-test-data/noalt/official-2014.combined.m2 data/conll14st-test-data/noalt/conll2014.src data/conll14st-test-data/noalt/conll2014.trg 38 | # JFLEGのダウンロード 39 | git clone https://github.com/keisks/jfleg.git data/jfleg 40 | ``` 41 | 42 | ## GECモデルの学習 43 | 44 | `train.sh`を使い作成したデータをバイナリーデータにしGECモデルを学習する.ここではGECモデルとしてTransformer-bigを使用する. 45 | ```shell 46 | ./script/train.sh 47 | ``` 48 | 49 | 50 | ## GECモデルの推論と評価 51 | 52 | 推論結果を評価するために評価指標を3つ`eval`ディレクトリに配置する.W&I+LOCNESSはCodaLabで評価するためERRANTは直接使わない. 53 | ```shell 54 | mkdir eval 55 | # M2のダウンロード 56 | git clone https://github.com/kanekomasahiro/m2_python3.git eval/m2_python3 57 | # GLEUのダウンロード 58 | git clone https://github.com/kanekomasahiro/gec-ranking_python3.git eval/gec-ranking_python3 59 | # 使わないが一応ERRANTのダウンロード 60 | git clone https://github.com/chrisjbryant/errant.git eval/errant 61 | ``` 62 | 63 | `interactive.sh`を使い学習したモデルで評価データに対して推論を行う.wi,conllまたはjflegのどれを推論するか引数で指定する.そして,CoNLL-2014(評価に時間がかかることがある)とJFLEGに対しては推論結果の評価も行われる.評価結果や出力結果は`output/$seed`に保存される. 64 | ```shell 65 | ./script/interactive.sh [wi/conll/jfleg] 66 | ``` 67 | W&I+LOCNESSは評価データのターゲット側が公開されていないため,[CodaLab](https://competitions.codalab.org/competitions/20228)にGECモデルの推論結果を投稿する必要がある.アカウントを作成し,`zip`コマンドにより推論結果を圧縮しParticipateのSubmitを押してアップロードすることでスコアを取得できる. 68 | seedによって1.5ぐらい前後するがスコアとしてはW&I+LOCNESS: 50, CoNLL-2014: 49, JFLEG: 53がでる. 69 | 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mock==4.0.3 2 | subword-nmt==0.3.8 3 | tqdm==4.62.3 4 | scipy==1.7.3 5 | -------------------------------------------------------------------------------- /script/interactive.sh: -------------------------------------------------------------------------------- 1 | seed=1111 2 | num_operations=8000 3 | beam=5 4 | test_data=$1 # wi conll jfleg 5 | 6 | FAIRSEQ_DIR=fairseq/fairseq_cli 7 | DATA_DIR=data 8 | PROCESSED_DIR=process 9 | MODEL_DIR=model/$seed 10 | OUTPUT_DIR=output 11 | EVAL_DIR=eval 12 | 13 | export PYTHONPATH=$FAIRSEQ_DIR 14 | 15 | mkdir -p $OUTPUT_DIR/$seed 16 | 17 | if [ -e $PROCESSED_DIR/$seed/${test_data}_bin ]; then 18 | echo $test_data のバイナリーデータは既に存在する. 19 | else 20 | echo $test_data のバイナリーデータを作成する. 21 | cpu_num=`grep -c ^processor /proc/cpuinfo` 22 | if [ $test_data = 'wi' ]; then 23 | subword-nmt apply-bpe -c $PROCESSED_DIR/$seed/trg_$num_operations.bpe \ 24 | < $DATA_DIR/wi.test.src \ 25 | > $PROCESSED_DIR/$seed/$test_data.src 26 | elif [ $test_data = 'conll' ]; then 27 | subword-nmt apply-bpe -c $PROCESSED_DIR/$seed/trg_$num_operations.bpe \ 28 | < $DATA_DIR/conll14st-test-data/noalt/conll2014.src \ 29 | > $PROCESSED_DIR/$seed/$test_data.src 30 | elif [ $test_data = 'jfleg' ]; then 31 | subword-nmt apply-bpe -c $PROCESSED_DIR/$seed/trg_$num_operations.bpe \ 32 | < $DATA_DIR/jfleg/test/test.src \ 33 | > $PROCESSED_DIR/$seed/$test_data.src 34 | fi 35 | 36 | cp $PROCESSED_DIR/$seed/$test_data.src $PROCESSED_DIR/$seed/$test_data.trg 37 | python -u $FAIRSEQ_DIR/preprocess.py \ 38 | --source-lang src \ 39 | --target-lang trg \ 40 | --trainpref $PROCESSED_DIR/$seed/train \ 41 | --validpref $PROCESSED_DIR/$seed/$test_data \ 42 | --testpref $PROCESSED_DIR/$seed/$test_data \ 43 | --destdir $PROCESSED_DIR/$seed/${test_data}_bin \ 44 | --srcdict $PROCESSED_DIR/$seed/bin/dict.src.txt \ 45 | --tgtdict $PROCESSED_DIR/$seed/bin/dict.trg.txt \ 46 | --workers $cpu_num \ 47 | --tokenizer space 48 | fi 49 | 50 | # GECモデルを用いて評価データの推論 51 | python -u $FAIRSEQ_DIR/interactive.py $PROCESSED_DIR/$seed/bin \ 52 | --source-lang src \ 53 | --target-lang trg \ 54 | --path $MODEL_DIR/checkpoint_best.pt \ 55 | --beam $beam \ 56 | --nbest $beam \ 57 | --no-progress-bar \ 58 | --buffer-size 1024 \ 59 | --batch-size 32 \ 60 | --log-format simple \ 61 | --remove-bpe \ 62 | < $PROCESSED_DIR/$seed/$test_data.src > $OUTPUT_DIR/$seed/$test_data.nbest.tok 63 | 64 | # n-bestから1-bestを抽出する 65 | cat $OUTPUT_DIR/$seed/$test_data.nbest.tok | grep "^H" | python -c "import sys; x = sys.stdin.readlines(); x = ' '.join([ x[i] for i in range(len(x)) if (i % ${beam} == 0) ]); print(x)" | cut -f3 > $OUTPUT_DIR/$seed/$test_data.best.tok 66 | sed -i '$d' $OUTPUT_DIR/$seed/$test_data.best.tok 67 | 68 | # 推論結果を評価する 69 | if [ $test_data = 'conll' ]; then 70 | CONLL_DIR=data/conll14st-test-data/noalt/ 71 | $EVAL_DIR/m2_python3/m2scorer $OUTPUT_DIR/$seed/$test_data.best.tok $CONLL_DIR/official-2014.combined.m2 > $OUTPUT_DIR/$seed/$test_data.eval 72 | elif [ $test_data = 'jfleg' ]; then 73 | JFLEG_DIR=data/jfleg/test 74 | $EVAL_DIR/gec-ranking_python3/scripts/compute_gleu -s $JFLEG_DIR/test.src -r $JFLEG_DIR/test.ref0 $JFLEG_DIR/test.ref1 $JFLEG_DIR/test.ref2 $JFLEG_DIR/test.ref3 -o $OUTPUT_DIR/$seed/$test_data.best.tok -n 4 > $OUTPUT_DIR/$seed/$test_data.eval 75 | fi 76 | -------------------------------------------------------------------------------- /script/preprocess.sh: -------------------------------------------------------------------------------- 1 | M2_DIR=$1 2 | PARA_DIR=$2 3 | 4 | # M2形式のFCEをパラレル形式に変換 5 | python src/convert_m2_to_parallel.py $M2_DIR/fce/m2/fce.train.gold.bea19.m2 \ 6 | $PARA_DIR/fce.train.src \ 7 | $PARA_DIR/fce.train.trg 8 | python src/convert_m2_to_parallel.py $M2_DIR/fce//m2/fce.dev.gold.bea19.m2 \ 9 | $PARA_DIR/fce.dev.src \ 10 | $PARA_DIR/fce.dev.trg 11 | python src/convert_m2_to_parallel.py $M2_DIR/fce/m2/fce.test.gold.bea19.m2 \ 12 | $PARA_DIR/fce.test.src \ 13 | $PARA_DIR/fce.test.trg 14 | 15 | # M2形式のNUCLEをパラレル形式に変換 16 | python src/convert_m2_to_parallel.py $M2_DIR/release3.3/bea2019/nucle.train.gold.bea19.m2 \ 17 | $PARA_DIR/nucle.train.src \ 18 | $PARA_DIR/nucle.train.trg 19 | 20 | # M2形式のLang-8をパラレル形式に変換 21 | python src/convert_m2_to_parallel.py $M2_DIR/lang8/lang8.train.auto.bea19.m2 \ 22 | $PARA_DIR/lang8.train.src \ 23 | $PARA_DIR/lang8.train.trg 24 | 25 | # M2形式のW&I+LOCNESSをパラレル形式に変換 26 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/A.train.gold.bea19.m2 \ 27 | $PARA_DIR/wi.trainA.src \ 28 | $PARA_DIR/wi.trainA.trg 29 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/B.train.gold.bea19.m2 \ 30 | $PARA_DIR/wi.trainB.src \ 31 | $PARA_DIR/wi.trainB.trg 32 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/C.train.gold.bea19.m2 \ 33 | $PARA_DIR/wi.trainC.src \ 34 | $PARA_DIR/wi.trainC.trg 35 | 36 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/A.dev.gold.bea19.m2 \ 37 | $PARA_DIR/wi.devA.src \ 38 | $PARA_DIR/wi.devA.trg 39 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/B.dev.gold.bea19.m2 \ 40 | $PARA_DIR/wi.devB.src \ 41 | $PARA_DIR/wi.devB.trg 42 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/C.dev.gold.bea19.m2 \ 43 | $PARA_DIR/wi.devC.src \ 44 | $PARA_DIR/wi.devC.trg 45 | python src/convert_m2_to_parallel.py $M2_DIR/wi+locness/m2/N.dev.gold.bea19.m2 \ 46 | $PARA_DIR/wi.devN.src \ 47 | $PARA_DIR/wi.devN.trg 48 | 49 | # 学習データのソース側の結合 50 | cat $PARA_DIR/fce.train.src \ 51 | $PARA_DIR/fce.dev.src \ 52 | $PARA_DIR/fce.test.src \ 53 | $PARA_DIR/nucle.train.src \ 54 | $PARA_DIR/lang8.train.src \ 55 | $PARA_DIR/wi.trainA.src \ 56 | $PARA_DIR/wi.trainB.src \ 57 | $PARA_DIR/wi.trainC.src \ 58 | > $PARA_DIR/train.src 59 | 60 | # 学習データのターゲット側の結合 61 | cat $PARA_DIR/fce.train.trg \ 62 | $PARA_DIR/fce.dev.trg \ 63 | $PARA_DIR/fce.test.trg \ 64 | $PARA_DIR/nucle.train.trg \ 65 | $PARA_DIR/lang8.train.trg \ 66 | $PARA_DIR/wi.trainA.trg \ 67 | $PARA_DIR/wi.trainB.trg \ 68 | $PARA_DIR/wi.trainC.trg \ 69 | > $PARA_DIR/train.trg 70 | 71 | # 学習データから訂正されていない文対を除去する 72 | python src/remove.py --source-lang src --target-lang trg --trainpref $PARA_DIR/train 73 | 74 | # 開発データのソース側の結合 75 | cat $PARA_DIR/wi.devA.src \ 76 | $PARA_DIR/wi.devB.src \ 77 | $PARA_DIR/wi.devC.src \ 78 | $PARA_DIR/wi.devN.src \ 79 | > $PARA_DIR/dev.src 80 | 81 | # 開発データのターゲット側の結合 82 | cat $PARA_DIR/wi.devA.trg \ 83 | $PARA_DIR/wi.devB.trg \ 84 | $PARA_DIR/wi.devC.trg \ 85 | $PARA_DIR/wi.devN.trg \ 86 | > $PARA_DIR/dev.trg 87 | 88 | # W&Iの評価データはソース側しか存在しないためそのままコピーしてくる 89 | cp $M2_DIR/wi+locness/test/ABCN.test.bea19.orig $PARA_DIR/wi.test.src 90 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | seed=1111 2 | num_operations=8000 3 | cpu_num=`grep -c ^processor /proc/cpuinfo` 4 | 5 | FAIRSEQ_DIR=fairseq/fairseq_cli 6 | DATA_DIR=data/parallel 7 | PROCESSED_DIR=process/$seed 8 | MODEL_DIR=model/$seed 9 | 10 | mkdir -p $PROCESSED_DIR 11 | 12 | # Fairseqに読み込ませるためのバイナリーデータを作成する. 13 | 14 | if [ -e $PROCESSED_DIR/bin ]; then 15 | echo 既にバイナリーデータは存在している. 16 | else 17 | echo バイナリデータを作成する. 18 | 19 | mkdir -p $PROCESSED_DIR/bin 20 | subword-nmt learn-bpe -s $num_operations < $DATA_DIR/train_corrected.trg \ 21 | > $PROCESSED_DIR/trg_$num_operations.bpe 22 | subword-nmt apply-bpe -c $PROCESSED_DIR/trg_$num_operations.bpe \ 23 | < $DATA_DIR/train_corrected.src \ 24 | > $PROCESSED_DIR/train.src 25 | subword-nmt apply-bpe -c $PROCESSED_DIR/trg_$num_operations.bpe \ 26 | < $DATA_DIR/train_corrected.trg \ 27 | > $PROCESSED_DIR/train.trg 28 | subword-nmt apply-bpe -c $PROCESSED_DIR/trg_$num_operations.bpe \ 29 | < $DATA_DIR/dev.src \ 30 | > $PROCESSED_DIR/dev.src 31 | subword-nmt apply-bpe -c $PROCESSED_DIR/trg_$num_operations.bpe \ 32 | < $DATA_DIR/dev.trg \ 33 | > $PROCESSED_DIR/dev.trg 34 | 35 | python -u $FAIRSEQ_DIR/preprocess.py \ 36 | --source-lang src \ 37 | --target-lang trg \ 38 | --trainpref $PROCESSED_DIR/train \ 39 | --validpref $PROCESSED_DIR/dev \ 40 | --testpref $PROCESSED_DIR/dev \ 41 | --destdir $PROCESSED_DIR/bin \ 42 | --workers $cpu_num \ 43 | --joined-dictionary \ 44 | --tokenizer space 45 | fi 46 | 47 | # GECモデルの学習 48 | 49 | mkdir -p $MODEL_DIR 50 | 51 | python -u $FAIRSEQ_DIR/train.py $PROCESSED_DIR/bin \ 52 | --save-dir $MODEL_DIR \ 53 | --source-lang src \ 54 | --target-lang trg \ 55 | --log-format simple \ 56 | --fp16 \ 57 | --max-epoch 30 \ 58 | --arch transformer_vaswani_wmt_en_de_big \ 59 | --max-tokens 4096 \ 60 | --optimizer adam \ 61 | --adam-betas '(0.9, 0.98)' \ 62 | --lr 0.0005 \ 63 | --lr-scheduler inverse_sqrt \ 64 | --warmup-updates 4000 \ 65 | --warmup-init-lr 1e-07 \ 66 | --stop-min-lr 1e-09 \ 67 | --dropout 0.3 \ 68 | --clip-norm 1.0 \ 69 | --weight-decay 0.0 \ 70 | --criterion label_smoothed_cross_entropy \ 71 | --label-smoothing 0.1 \ 72 | --num-workers $cpu_num \ 73 | --no-epoch-checkpoints \ 74 | --share-all-embeddings \ 75 | --seed $seed 76 | 77 | 78 | -------------------------------------------------------------------------------- /src/convert_m2_to_parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import re 5 | 6 | if len(sys.argv) != 4: 7 | print("[USAGE] %s nucle_m2_file output_src output_tgt" % sys.argv[0]) 8 | sys.exit() 9 | 10 | input_path = sys.argv[1] 11 | output_src_path = sys.argv[2] 12 | output_tgt_path = sys.argv[3] 13 | 14 | words = [] 15 | corrected = [] 16 | sid = eid = 0 17 | prev_sid = prev_eid = -1 18 | pos = 0 19 | 20 | 21 | with open(input_path) as input_file, open(output_src_path, 'w') as output_src_file, open(output_tgt_path, 'w') as output_tgt_file: 22 | for line in input_file: 23 | line = line.strip() 24 | if line.startswith('S'): 25 | line = line[2:] 26 | words = line.split() 27 | corrected = [''] + words[:] 28 | output_src_file.write(line + '\n') 29 | elif line.startswith('A'): 30 | line = line[2:] 31 | info = line.split("|||") 32 | sid, eid = info[0].split() 33 | sid = int(sid) + 1; eid = int(eid) + 1; 34 | error_type = info[1] 35 | if error_type == "Um": 36 | continue 37 | for idx in range(sid, eid): 38 | corrected[idx] = "" 39 | if sid == eid: 40 | if sid == 0: continue # Originally index was -1, indicating no op 41 | if sid != prev_sid or eid != prev_eid: 42 | pos = len(corrected[sid-1].split()) 43 | cur_words = corrected[sid-1].split() 44 | cur_words.insert(pos, info[2]) 45 | pos += len(info[2].split()) 46 | corrected[sid-1] = " ".join(cur_words) 47 | else: 48 | corrected[sid] = info[2] 49 | pos = 0 50 | prev_sid = sid 51 | prev_eid = eid 52 | else: 53 | target_sentence = ' '.join([word for word in corrected if word != ""]) 54 | assert target_sentence.startswith(''), '(' + target_sentence + ')' 55 | target_sentence = target_sentence[4:] 56 | output_tgt_file.write(target_sentence + '\n') 57 | prev_sid = -1 58 | prev_eid = -1 59 | pos = 0 60 | -------------------------------------------------------------------------------- /src/remove.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument('--source-lang', type=str, required=True) 8 | parser.add_argument('--target-lang', type=str, required=True) 9 | parser.add_argument('--trainpref', type=str, required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | return args 14 | 15 | 16 | def main(args): 17 | with open(f'{args.trainpref}.{args.source_lang}') as fs: 18 | with open(f'{args.trainpref}.{args.target_lang}') as ft: 19 | with open(f'{args.trainpref}_corrected.{args.source_lang}', 'w') as fws: 20 | with open(f'{args.trainpref}_corrected.{args.target_lang}', 'w') as fwt: 21 | for s, t in zip(fs, ft): 22 | if s == t: 23 | continue 24 | else: 25 | fws.write(s) 26 | fwt.write(t) 27 | 28 | 29 | if __name__ == "__main__": 30 | args = parse_args() 31 | main(args) 32 | --------------------------------------------------------------------------------