├── tapes ├── packages.tape ├── env_local.tape ├── ratio.config ├── full.config ├── balanced.config ├── env_coe.tape ├── params.tape ├── versioners.tape ├── train.tape ├── submitters.tape ├── prepro.tape ├── main.tape └── test.tape ├── .gitignore ├── environment.yml ├── scripts ├── data_to_char.py └── spm_vocab_export.py ├── ersatz ├── __init__.py ├── candidates.py ├── subword.py ├── model.py ├── score.py ├── dataset.py ├── split.py ├── utils.py └── trainer.py ├── setup.py ├── README.md └── LICENSE /tapes/packages.tape: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .history_coe.local.ET 2 | .data/* 3 | .ipynb_checkpoints 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /tapes/env_local.tape: -------------------------------------------------------------------------------- 1 | package ersatz :: .versioner=disk .path="/home/rewicks/Sources/ersatz" {} 2 | 3 | global { 4 | } 5 | -------------------------------------------------------------------------------- /tapes/ratio.config: -------------------------------------------------------------------------------- 1 | global { 2 | ducttape_output="/exp/rwicks/ersatz/exp/ratio" 3 | train_dir="/exp/rwicks/ersatz/data/ratio/train/" 4 | } 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ersatz 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.7 6 | - pytorch=1.7.1 7 | - pip: 8 | - sentencepiece 9 | -------------------------------------------------------------------------------- /tapes/full.config: -------------------------------------------------------------------------------- 1 | global { 2 | ducttape_output="/exp/rwicks/ersatz/exp/full" 3 | train_dir="/exp/rwicks/ersatz/data/full/train/" 4 | } 5 | -------------------------------------------------------------------------------- /tapes/balanced.config: -------------------------------------------------------------------------------- 1 | global { 2 | ducttape_output="/exp/rwicks/ersatz/exp/balanced" 3 | train_dir="/exp/rwicks/ersatz/data/balanced/train/" 4 | } 5 | -------------------------------------------------------------------------------- /scripts/data_to_char.py: -------------------------------------------------------------------------------- 1 | # turns input bpe file into char bpe file 2 | # prints to std out 3 | 4 | import sys 5 | 6 | file_path = sys.argv[1] 7 | 8 | with open(file_path) as i: 9 | for line in i: 10 | line = line.strip().replace(' ', '') 11 | line = ' '.join(line) 12 | print(line) 13 | 14 | -------------------------------------------------------------------------------- /tapes/env_coe.tape: -------------------------------------------------------------------------------- 1 | package ersatz :: .versioner=disk .path="/home/hltcoe/rwicks/ersatz/code/ersatz" {} 2 | 3 | global { 4 | grid="sge" 5 | location="coe" 6 | cpuActionFlags="" 7 | cpuResourceFlags="-l 'h_rt=100:0:0'" 8 | 9 | gpuActionFlags="-q gpu.q@@v100 -M rwicks@jhu.edu -m base" 10 | gpuResourceFlags="-l 'gpu=1,h_rt=500:0:0'" 11 | 12 | } 13 | -------------------------------------------------------------------------------- /scripts/spm_vocab_export.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | vocab = {} 5 | 6 | 7 | vocab[''] = 0 8 | vocab[''] = 1 9 | vocab[''] = 2 10 | vocab[''] = 3 11 | vocab[''] = 4 12 | 13 | with open(sys.argv[1]) as inputFile: 14 | for line in inputFile: 15 | word = line.split()[0].strip() 16 | if word not in vocab: 17 | vocab[word] = len(vocab) 18 | 19 | with open(sys.argv[2], 'w') as outputFile: 20 | json.dump(vocab, outputFile, indent=4) 21 | -------------------------------------------------------------------------------- /tapes/params.tape: -------------------------------------------------------------------------------- 1 | global { 2 | log_dir="/exp/rwicks/ersatz/exp/runs" 3 | batch_size=25000 4 | min_epochs=5 5 | max_epochs=1000 6 | lr=(LearningRate: 0.0001 0.00001) 7 | dropout=0.1 8 | log_interval=100 9 | validation_interval=500 10 | early_stopping=15 11 | left_size=(LeftSize: 1 2 3 4 5 6) 12 | right_size=(RightSize: 1 2 3 4 5 6) 13 | embed_size=(EmbedSize: 16 32 64 128 256) 14 | factor_embed_size=(FactSize: 0 8 16 32) 15 | transformer_nlayers=(TransformerLayers: 0 1 2) 16 | linear_nlayers=(LinearLayers: 0 1 2 3) 17 | activation_type=(ActivationType: tanh) 18 | nhead=(NHead: 8) 19 | eos_weight=(EOSWeight: 10.0 20.0 1.0 5.0 0.1 0.2) 20 | vocab_size=(Vocabulary: 125 500 1000 4000 8000 12000) 21 | } 22 | -------------------------------------------------------------------------------- /ersatz/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | __version__ = '1.0.0' 6 | __description__ = 'Simple sentence segmentation toolkit for segmenting and scoring' 7 | 8 | class DummyArgs(): 9 | def __init__(self): 10 | return 11 | 12 | def split(model="default-multilingual", 13 | text=None, 14 | input=None, 15 | output=None, 16 | batch_size=16, 17 | candidates="multilingual", 18 | cpu=False, 19 | columns=None, 20 | delimiter='\t'): 21 | from .split import split as ersatz_split 22 | args = DummyArgs() 23 | args.model = model 24 | args.text = text 25 | args.input = input 26 | args.output = output 27 | args.batch_size = batch_size 28 | args.candidates = candidates 29 | args.cpu = cpu 30 | args.columns = columns 31 | args.delimiter = delimiter 32 | args.list = True 33 | return ersatz_split(args) 34 | 35 | 36 | def train(): 37 | raise NotImplementedError 38 | 39 | def score(): 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /tapes/versioners.tape: -------------------------------------------------------------------------------- 1 | # * "checkout" is run in a sandbox directory 2 | # * All other commands are run inside $dir 3 | versioner disk :: path { 4 | action checkout > dir { 5 | if [ ! -e $path ]; then 6 | echo >&2 "Directory does not exist: $path" 7 | exit 1 8 | fi 9 | ln -s $path'/'* $dir/ 10 | } 11 | action repo_version > version { 12 | if [ ! -e $path ]; then 13 | echo >&2 "Directory does not exist: $path" 14 | exit 1 15 | fi 16 | echo "VERSIONING_UNSUPPORTED" > $version 17 | } 18 | action local_version > version date { 19 | if [ ! -e $path ]; then 20 | echo >&2 "Directory does not exist: $path" 21 | exit 1 22 | fi 23 | echo "VERSIONING_UNSUPPORTED" > $version 24 | echo "VERSIONING_UNSUPPORTED" > $date 25 | } 26 | } 27 | 28 | versioner git :: repo ref { 29 | action checkout > dir { 30 | git clone $repo $dir 31 | cd $dir 32 | git checkout $ref 33 | } 34 | action repo_version > version { 35 | git ls-remote $repo $ref | cut -f1 > $version 36 | } 37 | # Used to confirm version after checkout 38 | action local_version > version date { 39 | git rev-parse HEAD > $version 40 | git log -1 | awk '/^Date/{$1=""; print}' > $date 41 | } 42 | } 43 | 44 | versioner pip :: package tag { 45 | action checkout > dir { 46 | pip install $package==$tag 47 | } 48 | 49 | action repo_version > version { 50 | pip show $package | grep ^Version: | cut -d' ' -f2 > $version || 51 | echo "0" > $version 52 | } 53 | 54 | # Used to confirm version after checkout 55 | action local_version > version date { 56 | sacrebleu -V > $version 57 | echo > $date 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /tapes/train.tape: -------------------------------------------------------------------------------- 1 | task train : ersatz 2 | < train_path=$out_shuffle@make_train_data 3 | < valid_path=$out_shuffle@make_valid_data 4 | < vocab_path=$model@train_vocab 5 | > out 6 | :: left_size=@ 7 | :: right_size=@ 8 | :: languages=@ 9 | :: batch_size=@ 10 | :: min_epochs=@ 11 | :: max_epochs=@ 12 | :: transformer_nlayers=@ 13 | :: linear_nlayers=@ 14 | :: lr=@ 15 | :: dropout=@ 16 | :: embed_size=@ 17 | :: factor_embed_size=@ 18 | :: activation_type=@ 19 | :: nhead=@ 20 | :: log_interval=@ 21 | :: validation_interval=@ 22 | :: log_dir=@ 23 | :: early_stopping=@ 24 | :: eos_weight=@ 25 | :: pyenv=@ :: .submitter=$grid :: devices=@ 26 | :: devices_per_task=1 27 | :: .resource_flags=$gpuResourceFlags :: .action_flags=$gpuActionFlags 28 | { 29 | mkdir -p $out 30 | LOGEXT=$(echo $out | rev | cut -d'/' -f2 | rev) 31 | LOGDIR=$log_dir"/"$LOGEXT 32 | rm -rf $LOGDIR 33 | 34 | if [$factor_embed_size = "0"]; then 35 | FACT_VALUE="--source_factors" 36 | echo "$FACT_VALUE" 37 | else 38 | FACT_VALUE="" 39 | fi 40 | 41 | #PYTHONPATH=$ersatz python $ersatz/trainer.py \ 42 | ersatz_train \ 43 | --sentencepiece_path=$vocab_path \ 44 | --left_size=$left_size \ 45 | --right_size=$right_size \ 46 | --output_path=$out \ 47 | --transformer_nlayers=$transformer_nlayers \ 48 | --activation_type=$activation_type \ 49 | --linear_nlayers=$linear_nlayers \ 50 | --min-epochs=$min_epochs \ 51 | --max-epochs=$max_epochs \ 52 | --lr=$lr \ 53 | --batch_size=$batch_size \ 54 | --dropout=$dropout \ 55 | --embed_size=$embed_size \ 56 | --factor_embed_size=$factor_embed_size $FACT_VALUE \ 57 | --nhead=$nhead \ 58 | --log_interval=$log_interval \ 59 | --validation_interval=$validation_interval \ 60 | --eos_weight=$eos_weight \ 61 | --early_stopping=$early_stopping \ 62 | --tb_dir=$LOGDIR \ 63 | --train_path=$train_path \ 64 | --valid_path=$valid_path 65 | } 66 | 67 | -------------------------------------------------------------------------------- /ersatz/candidates.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # sentence ending punctuation 4 | # U+0964 । Po DEVANAGARI DANDA 5 | # U+061F ؟ Po ARABIC QUESTION MARK 6 | # U+002E . Po FULL STOP 7 | # U+3002 。 Po IDEOGRAPHIC FULL STOP 8 | # U+0021 ! Po EXCLAMATION MARK 9 | # U+06D4 ۔ Po ARABIC FULL STOP 10 | # U+17D4 ។ Po KHMER SIGN KHAN 11 | # U+003F ? Po QUESTION MARK 12 | # U+2026 ... Po Ellipsis 13 | # U+30FB 14 | # U+002A * 15 | 16 | # other acceptable punctuation 17 | # U+3011 】 Pe RIGHT BLACK LENTICULAR BRACKET 18 | # U+00BB » Pf RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK 19 | # U+201D " Pf RIGHT DOUBLE QUOTATION MARK 20 | # U+300F 』 Pe RIGHT WHITE CORNER BRACKET 21 | # U+2018 ‘ Pi LEFT SINGLE QUOTATION MARK 22 | # U+0022 " Po QUOTATION MARK 23 | # U+300D 」 Pe RIGHT CORNER BRACKET 24 | # U+201C " Pi LEFT DOUBLE QUOTATION MARK 25 | # U+0027 ' Po APOSTROPHE 26 | # U+2019 ’ Pf RIGHT SINGLE QUOTATION MARK 27 | # U+0029 ) Pe RIGHT PARENTHESIS 28 | 29 | ending_punc = { 30 | '\u0964', 31 | '\u061F', 32 | '\u002E', 33 | '\u3002', 34 | '\u0021', 35 | '\u06D4', 36 | '\u17D4', 37 | '\u003F', 38 | '\uFF61', 39 | '\uFF0E', 40 | '\u2026', 41 | } 42 | 43 | closing_punc = { 44 | '\u3011', 45 | '\u00BB', 46 | '\u201D', 47 | '\u300F', 48 | '\u2018', 49 | '\u0022', 50 | '\u300D', 51 | '\u201C', 52 | '\u0027', 53 | '\u2019', 54 | '\u0029' 55 | } 56 | 57 | list_set = { 58 | '\u30fb', 59 | '\uFF65', 60 | '\u002a', # asterisk 61 | '\u002d', 62 | '\u4e00' 63 | } 64 | 65 | class Split(): 66 | def __call__(self, left_context, right_context): 67 | return True 68 | 69 | 70 | class PunctuationSpace(Split): 71 | def __call__(self, left_context, right_context): 72 | if right_context[0] == ' ': 73 | regex = '.*[?!.][.?!")\']*' 74 | regex = re.compile(regex) 75 | if regex.fullmatch(left_context) is not None: 76 | return True 77 | return False 78 | 79 | class Lists(Split): 80 | def __call__(self, left_context, right_context): 81 | if right_context.strip()[0] in ['*', '-', '~']: 82 | return True 83 | 84 | class MultilingualPunctuation(Split): 85 | def __call__(self, left_context, right_context): 86 | try: 87 | left_context = left_context.split(' ')[-1] 88 | if right_context[0] not in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]: 89 | for i, ch in enumerate(left_context): 90 | if ch in ending_punc: 91 | for j, next_ch in enumerate(left_context[i:], i): 92 | if next_ch not in ending_punc and next_ch not in closing_punc: 93 | j = -1 94 | break 95 | if j != -1: 96 | return True 97 | except: 98 | return False 99 | return False 100 | 101 | class IndividualPunctuation(Split): 102 | def __init__(self, unicode_char): 103 | self.punc = unicode_char 104 | 105 | def __call__(self, left_context, right_context): 106 | if right_context[0] == ' ': 107 | for i, ch in enumerate(left_context): 108 | if ch == self.punc: 109 | for j, next_ch in enumerate(left_context[i:], i): 110 | if next_ch != self.punc and next_ch not in closing_punc: 111 | j = -1 112 | break 113 | if j != -1: 114 | return True 115 | return False 116 | -------------------------------------------------------------------------------- /tapes/submitters.tape: -------------------------------------------------------------------------------- 1 | # COMMANDS: the bash commands from some task 2 | # TASK, REALIZATION, CONFIGURATION: variables passed by ducttape 3 | submitter shell :: COMMANDS TASK_VARIABLES { 4 | action run { 5 | set +u # needed to fix a virtualenv bug 6 | if [[ ! -z ${pyenv:-} ]]; then 7 | virtualenv=$pyenv 8 | 9 | # Load the environment 10 | if [[ $virtualenv == conda:* ]]; then 11 | target=$(echo $virtualenv | cut -d: -f2-) 12 | source deactivate 13 | source activate $target 14 | else 15 | source $virtualenv 16 | fi 17 | fi 18 | set -u 19 | 20 | STARTED=$(date +%s) 21 | time eval "$COMMANDS" 22 | STOPPED=$(date +%s) 23 | TIME=$(($STOPPED - $STARTED)) 24 | echo $TIME > ducttape_time.txt 25 | set -u 26 | } 27 | } 28 | 29 | # COMMANDS: the bash commands from some task 30 | # TASK, REALIZATION, CONFIGURATION: variables passed by ducttape 31 | submitter sge :: action_flags 32 | :: COMMANDS 33 | :: TASK REALIZATION TASK_VARIABLES CONFIGURATION { 34 | action run { 35 | wrapper="ducttape_sge_job.sh" 36 | echo "#!/usr/bin/env bash" >> $wrapper 37 | echo "" >> $wrapper 38 | echo "#$ $resource_flags" >> $wrapper 39 | echo "#$ $action_flags" >> $wrapper 40 | echo "#$ -j y" >> $wrapper 41 | echo "#$ -o $PWD/job.out" >> $wrapper 42 | echo "#$ -N $CONFIGURATION-$TASK-$REALIZATION" >> $wrapper 43 | echo "" >> $wrapper 44 | 45 | # Bash flags aren't necessarily passed into the scheduler 46 | # so we must re-initialize them 47 | 48 | echo "set -euo pipefail" >> $wrapper 49 | echo "" >> $wrapper 50 | echo "$TASK_VARIABLES" | perl -pe 's/=/="/; s/$/"/' >> $wrapper 51 | 52 | # Setup the virtual environment 53 | cat >> $wrapper <> $wrapper 78 | 79 | # The current working directory will also be changed by most schedulers 80 | echo "cd $PWD" >> $wrapper 81 | 82 | echo >> $wrapper 83 | echo "echo \"HOSTNAME: \$(hostname)\"" >> $wrapper 84 | echo "echo" >> $wrapper 85 | echo "echo CUDA in ENV:" >> $wrapper 86 | echo "env | grep CUDA" >> $wrapper 87 | echo "echo" >> $wrapper 88 | echo "echo SGE in ENV:" >> $wrapper 89 | echo "env | grep SGE" >> $wrapper 90 | #echo "nvidia-smi" >> $wrapper 91 | echo >> $wrapper 92 | 93 | echo "$COMMANDS" >> $wrapper 94 | echo "echo \$? > $PWD/exitcode" >> $wrapper # saves the exit code of the inner process 95 | 96 | # Use SGE's -sync option to prevent qsub from immediately returning 97 | qsub -V -S /bin/bash $wrapper | grep -Eo "Your job [0-9]+" | grep -Eo "[0-9]+" > $PWD/job_id 98 | job_id=`cat $PWD/job_id` 99 | 100 | # async job killer 101 | exitfn () { 102 | trap SIGINT 103 | echo "wait until I kill the job $job_id" 104 | qdel $job_id 105 | exit 106 | } 107 | 108 | trap "exitfn" INT 109 | 110 | # don't use -sync y, instead, wait on exitcode 111 | while [ ! -z "`qstat -u $USER | grep $job_id`" ] 112 | do 113 | sleep 15 114 | done 115 | 116 | trap SIGINT 117 | 118 | # restore the exit code saved from the inner process 119 | EXITCODE=$(cat $PWD/exitcode) 120 | [ $EXITCODE = "0" ] 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /tapes/prepro.tape: -------------------------------------------------------------------------------- 1 | global { 2 | languages=(Scenario: en="en" 3 | ar="ar" 4 | cs="cs" 5 | de="de" 6 | es="es" 7 | et="et" 8 | fi="fi" 9 | fr="fr" 10 | gu="gu" 11 | hi="hi" 12 | iu="iu" 13 | ja="ja" 14 | kk="kk" 15 | km="km" 16 | lt="lt" 17 | lv="lv" 18 | pl="pl" 19 | ps="ps" 20 | ro="ro" 21 | ru="ru" 22 | ta="ta" 23 | tr="tr" 24 | zh="zh" 25 | multilingual="ar cs de en es et fi fr gu hi iu ja kk km lt lv pl ps ro ru ta tr zh") 26 | } 27 | 28 | task train_vocab : ersatz 29 | :: vocab_size=@ 30 | :: languages=@ 31 | :: train_dir=@ 32 | :: pyenv=@ :: .submitter=$grid :: devices=@ 33 | :: devices_per_task=0 34 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 35 | > model vocab 36 | { 37 | mkdir -p out 38 | TRAIN_DATA_PATH="" 39 | for lang in $languages; do 40 | TRAIN_DATA_PATH=$train_dir$lang","$TRAIN_DATA_PATH 41 | done; 42 | TRAIN_DATA_PATH=${TRAIN_DATA_PATH%?} 43 | spm_train_py --input $TRAIN_DATA_PATH \ 44 | --input_sentence_size 10000000 \ 45 | --model_prefix out/ersatz \ 46 | --vocab_size $vocab_size \ 47 | --bos_piece "" \ 48 | --eos_piece "" 49 | ln -s `realpath out/ersatz.model` $model 50 | ln -s `realpath out/ersatz.vocab` $vocab 51 | } 52 | 53 | task make_train_data : ersatz 54 | < vocab_path=$model@train_vocab 55 | :: left_size=@ right_size=@ languages=@ 56 | :: train_dir=@ 57 | :: pyenv=@ :: .submitter=$grid :: devices=@ 58 | :: devices_per_task=0 59 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 60 | > out_og out_shuffle 61 | { 62 | if [[ $languages =~ \ |\' ]]; then 63 | TRAIN_DATA_PATH="" 64 | for lang in $languages; do 65 | TRAIN_DATA_PATH="${TRAIN_DATA_PATH} $train_dir$lang" 66 | done 67 | else 68 | TRAIN_DATA_PATH=$train_dir$languages 69 | fi 70 | #PYTHONPATH=. python $ersatz/dataset.py \ 71 | ersatz_preprocess \ 72 | --sentencepiece_path $vocab_path \ 73 | --left-size $left_size \ 74 | --right-size $right_size \ 75 | --output_path out_og \ 76 | --input_paths $TRAIN_DATA_PATH 77 | 78 | shuf --random-source=<(get_seeded_random 14) out_og > out_shuffle 79 | } 80 | 81 | task make_valid_data : ersatz 82 | < vocab_path=$model@train_vocab 83 | :: left_size=@ right_size=@ languages=@ 84 | :: pyenv=@ :: .submitter=$grid :: devices=@ 85 | :: devices_per_task=0 86 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 87 | > out_og out_shuffle 88 | { 89 | if [[ $languages =~ \ |\' ]]; then 90 | VALID_DATA_PATH="" 91 | for lang in $languages; do 92 | VALID_DATA_PATH="${VALID_DATA_PATH} /exp/rwicks/ersatz/data/balanced/dev/$lang" 93 | done 94 | else 95 | VALID_DATA_PATH="/exp/rwicks/ersatz/data/balanced/dev/$languages" 96 | fi 97 | #PYTHONPATH=. python $ersatz/dataset.py \ 98 | ersatz_preprocess \ 99 | --sentencepiece_path $vocab_path \ 100 | --left-size $left_size \ 101 | --right-size $right_size \ 102 | --output_path out_og \ 103 | --input_paths $VALID_DATA_PATH 104 | shuf --random-source=<(get_seeded_random 14) out_og > out_shuffle 105 | } 106 | 107 | -------------------------------------------------------------------------------- /ersatz/subword.py: -------------------------------------------------------------------------------- 1 | # *-* coding: utf-8 *-* 2 | import torch 3 | 4 | class Vocabulary(): 5 | 6 | def __init__(self): 7 | self.itos = ['', '', '', ''] 8 | self.stoi = {'': 0, '': 1, '': 2, '': 3} 9 | 10 | def __len__(self): 11 | return len(self.itos) 12 | 13 | def build_vocab(self, file_path): 14 | with open(file_path) as i: 15 | for line in i: 16 | word = line.split()[0].strip() 17 | self.add_word(word) 18 | 19 | def add_word(self, word): 20 | if word not in self.stoi: 21 | self.stoi[word] = len(self.itos) 22 | self.itos.append(word) 23 | 24 | def embed_word(self, word): 25 | return self.stoi.get(word, 2) 26 | 27 | def get_word(self, embedding): 28 | return self.itos[embedding] 29 | 30 | def detokenize(self, input_string): 31 | input_string = input_string.replace(' ', '') 32 | input_string = input_string.replace('\u2581', ' ') 33 | return input_string 34 | 35 | def encode(self, input_string, out_type=int): 36 | if out_type is int: 37 | arr = [] 38 | input_string = input_string.split() 39 | for s in input_string: 40 | arr.append(self.embed_word(s)) 41 | return arr 42 | else: 43 | return input_string.split() 44 | 45 | def decode(self, input_array): 46 | output = [] 47 | for i in input_array: 48 | output.append(self.get_word(i)) 49 | return ' '.join(output) 50 | 51 | def tensor_to_string(self, tensors): 52 | if len(tensors.shape) > 1: 53 | output = [] 54 | for tens in tensors: 55 | output.append(self.decode(tens.tolist())) 56 | return output 57 | else: 58 | return self.decode(tensors.tolist()) 59 | 60 | def context_to_tensor(self, contexts): 61 | con_arr = [] 62 | fact_arr = [] 63 | lab_arr = [] 64 | for left, left_stream, right, right_stream, label in contexts: 65 | tens = [] 66 | for l in left.split(): 67 | tens.append(self.embed_word(l)) 68 | for r in right.split(): 69 | tens.append(self.embed_word(r)) 70 | con_arr.append(tens) 71 | 72 | fact_arr.append(left_stream + right_stream) 73 | 74 | if label == "": 75 | lab_arr.append(0) 76 | else: 77 | lab_arr.append(1) 78 | return torch.tensor(con_arr), torch.tensor(fact_arr), torch.tensor(lab_arr) 79 | 80 | class SentencePiece(Vocabulary): 81 | """ 82 | Implements SentencePiece. 83 | https://github.com/google/sentencepiece/blob/master/python/README.md 84 | """ 85 | def __init__(self, serialization=None, model_path=None, vocab_path = None, sample: bool = True, alpha: float = 0.5): 86 | import sentencepiece as spm 87 | if serialization is None: 88 | self.model = spm.SentencePieceProcessor() 89 | self.model_path = model_path 90 | if model_path is not None: 91 | self.model.Load(model_path) 92 | if vocab_path: 93 | self.vocab_path = vocab_path 94 | self.model.LoadVocabulary(vocab_path) 95 | else: 96 | self.model = spm.SentencePieceProcessor(model_proto=serialization) 97 | self.alpha = alpha 98 | self.sample = sample 99 | 100 | def __len__(self): 101 | return self.model.get_piece_size() 102 | 103 | def embed_word(self, word): 104 | return self.model[word] 105 | 106 | def encode(self, sentence, out_type=int) -> str: 107 | return self.model.encode(sentence, out_type=out_type) 108 | if out_type is int: 109 | if self.sample: 110 | return ' '.join(self.model.SampleEncodeAsPieces(sentence, nbest_size=1, alpha=self.alpha)) 111 | else: 112 | return ' '.join(self.model.EncodeAsPieces(sentence)) 113 | 114 | def decode(self, ids): 115 | return self.model.decode(ids) 116 | 117 | def merge(self, sentence: str, technique='replace') -> str: 118 | if technique == 'replace': 119 | return sentence.replace(' ', '').replace('▁', ' ') 120 | else: 121 | return self.model.decode(sentence) 122 | 123 | def get_tokenizer(model_path, sample = False): 124 | return SentencePiece(model_path=model_path, sample=sample) 125 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | A setuptools based setup module. 5 | See: 6 | - https://packaging.python.org/en/latest/distributing.html 7 | - https://github.com/pypa/sampleproject 8 | To install: 9 | 1. Setup pypi by creating ~/.pypirc 10 | [distutils] 11 | index-servers = 12 | pypi 13 | pypitest 14 | [pypi] 15 | username= 16 | password= 17 | [pypitest] 18 | username= 19 | password= 20 | 2. Create the dist 21 | python3 setup.py sdist bdist_wheel 22 | 3. Push 23 | twine upload dist/* 24 | """ 25 | 26 | import os 27 | import re 28 | 29 | # Always prefer setuptools over distutils 30 | from setuptools import setup, find_packages 31 | 32 | 33 | ROOT = os.path.dirname(__file__) 34 | 35 | 36 | def get_version(): 37 | """ 38 | Reads the version from ersatz's __init__.py file. 39 | We can't import the module because required modules may not 40 | yet be installed. 41 | """ 42 | VERSION_RE = re.compile(r'''__version__ = ['"]([0-9.]+)['"]''') 43 | init = open(os.path.join(ROOT, 'ersatz', '__init__.py')).read() 44 | return VERSION_RE.search(init).group(1) 45 | 46 | 47 | def get_description(): 48 | DESCRIPTION_RE = re.compile(r'''__description__ = ['"](.*)['"]''') 49 | init = open(os.path.join(ROOT, 'ersatz', '__init__.py')).read() 50 | return DESCRIPTION_RE.search(init).group(1) 51 | 52 | 53 | setup( 54 | name = 'ersatz', 55 | 56 | # Versions should comply with PEP440. For a discussion on single-sourcing 57 | # the version across setup.py and the project code, see 58 | # https://packaging.python.org/en/latest/single_source_version.html 59 | version = get_version(), 60 | 61 | description = get_description(), 62 | 63 | long_description = "Ersatz is a simple, language-agnostic toolkit for both training sentence segmentation models as well as providing pretrained, " 64 | "high-performing models for sentence segmentation in a multilingual setting.", 65 | 66 | # The project's main homepage. 67 | url = 'https://github.com/rewicks/ersatz', 68 | 69 | author = 'Rachel Wicks', 70 | author_email='rewicks@jhu.edu', 71 | maintainer_email='rewicks@jhu.edu', 72 | 73 | license = 'Apache License 2.0', 74 | 75 | python_requires = '>=3', 76 | 77 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 78 | classifiers = [ 79 | # How mature is this project? Common values are 80 | # 3 - Alpha 81 | # 4 - Beta 82 | # 5 - Production/Stable 83 | 'Development Status :: 5 - Production/Stable', 84 | 85 | # Indicate who your project is intended for 86 | 'Intended Audience :: Developers', 87 | 'Intended Audience :: Science/Research', 88 | 'Topic :: Scientific/Engineering', 89 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 90 | 'Topic :: Text Processing', 91 | 92 | # Pick your license as you wish (should match "license" above) 93 | 'License :: OSI Approved :: Apache Software License', 94 | 95 | # Specify the Python versions you support here. In particular, ensure 96 | # that you indicate whether you support Python 2, Python 3 or both. 97 | 'Programming Language :: Python :: 3 :: Only', 98 | ], 99 | 100 | # What does your project relate to? 101 | keywords = ['sentence segmenation, data processing, preprocessing, evaluation, NLP, natural language processing, computational linguistics'], 102 | 103 | # Which packages to deploy (currently sacrebleu, sacrebleu.matrics and sacrebleu.tokenizers)? 104 | packages = find_packages(), 105 | 106 | # Mark ersatz (and recursively all its sub-packages) as supporting mypy type hints (see PEP 561). 107 | package_data={"ersatz": ["py.typed"]}, 108 | 109 | # List run-time dependencies here. These will be installed by pip when 110 | # your project is installed. For an analysis of "install_requires" vs pip's 111 | # requirements files see: 112 | # https://packaging.python.org/en/latest/requirements.html 113 | install_requires = [ 114 | 'typing;python_version<"3.5"', 115 | 'torch==1.7.1', 116 | 'sentencepiece==0.1.95', 117 | 'tensorboard==2.4.1', 118 | 'progressbar2' 119 | ], 120 | 121 | # List additional groups of dependencies here (e.g. development 122 | # dependencies). You can install these using the following syntax, 123 | # for example: 124 | # $ pip install -e .[dev,test] 125 | extras_require = {}, 126 | 127 | # To provide executable scripts, use entry points in preference to the 128 | # "scripts" keyword. Entry points provide cross-platform support and allow 129 | # pip to create the appropriate form of executable for the target platform. 130 | entry_points={ 131 | 'console_scripts': [ 132 | 'ersatz = ersatz.split:main', 133 | 'ersatz_train = ersatz.trainer:main', 134 | 'ersatz_score = ersatz.score:main', 135 | 'ersatz_preprocess = ersatz.dataset:main' 136 | ], 137 | }, 138 | ) 139 | -------------------------------------------------------------------------------- /ersatz/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class ErsatzTransformer(nn.Module): 7 | 8 | def __init__(self, tokenizer, args): 9 | super(ErsatzTransformer, self).__init__() 10 | 11 | self.factor_embed_size = 0 12 | self.source_factors = False 13 | if 'source_factors' in args and args.source_factors: 14 | self.fact_emb = nn.Embedding(6, args.factor_embed_size) 15 | self.factor_embed_size = args.factor_embed_size 16 | self.source_factors = True 17 | 18 | if args.transformer_nlayers > 0: 19 | self.transformer = True 20 | # each layer of the transformer 21 | encoder_layer = nn.TransformerEncoderLayer(args.embed_size+self.factor_embed_size, args.nhead, dropout=args.dropout) 22 | # build the transformer with n of the previous layers 23 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=args.transformer_nlayers) 24 | self.pos_embed = PositionalEncoding(embed_size=args.embed_size+self.factor_embed_size, dropout=args.dropout, max_len=(args.left_size + args.right_size)) 25 | self.nhead = args.nhead 26 | 27 | else: 28 | self.transformer = False 29 | # embeds the input into a embed_size dimensional space 30 | self.src_emb = nn.Embedding(len(tokenizer), args.embed_size) 31 | self.embed_dropout = nn.Dropout(args.dropout) 32 | 33 | # vocab; includes stoi and itos look ups 34 | self.tokenizer = tokenizer 35 | self.transformer_nlayers = args.transformer_nlayers 36 | self.linear_nlayers = args.linear_nlayers 37 | self.dropout = args.dropout 38 | self.left_context_size = args.left_size 39 | self.right_context_size = args.right_size 40 | self.embed_size = args.embed_size 41 | self.max_size = args.left_size + args.right_size 42 | self.args = args 43 | self.generator = Generator(args.embed_size+self.factor_embed_size, self.max_size, 44 | nlayers=args.linear_nlayers, activation_type=args.activation_type) 45 | 46 | def forward(self, src, factors=None): 47 | if self.transformer: 48 | src = src.t() 49 | src = self.src_emb(src) 50 | if factors is not None: 51 | factors = factors.t() 52 | factors = self.fact_emb(factors) 53 | src = torch.cat((src, factors), dim=2) 54 | embed = self.pos_embed(src * math.sqrt(self.embed_size+self.factor_embed_size)) 55 | embed = self.encoder(embed).transpose(0,1) 56 | else: 57 | embed = self.src_emb(src) 58 | if factors is not None: 59 | factors = self.fact_emb(factors) 60 | embed = torch.cat((embed, factors), dim=2) 61 | #embed = self.embed_dropout(embed) 62 | #output = self.embed_dropout(embed) 63 | return self.generator(embed) 64 | 65 | class Generator(nn.Module): 66 | 67 | # could change this to mean-pool or max pool 68 | def __init__(self, embed_size, max_size, nlayers=0, activation_type="tanh"): 69 | super(Generator, self).__init__() 70 | hidden = max_size * embed_size 71 | 72 | if activation_type == 'tanh': 73 | activation = nn.Tanh() 74 | if nlayers > 0: 75 | hidden_layers = [ 76 | nn.Linear(hidden, embed_size), 77 | activation 78 | ] 79 | for n in range(1, nlayers): 80 | hidden_layers.append( 81 | nn.Linear(embed_size, embed_size) 82 | ) 83 | hidden_layers.append( 84 | activation 85 | ) 86 | self.hidden_layers = nn.ModuleList(hidden_layers) 87 | self.proj = nn.Linear(embed_size, 2) 88 | else: 89 | self.hidden_layers = None 90 | self.proj = nn.Linear(hidden, 2) 91 | 92 | def forward(self, x): 93 | x = x.reshape(x.size()[0], -1) 94 | if self.hidden_layers is not None: 95 | for layer in self.hidden_layers: 96 | x = layer(x) 97 | x = self.proj(x) 98 | return F.log_softmax(x, dim=-1) 99 | 100 | class PositionalEncoding(nn.Module): 101 | 102 | def __init__(self, embed_size=512, dropout=0.1, max_len=5000): 103 | super(PositionalEncoding, self).__init__() 104 | self.dropout = nn.Dropout(dropout) 105 | pe = torch.zeros(max_len, embed_size) 106 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 107 | div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size)) 108 | pe[:, 0::2] = torch.sin(position * div_term) 109 | pe[:, 1::2] = torch.cos(position * div_term) 110 | pe = pe.unsqueeze(0).transpose(0,1) 111 | self.register_buffer('pe', pe) 112 | 113 | def forward(self, x): 114 | x = x + self.pe[:x.size(0), :] 115 | return self.dropout(x) 116 | 117 | 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | _Ersatz_ is a simple, language-agnostic toolkit for both training sentence segmentation models as well as providing 2 | pretrained, high-performing models for sentence segmentation in a multilingual setting. 3 | 4 | For more information, please see: 5 | - [Rachel Wicks and Matt Post (2021): 6 | A unified approach to sentence segmentation of punctuated text in many languages](https://aclanthology.org/2021.acl-long.309/) In _Proceedings of ACL_. 7 | # QUICK START 8 | 9 | #### Install 10 | Install the Python (3.7+) module via pip 11 | 12 | ```angular2html 13 | pip install ersatz 14 | ``` 15 | 16 | or from source 17 | 18 | ```angular2html 19 | python setup.py install 20 | ``` 21 | 22 | #### Splitting 23 | 24 | _Ersatz_ can accept input from either standard input, or via a file path. Similarly, it produces output in the same manner: 25 | 26 | ```angular2html 27 | cat raw.txt | ersatz > output.txt 28 | ``` 29 | ```angular2html 30 | ersatz --input raw.txt --output output.txt 31 | ``` 32 | 33 | To use a specific model (rather than the default), you can pass a name via `--model_name`, or a path via `--model_path` 34 | 35 | #### Scoring 36 | _Ersatz_ also provides a simple scoring script which computes F1 from a given segmented file. 37 | ```angular2html 38 | ersatz_score GOLD_STANDARD_FILE FILE_TO_SCORE 39 | ``` 40 | The above will print all errors as well as additional metrics at bottom. 41 | The accompanying test suite can be found [here](https://github.com/rewicks/ersatz-test-suite). 42 | # Training a Model 43 | 44 | ## Data Preprocessing 45 | 46 | ### Vocabulary 47 | Requires uses a pretrained [`sentencepiece`](https://github.com/google/sentencepiece) model that has had `--eos_piece` replaced with `` and `--bos_piece` replaced with ``. 48 | 49 | ```angular2html 50 | spm_train --input $TRAIN_DATA_PATH \ 51 | --model_prefix ersatz \ 52 | --bos_piece "" \ 53 | --eos_piece "" 54 | ``` 55 | 56 | ### Create training data 57 | 58 | This pipeline takes a raw text file with one sentence per line (to use as labels) and creates a new raw text file 59 | with the appropriate left/right context and labels. One line is one training example. User is expected to shuffle this 60 | file manually (ie via `shuf`) after creation. 61 | 62 | 1. To create: 63 | ```angular2html 64 | python dataset.py \ 65 | --sentencepiece_path $SPM_PATH \ 66 | --left-size $LEFT_SIZE \ 67 | --right-size $RIGHT_SIZE \ 68 | --output_path $OUTPUT_PATH \ 69 | $INPUT_TRAIN_FILE_PATHS 70 | 71 | 72 | shuf $OUTPUT_PATH > $SHUFFLED_TRAIN_OUTPUT_PATH 73 | ``` 74 | 2. Repeat for validation data 75 | ```angular2html 76 | python dataset.py \ 77 | --sentencepiece_path $SPM_PATH \ 78 | --left-size $LEFT_SIZE \ 79 | --right-size $RIGHT_SIZE \ 80 | --output_path $VALIDATION_OUTPUT_PATH \ 81 | $INPUT_DEV_FILE_PATHS 82 | ``` 83 | 84 | ## Training 85 | Something like: 86 | 87 | ```angular2html 88 | python trainer.py \ 89 | --sentencepiece_path=$vocab_path \ 90 | --left_size=$left_size \ 91 | --right_size=$right_size \ 92 | --output_path=$out \ 93 | --transformer_nlayers=$transformer_nlayers \ 94 | --activation_type=$activation_type \ 95 | --linear_nlayers=$linear_nlayers \ 96 | --min-epochs=$min_epochs \ 97 | --max-epochs=$max_epochs \ 98 | --lr=$lr \ 99 | --dropout=$dropout \ 100 | --embed_size=$embed_size \ 101 | --factor_embed_size=$factor_embed_size \ 102 | --source_factors \ 103 | --nhead=$nhead \ 104 | --log_interval=$log_interval \ 105 | --validation_interval=$validation_interval \ 106 | --eos_weight=$eos_weight \ 107 | --early_stopping=$early_stopping \ 108 | --tb_dir=$LOGDIR \ 109 | $train_path \ 110 | $valid_path 111 | ``` 112 | 113 | # Splitting with a Pre-Trained Model 114 | 115 | 1. Expects a `model_path` (should probably change to a default in expected folder location...) 116 | 2. `ersatz` reads from either stdin or a file path (via `--input`). 117 | 3. `ersatz` writes to either stdout or a file path (via `--output`). 118 | 4. An alternate candidate set for splitting may be given using `--determiner_type` 119 | * `multilingual` (default) is as described in paper 120 | * `en` requires a space following punctuation 121 | * `all` a space between any two characters 122 | * Custom can be written that uses the `determiner.Split()` base class 123 | 5. By default, expects raw sentences. Splitting a `.tsv` is also a supported behavior. 124 | 1. `--text_ids` expects a comma separated list of column indices to split 125 | 2. `--delim` changes the delimiter character (default is `\t`) 126 | 6. Uses gpu if available, to force cpu, use `--cpu` 127 | 128 | ### Example usage 129 | Typical python usage: 130 | ```angular2html 131 | python split.py --input unsegmented.txt --output sentences.txt ersatz.model 132 | ``` 133 | 134 | std[in,out] usage: 135 | ```angular2html 136 | cat unsegmented.txt | split.py ersatz.model > sentences.txt 137 | ``` 138 | To split `.tsv` file: 139 | ```angular2html 140 | cat unsegmented.tsv | split.py ersatz.model --text_ids 1 > sentences.txt 141 | ``` 142 | 143 | # Scoring a Model's Output 144 | 145 | ```angular2html 146 | python score.py [gold_standard_file_path] [file_to_score] 147 | ``` 148 | 149 | (There are legacy arguments, but they're not used) 150 | 151 | # Changelog 152 | 153 | 1.0.0 original release 154 | -------------------------------------------------------------------------------- /tapes/main.tape: -------------------------------------------------------------------------------- 1 | import "submitters.tape" 2 | import "versioners.tape" 3 | 4 | global { 5 | pyenv="conda:ersatz" 6 | ducttape_experimental_packages=true 7 | ducttape_experimental_submitters=true 8 | ducttape_experimental_imports=true 9 | ducttape_experimental_multiproc=true 10 | devices="0,1,2,3,4,5,6,7" 11 | } 12 | 13 | import "env_coe.tape" 14 | import "params.tape" 15 | import "prepro.tape" 16 | import "train.tape" 17 | import "test.tape" 18 | 19 | plan do_english { 20 | reach score via (Scenario: en) * 21 | (Vocabulary: 125) * 22 | (LeftSize: 6) * 23 | (RightSize: 4 ) * 24 | (EmbedSize: 128) * 25 | (TransformerLayers: 2) * 26 | (LinearLayers: 1) * 27 | (ActivationType: tanh) * 28 | (EOSWeight: 1.0) 29 | } 30 | 31 | plan do_english_wikipedia { 32 | reach score via (Scenario: en) * 33 | (Vocabulary: 500) * 34 | (LeftSize: 6) * 35 | (RightSize: 4) * 36 | (EmbedSize: 64) * 37 | (TransformerLayers: 2) * 38 | (LinearLayers: 1) * 39 | (ActivationType: tanh) * 40 | (EOSWeight: 1.0) 41 | } 42 | 43 | plan do_monolinguals { 44 | reach score_dev via (Scenario: cs de en es et fi fr kk lt lv pl ro ru tr) * 45 | (LearningRate: 0.0001) * 46 | (Vocabulary: 500 1000) * 47 | (LeftSize: 6) * 48 | (RightSize: 4) * 49 | (EmbedSize: 128) * 50 | (FactSize: 0) * 51 | (TransformerLayers: 2) * 52 | (LinearLayers: 1) * 53 | (ActivationType: tanh) * 54 | (EOSWeight: 1.0 20.0) 55 | reach score_dev via (Scenario: ar gu hi iu km ps ta) * 56 | (LearningRate: 0.0001) * 57 | (Vocabulary: 500 1000) * 58 | (LeftSize: 6) * 59 | (RightSize: 4) * 60 | (EmbedSize: 128) * 61 | (FactSize: 0) * 62 | (TransformerLayers: 2) * 63 | (LinearLayers: 1) * 64 | (ActivationType: tanh) * 65 | (EOSWeight: 1.0 20.0) 66 | reach score_dev via (Scenario: ja) * 67 | (LearningRate: 0.0001) * 68 | (Vocabulary: 4000 8000) * 69 | (LeftSize: 6) * 70 | (RightSize: 4) * 71 | (EmbedSize: 128) * 72 | (FactSize: 0) * 73 | (TransformerLayers: 2) * 74 | (LinearLayers: 1) * 75 | (ActivationType: tanh) * 76 | (EOSWeight: 1.0 20.0) 77 | reach score_dev via (Scenario: zh) * 78 | (LearningRate: 0.0001) * 79 | (Vocabulary: 8000) * 80 | (LeftSize: 6 8) * 81 | (RightSize: 4 6) * 82 | (EmbedSize: 128) * 83 | (FactSize: 0) * 84 | (TransformerLayers: 2) * 85 | (LinearLayers: 1) * 86 | (ActivationType: tanh) * 87 | (EOSWeight: 1.0 20.0) 88 | } 89 | 90 | plan do_ml { 91 | reach train via (Scenario: multilingual) * 92 | (Vocabulary: 12000) * 93 | (LeftSize: 6) * 94 | (RightSize: 4) * 95 | (EmbedSize: 128 256) * 96 | (FactSize: 0) * 97 | (TransformerLayers: 1 2) * 98 | (LinearLayers: 1) * 99 | (ActivationType: tanh) * 100 | (EOSWeight: 10.0 20.0) * 101 | (LearningRate: 0.0001) 102 | reach train via (Scenario: multilingual) * 103 | (Vocabulary: 12000) * 104 | (LeftSize: 6) * 105 | (RightSize: 4) * 106 | (EmbedSize: 128 256) * 107 | (FactSize: 0) * 108 | (TransformerLayers: 0) * 109 | (LinearLayers: 2 3) * 110 | (ActivationType: tanh) * 111 | (EOSWeight: 10.0 20.0) * 112 | (LearningRate: 0.0001) 113 | } 114 | 115 | plan do_ml_context_grid { 116 | reach score_dev via (Scenario: multilingual) * 117 | (Vocabulary: 12000) * 118 | (LeftSize: 1 2 3 4 5 6) * 119 | (RightSize: 1 2 3 5 6) * 120 | (EmbedSize: 128) * 121 | (FactSize: 0) * 122 | (TransformerLayers: 2) * 123 | (LinearLayers: 1) * 124 | (ActivationType: tanh) * 125 | (EOSWeight: 20.0) * 126 | (LearningRate: 0.0001) 127 | reach score_dev via (Scenario: multilingual) * 128 | (Vocabulary: 12000) * 129 | (LeftSize: 1 2 3 4 5) * 130 | (RightSize: 4) * 131 | (EmbedSize: 128) * 132 | (FactSize: 0) * 133 | (TransformerLayers: 2) * 134 | (LinearLayers: 1) * 135 | (ActivationType: tanh) * 136 | (EOSWeight: 20.0) * 137 | (LearningRate: 0.0001) 138 | } 139 | 140 | 141 | plan do_time { 142 | reach speed_test_cpu via (Scenario: multilingual) * 143 | (Vocabulary: 12000) * 144 | (LeftSize: 6) * 145 | (RightSize: 4) * 146 | (EmbedSize: 128) * 147 | (FactSize: 0) * 148 | (TransformerLayers: 1 2) * 149 | (LinearLayers: 1) * 150 | (ActivationType: tanh) * 151 | (EOSWeight: 20.0) * 152 | (LearningRate: 0.0001) 153 | reach speed_test_cpu via (Scenario: multilingual) * 154 | (Vocabulary: 12000) * 155 | (LeftSize: 6) * 156 | (RightSize: 4) * 157 | (EmbedSize: 128) * 158 | (FactSize: 0) * 159 | (TransformerLayers: 0) * 160 | (LinearLayers: 2 3) * 161 | (ActivationType: tanh) * 162 | (EOSWeight: 20.0) * 163 | (LearningRate: 0.0001) 164 | } 165 | 166 | plan do_score { 167 | reach score_dev via (Scenario: multilingual) * 168 | (Vocabulary: 12000) * 169 | (LeftSize: 6) * 170 | (RightSize: 4) * 171 | (EmbedSize: 128) * 172 | (FactSize: 0) * 173 | (TransformerLayers: 2) * 174 | (LinearLayers: 1) * 175 | (ActivationType: tanh) * 176 | (EOSWeight: 10.0) * 177 | (LearningRate: 0.0001) 178 | } 179 | 180 | plan do_baselines { 181 | reach score_baseline via (Scenario: multilingual) * 182 | (Base: *) 183 | } 184 | -------------------------------------------------------------------------------- /tapes/test.tape: -------------------------------------------------------------------------------- 1 | global { 2 | in_languages=(Scenario: en="en" 3 | ar="ar" 4 | cs="cs" 5 | de="de" 6 | es="es" 7 | et="et" 8 | fi="fi" 9 | fr="fr" 10 | gu="gu" 11 | hi="hi" 12 | iu="iu" 13 | ja="ja" 14 | kk="kk" 15 | km="km" 16 | lt="lt" 17 | lv="lv" 18 | pl="pl" 19 | ps="ps" 20 | ro="ro" 21 | ru="ru" 22 | ta="ta" 23 | tr="tr" 24 | zh="zh" 25 | multilingual="ar cs de en es et fi fr gu hi iu ja kk km lt lv pl ps ro ru ta tr zh") 26 | baseline_type=(Base: punkt="punkt" 27 | moses="moses" 28 | ml_punkt="ml-punkt" always="always-split" never="never-split" spacy="spacy-split") 29 | } 30 | 31 | task split : ersatz 32 | < model_path=$out@train 33 | > out 34 | :: in_languages=@ 35 | :: pyenv=@ :: .submitter=$grid :: devices=@ 36 | :: devices_per_task=1 37 | :: .resource_flags=$gpuResourceFlags :: .action_flags=$gpuActionFlags 38 | { 39 | mkdir -p $out 40 | for lang in ${in_languages[@]}; 41 | do 42 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/test/$lang" 43 | OUTPATH=$(echo $FILE_PATH | rev | cut -d'/' -f1 | rev) 44 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py | tr '\n' ' ' | ersatz --model $model_path/checkpoint.best --candidates=multilingual > out/$OUTPATH 45 | done; 46 | } 47 | 48 | task split_dev : ersatz 49 | < model_path=$out@train 50 | > out 51 | :: in_languages=@ 52 | :: pyenv=@ :: .submitter=$grid :: devices=@ 53 | :: devices_per_task=1 54 | :: .resource_flags=$gpuResourceFlags :: .action_flags=$gpuActionFlags 55 | { 56 | mkdir -p $out 57 | for lang in ${in_languages[@]}; 58 | do 59 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/dev/$lang" 60 | OUTPATH=$(echo $FILE_PATH | rev | cut -d'/' -f1 | rev) 61 | # cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py | tr '\n' ' ' | PYTHONPATH=$ersatz python $ersatz/split.py $model_path/checkpoint.best --determiner_type=multilingual > out/$OUTPATH 62 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py | tr '\n' ' ' | ersatz --model $model_path/checkpoint.best --candidates=multilingual > out/$OUTPATH 63 | done; 64 | } 65 | 66 | task split_baseline : ersatz 67 | > out 68 | :: in_languages=@ 69 | :: baseline_type=@ 70 | :: pyenv=@ :: .submitter=$grid :: devices=@ 71 | :: devices_per_task=0 72 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 73 | { 74 | mkdir -p $out 75 | echo $baseline_type 76 | for lang in ${in_languages[@]}; 77 | do 78 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/test/$lang" 79 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py | tr '\n' ' ' | $baseline_type --lang $lang > out/$lang 80 | done; 81 | } 82 | 83 | 84 | task score : ersatz 85 | < in_dir=$out@split 86 | < log=$out@train 87 | > out 88 | :: in_languages=@ 89 | :: embed_size=@ 90 | :: left_size=@ 91 | :: right_size=@ 92 | :: transformer_nlayers=@ 93 | :: vocab_size=@ 94 | :: pyenv=@ :: .submitter=$grid :: devices=@ 95 | :: devices_per_task=0 96 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 97 | { 98 | mkdir -p $out 99 | for lang in ${in_languages[@]}; 100 | do 101 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/test/$lang" 102 | if [[ $transformer_nlayers != 0 ]]; 103 | then 104 | ARCH="transformer" 105 | else 106 | ARCH="linear" 107 | fi 108 | LOG_PATH=$(echo $log | rev | cut -d'/' -f2- | rev)/job.out 109 | PARAMS=$(grep "Training with" $LOG_PATH | rev | cut -d' ' -f1 | rev) 110 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py > new_gold.txt 111 | #PYTHONPATH=$ersatz python $ersatz/score.py new_gold.txt $in_dir/$lang --determiner_type=multilingual > $out/$lang 112 | ersatz_score new_gold.txt $in_dir/$lang > $out/$lang 113 | FSCORE=$(grep "F1" $out/$lang | tail -1 | cut -d' ' -f2-) 114 | echo -e $lang"\t"$PARAMS"\t"$embed_size"\t"$vocab_size"\t"$ARCH"\t"$left_size"\t"$right_size"\t"$FSCORE 115 | done; 116 | 117 | } 118 | 119 | task score_dev : ersatz 120 | < in_dir=$out@split_dev 121 | < log=$out@train 122 | > out 123 | :: in_languages=@ 124 | :: embed_size=@ 125 | :: left_size=@ 126 | :: right_size=@ 127 | :: transformer_nlayers=@ 128 | :: vocab_size=@ 129 | :: pyenv=@ :: .submitter=$grid :: devices=@ 130 | :: devices_per_task=0 131 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 132 | { 133 | mkdir -p $out 134 | for lang in ${in_languages[@]}; 135 | do 136 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/dev/$lang" 137 | if [[ $transformer_nlayers != 0 ]]; 138 | then 139 | ARCH="transformer" 140 | else 141 | ARCH="linear" 142 | fi 143 | LOG_PATH=$(echo $log | rev | cut -d'/' -f2- | rev)/job.out 144 | PARAMS=$(grep "Training with" $LOG_PATH | rev | cut -d' ' -f1 | rev) 145 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py > new_gold.txt 146 | #PYTHONPATH=$ersatz python $ersatz/score.py new_gold.txt $in_dir/$lang --determiner_type=multilingual > $out/$lang 147 | ersatz_score new_gold.txt $in_dir/$lang > $out/$lang 148 | FSCORE=$(grep "F1" $out/$lang | tail -1 | cut -d' ' -f2-) 149 | echo -e $lang"\t"$PARAMS"\t"$embed_size"\t"$vocab_size"\t"$ARCH"\t"$left_size"\t"$right_size"\t"$FSCORE 150 | done; 151 | 152 | } 153 | 154 | task score_baseline : ersatz 155 | < in_dir=$out@split_baseline 156 | > out 157 | :: in_languages=@ 158 | :: pyenv=@ :: .submitter=$grid :: devices=@ 159 | :: devices_per_task=0 160 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 161 | { 162 | mkdir -p $out 163 | for lang in ${in_languages[@]}; 164 | do 165 | FILE_PATH="/exp/rwicks/ersatz/data/balanced/test/$lang" 166 | cat $FILE_PATH | ~mpost/bin/strip_punc_rachel.py > new_gold.txt 167 | #PYTHONPATH=$ersatz python $ersatz/score.py new_gold.txt $in_dir/$lang --determiner_type=multilingual > $out/$lang 168 | ersatz_score new_gold.txt $in_dir/$lang > $out/$lang 169 | FSCORE=$(grep "F1" $out/$lang | tail -1 |cut -d' ' -f2-) 170 | RECALL=$(grep "Recall" $out/$lang | tail -1 | cut -d' ' -f2-) 171 | PRECISION=$(grep "Prec" $out/$lang | tail -1 | cut -d' ' -f2-) 172 | ACCURACY=$(grep "Acc" $out/$lang | tail -1 | cut -d' ' -f2-) 173 | echo -e $lang"\t"$FSCORE"\t"$RECALL"\t"$PRECISION"\t"$ACCURACY 174 | done; 175 | } 176 | 177 | task speed_test_cpu : ersatz 178 | < log=$out@train 179 | > out 180 | :: linear_nlayers=@ 181 | :: transformer_nlayers=@ 182 | :: pyenv=@ :: .submitter=$grid :: devices=@ 183 | :: devices_per_task=0 184 | :: .resource_flags=$cpuResourceFlags :: .action_flags=$cpuActionFlags 185 | { 186 | LOG_PATH=$(echo $log | rev | cut -d'/' -f2- | rev)/job.out 187 | PARAMS=$(grep "Training with" $LOG_PATH | rev | cut -d' ' -f1 | rev) 188 | LINEAR=$(($linear_nlayers-1)) 189 | MODEL_NAME="Transformer.$transformer_nlayers.Linear.$LINEAR" 190 | FSCORE=$(grep "EARLY STOPPING" $LOG_PATH | tail -1 | cut -d' ' -f3- | jq .inference_f1) 191 | TIME=`(time (cat /home/hltcoe/rwicks/speed-tests/wiki.1M.en | ersatz --model $log/checkpoint.best --cpu > split)) 2>&1 | grep real | cut -f2` 192 | echo -e $MODEL_NAME"\t"$PARAMS"\t"$FSCORE"\t"$TIME > time 193 | } 194 | -------------------------------------------------------------------------------- /ersatz/score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pathlib 4 | import sys 5 | 6 | if __package__ is None and __name__ == '__main__': 7 | parent = pathlib.Path(__file__).absolute().parents[1] 8 | sys.path.insert(0, str(parent)) 9 | __package__ = 'ersatz' 10 | 11 | from .candidates import MultilingualPunctuation, PunctuationSpace, Split 12 | 13 | def levenshtein(gold_sequence, pred_sequence): 14 | size_x = len(gold_sequence) + 1 15 | size_y = len(pred_sequence) + 1 16 | matrix = np.zeros((size_x, size_y)) 17 | for x in range(size_x): 18 | matrix[x][0] = x 19 | for y in range(size_y): 20 | matrix[0][y] = y 21 | for x in range(1, size_x): 22 | for y in range(1, size_y): 23 | if gold_sequence[x-1] not in ['', ''] and gold_sequence[x-1] == pred_sequence[y-1]: 24 | matrix[x][y] = min ( 25 | matrix[x-1][y] + 1, 26 | matrix[x-1][y-1], 27 | matrix[x][y-1] + 1 28 | ) 29 | else: 30 | matrix[x][y] = min ( 31 | matrix[x-1][y] + 1, 32 | matrix[x-1][y-1] + 1, 33 | matrix[x][y-1] + 1 34 | ) 35 | return matrix[size_x-1][size_y-1] 36 | 37 | def subset(one, two): 38 | for a, b in zip(one, two): 39 | if a != b: 40 | if a in ['', ''] and b not in ['', '']: 41 | return False 42 | return True 43 | 44 | def align(gold_sequence, pred_sequence): 45 | x = 0 46 | y = 0 47 | while gold_sequence[x] not in pred_sequence: 48 | x += 2 49 | while pred_sequence[y] != gold_sequence[x]: 50 | y += 2 51 | if subset(gold_sequence[x:], pred_sequence[y:]): 52 | return gold_sequence[x:], pred_sequence[y:], levenshtein(gold_sequence[:x], pred_sequence[:y]) 53 | else: 54 | gold_sequence_prefix, pred_sequence_prefix, min_edit = align(gold_sequence[:-2], pred_sequence[:-2]) 55 | return gold_sequence_prefix + gold_sequence[-2:], pred_sequence_prefix + pred_sequence[-2:], min_edit 56 | 57 | def make_context_mappings(content): 58 | content = content.replace('\n', ' \u2581 ') 59 | content = content.split() 60 | i = 0 61 | out = {} 62 | for x, tok in enumerate(content): 63 | for y, ch in enumerate(tok): 64 | i += 1 65 | out[i] = (x+1,y) 66 | if y == len(tok)-1: 67 | if tok != '\u2581' and x < len(content) - 1 and content[x+1] != '\u2581': 68 | i += 1 69 | else: 70 | i += 1 71 | return out, content 72 | 73 | 74 | 75 | def generator(content): 76 | out = [] 77 | content = content.replace('\n', '\u2581') 78 | content = ''.join(content.split()) 79 | for i,c in enumerate(content): 80 | if c == '\u2581': 81 | out.append('') 82 | else: 83 | out.append(c) 84 | if i < len(content)-1 and content[i+1] != '\u2581': 85 | out.append('') 86 | return out 87 | 88 | 89 | def score(target_file, pred_file): 90 | pred_content = open(pred_file).read().strip() 91 | pred_gen = generator(pred_content) 92 | #context_lookup, pred_content = make_context_mappings(pred_content) 93 | 94 | target_content = open(target_file).read().strip() 95 | target_gen = generator(target_content) 96 | context_lookup, target_content = make_context_mappings(target_content) 97 | 98 | correct_eos = 0 99 | incorrect_eos = 0 100 | correct_mos = 0 101 | incorrect_mos = 0 102 | index = 0 103 | running_index = 0 104 | total_edits = 0 105 | errors = [] 106 | 107 | type_one = [] 108 | type_two = [] 109 | reached = False 110 | while (index < len(pred_gen)): 111 | pred = pred_gen[index] 112 | target = target_gen[index] 113 | if (pred != target): 114 | if pred in ['', ''] and target in ['', '']: 115 | mapped_index_x, mapped_index_y = context_lookup[running_index] 116 | left_context = target_content[mapped_index_x-5:mapped_index_x-1] 117 | right_context = target_content[mapped_index_x:mapped_index_x+5] 118 | 119 | 120 | if mapped_index_x < len(target_content) and mapped_index_y == len(target_content[mapped_index_x])-1: 121 | left_context += [target_content[mapped_index_x-1]] 122 | else: 123 | left_context += [target_content[mapped_index_x-1][:mapped_index_y+1]] 124 | right_context = [target_content[mapped_index_x-1][mapped_index_y+1:]] + right_context 125 | 126 | left_context = ' '.join(left_context) 127 | right_context = ' '.join(right_context).replace('\u2581', ' ') 128 | print(f'{left_context} {pred} {right_context}') 129 | if target == '': 130 | incorrect_eos += 1 131 | else: 132 | incorrect_mos += 1 133 | 134 | elif pred in ['', ''] or target in ['', '']: 135 | exit(-1) 136 | else: 137 | SUCCESS = False 138 | range = 4 139 | while not SUCCESS: 140 | try: 141 | rem_gold, rem_pred, edit_dist = align(target_gen[index:index+range], pred_gen[index:index+range]) 142 | running_index += range-len(rem_gold) 143 | pred_gen = rem_pred + pred_gen[index+range:] 144 | target_gen = rem_gold + target_gen[index+range:] 145 | SUCCESS = True 146 | except: 147 | # expands window to align in until match is found 148 | range += 2 149 | index = 0 150 | total_edits += edit_dist 151 | else: 152 | if pred in ['', '']: 153 | if pred == '': 154 | correct_eos += 1 155 | else: 156 | correct_mos += 1 157 | index += 1 158 | running_index += 1 159 | 160 | total = correct_eos + incorrect_eos + correct_mos + incorrect_mos 161 | try: 162 | accuracy = (correct_eos+correct_mos)/total 163 | except: 164 | accuracy = 'n/a' 165 | try: 166 | recall = (correct_eos)/(correct_eos+incorrect_eos) 167 | except: 168 | recall = 'n/a' 169 | try: 170 | precision = (correct_eos)/(correct_eos+incorrect_mos) 171 | except: 172 | precision = 'n/a' 173 | try: 174 | f1 = (2*precision*recall)/(precision+recall) 175 | except: 176 | f1 = 'n/a' 177 | try: 178 | print(f'Accuracy {accuracy*100:.2f}') 179 | except: 180 | print(f'Accuracy n/a') 181 | try: 182 | print(f'Recall {recall*100:.2f}') 183 | except: 184 | print(f'Recall n/a') 185 | try: 186 | print(f'Precision {precision*100:.2f}') 187 | except: 188 | print("Precision n/a") 189 | try: 190 | print(f'F1 {f1*100:.2f}') 191 | except: 192 | print("F1 n/a") 193 | for one in type_one: 194 | print(one) 195 | for two in type_two: 196 | print(two) 197 | print(total_edits) 198 | 199 | 200 | def main(): 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('rubric_file_path', type=str) 203 | parser.add_argument('pred_file_path', type=str) 204 | parser.add_argument('--determiner_type', default='multilingual', choices=['en', 'multilingual', 'all']) 205 | 206 | args = parser.parse_args() 207 | 208 | if args.determiner_type == "en": 209 | determiner = PunctuationSpace() 210 | elif args.determiner_type == 'multilingual': 211 | determiner = MultilingualPunctuation() 212 | else: 213 | determiner = Split() 214 | 215 | score(args.rubric_file_path, args.pred_file_path) 216 | 217 | if __name__ == '__main__': 218 | main() 219 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ersatz/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import pathlib 5 | import sys 6 | import logging 7 | import argparse 8 | 9 | if __package__ is None and __name__ == '__main__': 10 | parent = pathlib.Path(__file__).absolute().parents[1] 11 | sys.path.insert(0, str(parent)) 12 | __package__ = 'ersatz' 13 | 14 | from .subword import Vocabulary, SentencePiece 15 | from .candidates import MultilingualPunctuation, PunctuationSpace, Split 16 | 17 | logger = logging.getLogger('ersatz') 18 | 19 | ####################################################################################################### 20 | 21 | # iterates over a file and yields one doc at a time with the appropriately 22 | # labelled splits; documents are separated by empty lines 23 | def document_generator(file_path, tokenizer=None): 24 | document = [] 25 | with open(file_path) as input_file: 26 | for line in input_file: 27 | if len(tokenizer.encode(line, out_type=str)) > 0: 28 | line = line.strip() 29 | line = tokenizer.encode(line, out_type=str) 30 | new_line = [] 31 | for l in line: 32 | new_line.append(l) 33 | new_line.append('') 34 | new_line[-1] = '' 35 | document += new_line 36 | else: 37 | yield document 38 | document = [] 39 | yield document 40 | 41 | # this builds all the training data from a plain text file 42 | # writes it out to a new file 43 | def split_train_file(file_paths, 44 | tokenizer, 45 | output_path=None, 46 | left_context_size=5, 47 | right_context_size=5, 48 | determiner=None): 49 | 50 | random.seed(14) 51 | 52 | # import pdb; pdb.set_trace() 53 | with open(output_path, 'w') as f: 54 | for file_path in file_paths: 55 | for doc in document_generator(file_path, tokenizer=tokenizer): 56 | if len(doc) > 0: 57 | left_temp = ["" for x in range(left_context_size-1)] + [doc[0]] 58 | right_temp = [x for x in doc[1:(2*right_context_size)+1] if x not in ["", ""]] 59 | temp_index = 2*right_context_size+2 60 | for index, word in enumerate(doc): 61 | if word in ['', '']: 62 | 63 | label = word 64 | 65 | if determiner(''.join(left_temp).replace('\u2581', ' ').replace('', ''), 66 | ''.join(right_temp).replace('\u2581', ' ').replace('', '')): 67 | f.write(' '.join(left_temp) + ' ||| ' + ' '.join(right_temp) + ' ||| ' + label + '\n') 68 | 69 | left_temp.pop(0) 70 | left_temp.append(right_temp.pop(0)) 71 | if temp_index < len(doc): 72 | right_temp.append(doc[temp_index]) 73 | temp_index += 2 74 | else: 75 | right_temp.append("") 76 | # split test files 77 | # the difference between this and the previous is there are no labels in data 78 | def split_test_file(document, tokenizer, left_context_size, right_context_size): 79 | document = tokenizer.encode(document, out_type=str) 80 | left_contexts = [] 81 | right_contexts = [] 82 | if len(document) > 0: 83 | left_temp = ["" for x in range(left_context_size - 1)] + [document[0]] 84 | right_temp = [x for x in document[1:(right_context_size) + 1]] 85 | while (len(right_temp) < right_context_size): 86 | right_temp.append("") 87 | temp_index = right_context_size + 1 88 | for index, word in enumerate(document, 0): 89 | left_contexts.append(' '.join(left_temp)) 90 | 91 | right_contexts.append(' '.join(right_temp)) 92 | 93 | left_temp.pop(0) 94 | left_temp.append(right_temp.pop(0)) 95 | if temp_index < len(document): 96 | right_temp.append(document[temp_index]) 97 | temp_index += 1 98 | else: 99 | right_temp.append("") 100 | return left_contexts, right_contexts 101 | 102 | 103 | 104 | def write_training_files(file_path, left_contexts, right_contexts, labels, left_context_size=5, right_context_size=5): 105 | output_path = file_path.split('/')[-1].split('.') 106 | output_path = output_path[:-1] + [f'{left_context_size}-{right_context_size}-context'] + output_path[-1:] 107 | output_path = os.path.join('/'.join(file_path.split('/')[:-1]),'.'.join(output_path)) 108 | 109 | with open(output_path, 'w') as f: 110 | for left, right, label in zip(left_contexts, right_contexts, labels): 111 | f.write(' '.join(left) + ' ||| ' + ' '.join(right) + ' ||| ' + label + '\n') 112 | 113 | class SourceFactors(): 114 | def __init__(self): 115 | self.codes = { 116 | 'UNMARK': 0, 117 | 'CAP': 1, 118 | 'LOWER': 2, 119 | 'PUNC': 3, 120 | 'TITLE': 4, 121 | 'NUMBER': 5 122 | } 123 | pass 124 | 125 | # specific to sentencepiece 126 | def compute(self, token_stream): 127 | word = [] 128 | output_stream = [] 129 | token_stream = token_stream.split() 130 | for t in token_stream + ['\u2581']: 131 | if '\u2581' in t: 132 | out = None 133 | # potentially add a marker for truncated words in left context 134 | if len(word) > 0: 135 | untok = ''.join(word).replace('\u2581', '') 136 | if untok.istitle(): 137 | out = [self.codes['TITLE'] for w in word] 138 | elif untok.isupper(): 139 | out = [self.codes['CAP'] for w in word] 140 | elif untok.islower(): 141 | out = [self.codes['LOWER'] for w in word] 142 | elif untok in string.punctuation: 143 | out = [self.codes['PUNC'] for w in word] 144 | else: 145 | for w in untok: 146 | if w in string.digits: 147 | out = [self.codes['NUMBER'] for w in word] 148 | break 149 | if not out: 150 | out = [self.codes['UNMARK'] for w in word] 151 | output_stream += out 152 | word = [] 153 | word.append(t) 154 | assert(len(output_stream)==len(token_stream)) 155 | return output_stream 156 | 157 | 158 | class ErsatzDataset(): 159 | def __init__(self, data_path, device, 160 | left_context_size=15, 161 | right_context_size=5, 162 | sentencepiece_path=None, 163 | tokenizer=None): 164 | if tokenizer is None: 165 | if sentencepiece_path is not None: 166 | self.tokenizer = SentencePiece(model_path=sentencepiece_path) 167 | else: 168 | self.tokenizer = Vocabulary() 169 | self.tokenizer.build_vocab(data_path) 170 | else: 171 | self.tokenizer = tokenizer 172 | self.device = device 173 | self.size = 0 174 | 175 | if not os.path.exists(data_path): 176 | raise Exception("path does not exist") 177 | 178 | self.left_context_size = left_context_size 179 | self.right_context_size = right_context_size 180 | 181 | self.data_path = data_path 182 | self.source_factors = SourceFactors() 183 | 184 | def __len__(self): 185 | return self.size 186 | 187 | def batchify(self, batch_size): 188 | data = [] 189 | context_strings = [] 190 | batch_idx = 0 191 | #factors = [] 192 | with open(self.data_path) as f: 193 | for line in f: 194 | self.size += 1 195 | if len(line.strip().split('|||')) == 3: 196 | left, right, label = line.strip().split('|||') 197 | # little check because some datasets have '|||' ... maybe change eventually to special character code ? 198 | if (len(left.split()) == self.left_context_size) and (len(right.split()) == self.right_context_size): 199 | data.append((left.strip(), self.source_factors.compute(left.strip()), 200 | right.strip(), self.source_factors.compute(right.strip()), 201 | label.strip())) 202 | context_strings.append((left.strip(), right.strip())) 203 | if len(data) >= batch_size: 204 | context, factors, label = self.tokenizer.context_to_tensor(data) 205 | context = context.view(len(data), -1) 206 | factors = factors.view(len(data), -1) 207 | label = label.view(len(data)) 208 | yield context, factors, label, context_strings 209 | batch_idx += 1 210 | data = [] 211 | context_strings = [] 212 | context, factors, label = self.tokenizer.context_to_tensor(data) 213 | context = context.view(len(data), -1) 214 | factors = factors.view(len(data), -1) 215 | label = label.view(len(data)) 216 | if len(data) > 0: 217 | yield context, factors, label, context_strings 218 | batch_idx += 1 219 | 220 | ############################################################################## 221 | 222 | def parse_args(): 223 | parser = argparse.ArgumentParser( 224 | description="ERSATZ PREPROCESSOR: converts raw text (~one sentence per line) to expected input for ersatz training.\n" 225 | " Example: ersatz_preprocess --sp en.8000.model --output_path en.train file1.txt file2.txt file3.txt", 226 | formatter_class=argparse.RawTextHelpFormatter 227 | ) 228 | 229 | parser.add_argument('--sentencepiece_path', '--sp', type=str, default=None, 230 | help="Path to sentencepiece .model file to be used as tokenizer") 231 | parser.add_argument('--output_path', type=str, default="train.data", 232 | help="File path where output will be written") 233 | parser.add_argument('--left-size', type=int, default=5, 234 | help="Number of tokens of left context to use for predictions") 235 | parser.add_argument('--right-size', type=int, default=5, 236 | help="Number of tokens of right context to use for predictions") 237 | parser.add_argument('--determiner_type', default='multilingual', choices=["en", "multilingual", "all"], 238 | help="Type of contexts to include. Defaults to 'multilingual'\n" 239 | " * en: [EOS punctuation][any_punctuation]*[space]\n" 240 | " * multilingual: [EOS punctuation][!number]\n" 241 | " * all: all possible contexts") 242 | parser.add_argument('--input_paths', nargs='*', default=None, 243 | help="Paths to raw text input files") 244 | args = parser.parse_args() 245 | return args 246 | 247 | def main(): 248 | args = parse_args() 249 | 250 | if args.sentencepiece_path is not None: 251 | tokenizer = SentencePiece(model_path=args.sentencepiece_path) 252 | else: 253 | logger.error("ERROR: No --sentencepiece_path was given. Training one as part of preprocessing is not currently supported.") 254 | sys.exit(-1) 255 | 256 | if args.determiner_type == "en": 257 | determiner = PunctuationSpace() 258 | elif args.determiner_type == "multilingual": 259 | determiner = MultilingualPunctuation() 260 | else: 261 | determiner = Split() 262 | 263 | split_train_file(args.input_paths, tokenizer, 264 | output_path=args.output_path, 265 | left_context_size=args.left_size, 266 | right_context_size=args.right_size, 267 | determiner=determiner) 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /ersatz/split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pathlib 4 | import os 5 | import torch 6 | import argparse 7 | import sys 8 | import csv 9 | 10 | if __package__ is None and __name__ == '__main__': 11 | parent = pathlib.Path(__file__).absolute().parents[1] 12 | sys.path.insert(0, str(parent)) 13 | __package__ = 'ersatz' 14 | 15 | from . import __version__ 16 | from .utils import get_model_path, list_models, MODELS 17 | from .model import ErsatzTransformer 18 | from .dataset import SourceFactors, split_test_file 19 | from .candidates import PunctuationSpace, MultilingualPunctuation, Split 20 | from .subword import SentencePiece 21 | 22 | import logging 23 | 24 | 25 | logging.basicConfig( 26 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 27 | datefmt="%Y-%m-%d %H:%M:%S", 28 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 29 | stream=sys.stderr, 30 | ) 31 | logger = logging.getLogger("ersatz") 32 | 33 | # default args for loading models 34 | # should write something to merge default args with loaded args (overwrite when applicable) 35 | class DefaultArgs(): 36 | def __init__(self): 37 | self.left_context_size=4 38 | self.right_context_size=6 39 | self.embed_size=256 40 | self.nhead=8 41 | self.dropout=0.1 42 | self.transformer_nlayers=2 43 | self.linear_nlayers=0 44 | self.activation_type='tanh' 45 | 46 | 47 | def load_model(checkpoint_path): 48 | model_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) 49 | tokenizer = SentencePiece(serialization=model_dict['tokenizer']) 50 | model = ErsatzTransformer(tokenizer, model_dict['args']) 51 | model.load_state_dict(model_dict['weights']) 52 | model.eval() 53 | return model 54 | 55 | def detokenize(input_string): 56 | input_string = input_string.replace(' ', '') 57 | input_string = input_string.replace('\u2581', ' ') 58 | return input_string 59 | 60 | class EvalModel(): 61 | def __init__(self, model_path): 62 | self.model = load_model(model_path) 63 | if type(self.model) is torch.nn.DataParallel: 64 | self.model = self.model.module 65 | 66 | self.tokenizer = self.model.tokenizer 67 | self.left_context_size = self.model.left_context_size 68 | self.right_context_size = self.model.right_context_size 69 | self.context_size = self.right_context_size + self.left_context_size 70 | 71 | def batchify(self, content, batch_size, candidates): 72 | source_factors = SourceFactors() 73 | left_contexts, right_contexts = split_test_file(content, self.tokenizer, self.left_context_size, self.right_context_size) 74 | if len(left_contexts) > 0: 75 | lines = [] 76 | indices = [] 77 | index = 1 78 | for left, right in zip(left_contexts, right_contexts): 79 | if candidates(detokenize(' '.join(left)), detokenize(' '.join(right))): 80 | lines.append((left, source_factors.compute(left), 81 | right, source_factors.compute(right), 82 | '')) 83 | indices.append(index) 84 | index += 1 85 | indices = torch.tensor(indices) 86 | data, factors, _ = self.tokenizer.context_to_tensor(lines) 87 | 88 | nbatch = data.size(0) // batch_size 89 | remainder = data.size(0) % batch_size 90 | 91 | if remainder > 0: 92 | remaining_data = data.narrow(0, nbatch*batch_size, remainder) 93 | remaining_factors = factors.narrow(0, nbatch*batch_size, remainder) 94 | remaining_indices = indices.narrow(0, nbatch*batch_size, remainder) 95 | 96 | data = data.narrow(0, 0, nbatch*batch_size) 97 | factors = factors.narrow(0, 0, nbatch*batch_size) 98 | indices = indices.narrow(0, 0, nbatch * batch_size) 99 | 100 | data = data.view(batch_size, -1).t().contiguous() 101 | factors = factors.view(batch_size, -1).t().contiguous() 102 | indices = indices.view(batch_size, -1).t().contiguous() 103 | 104 | if remainder > 0: 105 | remaining_data = remaining_data.view(remainder, -1).t().contiguous() 106 | remaining_factors = remaining_factors.view(remainder, -1).t().contiguous() 107 | remaining_indices = remaining_indices.view(remainder, -1).t().contiguous() 108 | 109 | 110 | batches = [] 111 | data = data.view(-1, self.context_size, batch_size) 112 | factors = factors.view(-1, self.context_size, batch_size) 113 | indices = indices.view(-1, 1, batch_size) 114 | 115 | if remainder > 0: 116 | remaining_data = remaining_data.view(-1, self.context_size, remainder) 117 | remaining_factors = remaining_factors.view(-1, self.context_size, remainder) 118 | remaining_indices = remaining_indices.view(-1, 1, remainder) 119 | for context_batch, factors_batch, index_batch in zip(data, factors, indices): 120 | batches.append((context_batch.t(), factors_batch.t(), index_batch[0])) 121 | if remainder > 0: 122 | batches.append((remaining_data[0].t(), remaining_factors[0].t(), remaining_indices[0][0])) 123 | return batches 124 | else: 125 | return [] 126 | 127 | def parallel_evaluation(self, content, batch_size, candidates=None, min_sent_length=3): 128 | batches = self.batchify(content, batch_size, candidates) 129 | eos = [] 130 | for contexts, factors, indices, in batches: 131 | data = contexts.to(self.device) 132 | if not self.model.source_factors: 133 | factors = None 134 | else: 135 | factors = factors.to(self.device) 136 | 137 | output = self.model.forward(data, factors=factors) 138 | 139 | pred = output.argmax(1) 140 | pred_ind = torch.where(pred == 0)[0] 141 | pred_ind = [indices[i].item() for i in pred_ind] 142 | eos.extend(pred_ind) 143 | if len(eos) == 0: 144 | yield content.strip() 145 | else: 146 | eos = sorted(eos) 147 | next_index = int(eos.pop(0)) 148 | this_content = self.tokenizer.encode(content, out_type=str) 149 | output = [] 150 | counter = 0 151 | for index, word in enumerate(this_content): 152 | if counter == next_index: 153 | try: 154 | next_index = int(eos.pop(0)) 155 | except: 156 | next_index = len(content)-1 157 | if (next_index - counter >= 5): 158 | output.append('\n') 159 | output.append(word) 160 | 161 | else: 162 | output.append(word) 163 | counter += 1 164 | output = self.tokenizer.merge(output, technique='utility').strip().split('\n') 165 | yield '\n'.join([o.strip() for o in output]) 166 | yield None 167 | 168 | def split(self, input_file, output_file, batch_size, candidates=None): 169 | for line in input_file: 170 | for batch_output in self.parallel_evaluation(line, batch_size, candidates=candidates): 171 | if batch_output is not None: 172 | print(batch_output.strip(), file=output_file) 173 | return output_file 174 | 175 | 176 | def split_delimiter(self, input_file, output_file, batch_size, delimiter, columns, candidates=None): 177 | input_file = csv.reader(input_file, delimiter=delimiter) 178 | for line in input_file: 179 | new_lines = [] 180 | max_len = 1 181 | for i, l in enumerate(line): 182 | if i in columns: 183 | for batch_output in self.parallel_evaluation(l, batch_size, candidates=candidates): 184 | if batch_output is not None: 185 | batch_output = batch_output.split('\n') 186 | new_lines.append(batch_output) 187 | if len(batch_output) > max_len: 188 | max_len = len(batch_output) 189 | else: 190 | new_lines.append([line[i]]) 191 | for x in range(max_len): 192 | out_line = [] 193 | for i, col in enumerate(new_lines): 194 | if x >= len(col): 195 | if i not in columns: 196 | out_line.append(col[-1]) 197 | else: 198 | out_line.append('') 199 | else: 200 | out_line.append(col[x]) 201 | print(delimiter.join(out_line).strip(), file=output_file) 202 | 203 | def parse_args(): 204 | parser = argparse.ArgumentParser( 205 | description="ERSATZ SEGMENTER: Segments input text into sentences.\n" 206 | " Example: ersatz --model fr --input wikipedia.fr --output output.fr", 207 | usage='%(prog)s [-h] [--model MODEL] [--input INPUT] [--output OUTPUT] [OPTIONS]', 208 | formatter_class=argparse.RawTextHelpFormatter 209 | ) 210 | 211 | main_group = parser.add_argument_group('arguments') 212 | main_group.add_argument('--model', '-m', default='default-multilingual', 213 | help="Either name of or path to a pre-trained ersatz model") 214 | main_group.add_argument('--input', '-i', default=None, 215 | help="Input file. None means stdin") 216 | main_group.add_argument('--output', '-o', default=None, 217 | help="Output file. None means stdout") 218 | main_group.add_argument('--batch-size', '-b', type=int, default=16, 219 | help="Batch size--predictions to make at once") 220 | main_group.add_argument('--candidates', '-c', default='multilingual', choices=['multilingual', 'en', 'all'], 221 | help = "Criteria for selecting candidate sites. Defaults to 'multilingual'\n" 222 | " * multilingual: [EOS punctuation][!number] (sentence-ending punctuation followed by a non-digit)\n" 223 | " * en: [EOS punctuation][any_punctuation]*[space] (sentence-ending punctuation followed by a space)\n" 224 | " * all: all possible contexts") 225 | main_group.add_argument('--cpu', action='store_true', help="Uses CPU (GPU is default if available)") 226 | 227 | tsv_group = parser.add_argument_group('tsv options', description="Used for splitting .csv/.tsv/etc files. This mode triggered by '--columns'") 228 | tsv_group.add_argument('--delimiter', '-d', type=str, default='\t', 229 | help="Delimiter character (default is \\t)\n" 230 | " * '--columns' must be set" 231 | ) 232 | tsv_group.add_argument('--columns', '-C', type=int, default=None, nargs="*", 233 | help="Columns to split (0-indexed). If empty, plain-text\n" 234 | ) 235 | 236 | options = parser.add_argument_group('additional options') 237 | options.add_argument('--version', '-V', action='store_true', help="Prints ersatz version") 238 | options.add_argument('--download', '-D', action='store_true', 239 | help="Downloads model selected via '--model'") 240 | options.add_argument('--list', '-l', action='store_true', 241 | help="Lists available models.") 242 | options.add_argument('--quiet', '-q', action='store_true', 243 | help="Disables logging.") 244 | 245 | args = parser.parse_args() 246 | 247 | args.text = None 248 | 249 | return args 250 | 251 | def split(args): 252 | if args.candidates == "en": 253 | candidates = PunctuationSpace() 254 | elif args.candidates == 'multilingual': 255 | candidates = MultilingualPunctuation() 256 | else: 257 | candidates = Split() 258 | 259 | if args.input is not None: 260 | input_file = open(args.input, 'r') 261 | elif args.text is not None: 262 | input_file = args.text.split('\n') 263 | else: 264 | input_file = sys.stdin 265 | 266 | if args.output is not None: 267 | output_file = open(args.output, 'w') 268 | elif args.text is not None: 269 | from io import StringIO 270 | output_file = StringIO() 271 | else: 272 | output_file = sys.stdout 273 | 274 | if torch.cuda.is_available() and not args.cpu: 275 | device = torch.device('cuda') 276 | else: 277 | device = torch.device('cpu') 278 | 279 | if args.model not in MODELS: 280 | model = EvalModel(args.model) 281 | else: 282 | model_path = get_model_path(args.model) 283 | model = EvalModel(model_path) 284 | 285 | model.model = model.model.to(device) 286 | model.device = device 287 | 288 | with torch.no_grad(): 289 | if args.columns is None: 290 | output_file = model.split(input_file, output_file, args.batch_size, candidates=candidates) 291 | else: 292 | output_file = model.split_delimiter(input_file, output_file, args.batch_size, args.delimiter, args.columns, candidates=candidates) 293 | 294 | if args.text: 295 | return output_file.getvalue().strip().split('\n') 296 | 297 | def main(): 298 | 299 | args = parse_args() 300 | 301 | if args.version: 302 | from . import __version__ 303 | print("ersatz", __version__) 304 | sys.exit(0) 305 | 306 | if args.download: 307 | get_model_path(args.model) 308 | sys.exit(0) 309 | 310 | if args.list: 311 | list_models() 312 | sys.exit(0) 313 | 314 | if args.quiet: 315 | logger.setLevel(logging.ERROR) 316 | 317 | split(args) 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /ersatz/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import gzip 4 | import os 5 | import ssl 6 | import sys 7 | import urllib.request 8 | import hashlib 9 | import shutil 10 | import logging 11 | import progressbar 12 | 13 | # TODO: change the loglevel here if -q is passed 14 | logging.basicConfig( 15 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 16 | datefmt="%Y-%m-%d %H:%M:%S", 17 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 18 | stream=sys.stderr, 19 | ) 20 | logger = logging.getLogger("ersatz") 21 | 22 | USERHOME = os.path.expanduser("~") 23 | ERSATZ_DIR = os.environ.get("ERSATZ", os.path.join(USERHOME, ".ersatz")) 24 | 25 | MODELS = { 26 | "en" : { 27 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/en/01.Jun.21.en.gz", 28 | "info" : "An English monolingual model trained on English News Commentary", 29 | "description" : "monolingual/en", 30 | "destination": "monolingual/en/01.Jun.21.en", 31 | "date": "01 June 2021", 32 | "md5" : "75e8700396a21b1fe7e08f91b1971978" 33 | }, 34 | "ar" : { 35 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ar/01.Jun.21.ar.gz", 36 | "info" : "An Arabic monolingual model trained on Arabic News Commentary and Wikipedia data", 37 | "description" : "monolingual-ar", 38 | "destination" : "monolingual/ar/01.Jun.21.ar", 39 | "date" : "01 June 2021", 40 | "md5" : "deb6c246bd8d48478f7872668737b3e1" 41 | }, 42 | "cs" : { 43 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/cs/01.Jun.21.cs.gz", 44 | "info" : "A Czech monolingual model trained on Czech News Commentary and Wikipedia data", 45 | "description" : "monolingual-cs", 46 | "destination" : "monolingual/cs/01.Jun.21.cs", 47 | "date" : "01 June 2021", 48 | "md5" : "71fca6f2ab670843a2b698ab118b9fce" 49 | }, 50 | "de" : { 51 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/de/01.Jun.21.de.gz", 52 | "info" : "A German monolingual model trained on German News Commentary and Wikipedia data", 53 | "description" : "monolingual-de", 54 | "destination" : "monolingual/de/01.Jun.21.de", 55 | "date" : "01 June 2021", 56 | "md5" : "3dcb90a96e5a1c4e151a4bdac24db459" 57 | }, 58 | 59 | 60 | "es" : { 61 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/es/01.Jun.21.es.gz", 62 | "info" : "A Spanish monolingual model trained on Spanish News Commentary data", 63 | "description" : "monolingual-es", 64 | "destination" : "monolingual/es/01.Jun.21.es", 65 | "date" : "01 June 2021", 66 | "md5" : "6bf02a677365ead6db9efdfc7fce5586" 67 | }, 68 | "et" : { 69 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/et/01.Jun.21.et.gz", 70 | "info" : "A Estonian monolingual model trained on Estonian News Crawl data", 71 | "description" : "monolingual-et", 72 | "destination" : "monolingual/et/01.Jun.21.et", 73 | "date" : "01 June 2021", 74 | "md5" : "a8e0b2f93e4300c41097e83e34cf793f" 75 | }, 76 | "fi" : { 77 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/fi/01.Jun.21.fi.gz", 78 | "info" : "A Finnish monolingual model trained on Finnish News Crawl and Wikipedia data", 79 | "description" : "monolingual-fi", 80 | "destination" : "monolingual/fi/01.Jun.21.fi", 81 | "date" : "01 June 2021", 82 | "md5" : "02382a6b06c4e0475f4940a60fd1506c" 83 | }, 84 | "fr" : { 85 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/fr/01.Jun.21.fr.gz", 86 | "info" : "A French monolingual model trained on French News Commentary and Wikipedia data", 87 | "description" : "monolingual-fr", 88 | "destination" : "monolingual/fr/01.Jun.21.fr", 89 | "date" : "01 June 2021", 90 | "md5" : "db67663cd9bd33b4e5aad919e5c1fefa" 91 | }, 92 | "gu" : { 93 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/gu/01.Jun.21.gu.gz", 94 | "info" : "A Gujarti monolingual model trained on Gujarti News Crawl and Common Crawl data", 95 | "description" : "monolingual-gu", 96 | "destination" : "monolingual/gu/01.Jun.21.gu", 97 | "date" : "01 June 2021", 98 | "md5" : "939929a3859daafa64c8b0360bced552" 99 | }, 100 | "hi" : { 101 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/hi/01.Jun.21.hi.gz", 102 | "info" : "A Hindi monolingual model trained on Hindi News Commentary, News Crawl, and Wikipeda data", 103 | "description" : "monolingual-hi", 104 | "destination" : "monolingual/hi/01.Jun.21.hi", 105 | "date" : "01 June 2021", 106 | "md5" : "2df0baa5ce535ef2f8c221f4899ed38d" 107 | }, 108 | "iu" : { 109 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/iu/01.Jun.21.iu.gz", 110 | "info" : "A Inuktitut monolingual model trained on Inuktitut data from the Nunavut-Hansard-Inuktitut-English Parallel Corpus 3.0", 111 | "description" : "monolingual-iu", 112 | "destination" : "monolingual/iu/01.Jun.21.iu", 113 | "date" : "01 June 2021", 114 | "md5" : "c5ca9cefc5633528d039b89924030dcb" 115 | }, 116 | "ja" : { 117 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ja/01.Jun.21.ja.gz", 118 | "info" : "A Japanese monolingual model trained on Japanese News Commentary and News Crawl data", 119 | "description" : "monolingual-ja", 120 | "destination" : "monolingual/ja/01.Jun.21.ja", 121 | "date" : "01 June 2021", 122 | "md5" : "4b6f85485757d8d5df16212193b7b2c8" 123 | }, 124 | "kk" : { 125 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/kk/01.Jun.21.kk.gz", 126 | "info" : "A Kazakh monolingual model trained on Kazakh News Commentary and News Crawl data", 127 | "description" : "monolingual-kk", 128 | "destination" : "monolingual/kk/01.Jun.21.kk", 129 | "date" : "01 June 2021", 130 | "md5" : "25464c09d8621ed3c4c0cdd594f771bb" 131 | }, 132 | "km" : { 133 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/km/01.Jun.21.km.gz", 134 | "info" : "A Khmer monolingual model trained on Khmer JW300 Corpus and Common Crawl data", 135 | "description" : "monolingual-km", 136 | "destination" : "monolingual/km/01.Jun.21.km", 137 | "date" : "01 June 2021", 138 | "md5" : "9c163b927d2641c205dcadeda8aefba4" 139 | }, 140 | "lt" : { 141 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/lt/01.Jun.21.lt.gz", 142 | "info" : "A Lithuanian monolingual model trained on Lithuanian News Crawl and Wikipedia data", 143 | "description" : "monolingual-lt", 144 | "destination" : "monolingual/lt/01.Jun.21.lt", 145 | "date" : "01 June 2021", 146 | "md5" : "b1f3ee5f2a745adf15cedcb6cb7ed2be" 147 | }, 148 | "lv" : { 149 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/lv/01.Jun.21.lv.gz", 150 | "info" : "A Latvian monolingual model trained on Latvian News Crawl data", 151 | "description" : "monolingual-lv", 152 | "destination" : "monolingual/lv/01.Jun.21.lv", 153 | "date" : "01 June 2021", 154 | "md5" : "be0c7a7d8e7f9d8933d1c4f5e1b52b0b" 155 | }, 156 | "pl" : { 157 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/pl/01.Jun.21.pl.gz", 158 | "info" : "A Polish monolingual model trained on Polish News Crawl, Global Voices, and Wikipedia data", 159 | "description" : "monolingual-pl", 160 | "destination" : "monolingual/pl/01.Jun.21.pl", 161 | "date" : "01 June 2021", 162 | "md5" : "6b874717e93147dd8f55fad2c4e7a1a3" 163 | }, 164 | "ps" : { 165 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ps/01.Jun.21.ps.gz", 166 | "info" : "A Pashto monolingual model trained on Pashto News Crawl, SADA, SYSTRAN, and TRANSTAC data", 167 | "description" : "monolingual-ps", 168 | "destination" : "monolingual/ps/01.Jun.21.ps", 169 | "date" : "01 June 2021", 170 | "md5" : "13b2863e0d907606625743d0f091c294" 171 | }, 172 | "ro" : { 173 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ro/01.Jun.21.ro.gz", 174 | "info" : "A Romanian monolingual model trained on Romanian News Crawl, Global Voices, and Wikipedia data", 175 | "description" : "monolingual-ro", 176 | "destination" : "monolingual/ro/01.Jun.21.ro", 177 | "date" : "01 June 2021", 178 | "md5" : "574a396c19330003b12dc91bf1e77ef5" 179 | }, 180 | "ru" : { 181 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ru/01.Jun.21.ru.gz", 182 | "info" : "A Russian monolingual model trained on Russian News Commentary and Wikipedia data", 183 | "description" : "monolingual-ru", 184 | "destination" : "monolingual/ru/01.Jun.21.ru", 185 | "date" : "01 June 2021", 186 | "md5" : "50136fed7ad3330eb8c59c3c79864179" 187 | }, 188 | "ta" : { 189 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/ta/01.Jun.21.ta.gz", 190 | "info" : "A Tamil monolingual model trained on Tamil Wikipedia and News Crawl data", 191 | "description" : "monolingual-ta", 192 | "destination" : "monolingual/ta/01.Jun.21.ta", 193 | "date" : "01 June 2021", 194 | "md5" : "00785d88c84ee656343b0d12082e1c3a" 195 | }, 196 | "tr" : { 197 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/tr/01.Jun.21.tr.gz", 198 | "info" : "A Turkish monolingual model trained on Turkish Global Voices, Wikipedia, and News Crawl data", 199 | "description" : "monolingual-tr", 200 | "destination" : "monolingual/tr/01.Jun.21.tr", 201 | "date" : "01 June 2021", 202 | "md5" : "5702b95c97d9702fef7ab57380a80a1d" 203 | }, 204 | "zh" : { 205 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/monolingual/zh/01.Jun.21.zh.gz", 206 | "info" : "A Chinese monolingual model trained on Chinese News Commentary and Wikipedia data", 207 | "description" : "monolingual-zh", 208 | "destination" : "monolingual/zh/01.Jun.21.zh", 209 | "date" : "01 June 2021", 210 | "md5" : "385e068e30f54963fbc582f49f4416ff" 211 | }, 212 | "default-multilingual" : { 213 | "source" : "https://github.com/rewicks/ersatz-models/raw/main/multilingual/wmtlangs/01.Jun.21.multilingual.gz", 214 | "info": "A multilingual model, including languages commonly associated with WMT tasks and datasets", 215 | "description" : "multilingual/wmtlangs", 216 | "destination": "multilingual/wmtlangs/01.Jun.21.multilingual", 217 | "date": "01 June 2021", 218 | "md5" : "2d7d2092800cecda2b88f9da9fffbfff" 219 | } 220 | } 221 | 222 | def list_models(): 223 | for model_name in MODELS: 224 | model = MODELS[model_name] 225 | print(f'\t- {model_name} ({model["description"]}) : {model["info"]}') 226 | pass 227 | 228 | def get_model_path(model_name='default-multilingual'): 229 | 230 | if model_name not in MODELS: 231 | logger.error(f"Could not find model by name of \"{model_name}\". Using \"default-multilingual\" instead") 232 | model_name = 'default-multilingual' 233 | 234 | model = MODELS[model_name] 235 | 236 | logger.info(f"Segmentation model: \"{model_name}\"") 237 | logger.info(f"Model description: \"{model['description']}\"") 238 | logger.info(f"Release Date: \"{model['date']}\"") 239 | 240 | model_file = os.path.join(ERSATZ_DIR, model['destination']) 241 | if os.path.exists(model_file): 242 | logger.info(f"USING \"{model_name}\" model found at {model_file}") 243 | return model_file 244 | elif download_model(model_name) == 0: 245 | return model_file 246 | sys.exit(1) 247 | 248 | pbar = None 249 | def show_progress(block_num, block_size, total_size): 250 | global pbar 251 | if pbar is None: 252 | pbar = progressbar.ProgressBar(maxval=total_size) 253 | pbar.start() 254 | 255 | downloaded = block_num * block_size 256 | if downloaded < total_size: 257 | pbar.update(downloaded) 258 | else: 259 | pbar.finish() 260 | pbar = None 261 | 262 | def download_model(model_name='default'): 263 | """ 264 | Downloads the specified model into the ERSATZ directory 265 | :param language: 266 | :return: 267 | """ 268 | 269 | expected_checksum = MODELS[model_name].get('md5', None) 270 | model_source = MODELS[model_name]['source'] 271 | model_file = os.path.join(ERSATZ_DIR, os.path.basename(model_source)) 272 | model_destination = os.path.join(ERSATZ_DIR, MODELS[model_name]['destination']) 273 | 274 | os.makedirs(ERSATZ_DIR, exist_ok=True) 275 | os.makedirs(os.path.dirname(model_destination), exist_ok=True) 276 | 277 | logger.info(f"DOWNLOADING \"{model_name}\" model from {model_source}") 278 | 279 | if not os.path.exists(model_file) or os.path.getsize(model_file) == 0: 280 | try: 281 | urllib.request.urlretrieve(model_source, model_file, show_progress) 282 | except Exception as e: 283 | logger.error(e) 284 | sys.exit(1) 285 | 286 | if expected_checksum is not None: 287 | md5 = hashlib.md5() 288 | with open(model_file, 'rb') as infile: 289 | for line in infile: 290 | md5.update(line) 291 | if md5.hexdigest() != expected_checksum: 292 | logger.error(f"Failed checksum: expected was {expected_checksum}, received {md5.hexdigest()}") 293 | sys.exit(1) 294 | 295 | logger.info(f"Checksum passed: {md5.hexdigest()}") 296 | 297 | logger.info(f"EXTRACTING {model_file} to {model_destination}") 298 | with gzip.open(model_file) as infile, open(model_destination, 'wb') as outfile: 299 | shutil.copyfileobj(infile, outfile) 300 | 301 | return 0 302 | -------------------------------------------------------------------------------- /ersatz/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import argparse 7 | import os 8 | import logging 9 | import math 10 | import json 11 | import sys 12 | import pathlib 13 | 14 | 15 | if __package__ is None and __name__ == '__main__': 16 | parent = pathlib.Path(__file__).absolute().parents[1] 17 | sys.path.insert(0, str(parent)) 18 | __package__ = 'ersatz' 19 | 20 | from .subword import SentencePiece 21 | from .candidates import PunctuationSpace, Split, MultilingualPunctuation 22 | from .model import ErsatzTransformer 23 | from .dataset import ErsatzDataset 24 | 25 | logging.basicConfig(format="%(message)s", level=logging.INFO) 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--train_path', type=str) 30 | parser.add_argument('--valid_path', type=str) 31 | parser.add_argument('--sentencepiece_path') 32 | parser.add_argument('--determiner_type', default='multilingual', choices=["en", "multilingual", "all"]) 33 | parser.add_argument('--left_size', type=int, default=15) 34 | parser.add_argument('--right_size', type=int, default=5) 35 | parser.add_argument('--batch_size', type=int, default=256) 36 | parser.add_argument('--min-epochs', type=int, default=25) 37 | parser.add_argument('--max-epochs', type=int, default=1000) 38 | parser.add_argument('--output_path', type=str, default='models') 39 | parser.add_argument('--checkpoint_path', type=str) 40 | parser.add_argument('--lr', type=float, default=0.0001) 41 | parser.add_argument('--dropout', type=float, default=0.1) 42 | parser.add_argument('--embed_size', type=int, default=256) 43 | parser.add_argument('--source_factors', action='store_true') 44 | parser.add_argument('--factor_embed_size', type=int, default=8) 45 | parser.add_argument('--transformer_nlayers', type=int, default=2) 46 | parser.add_argument('--linear_nlayers', type=int, default=0) 47 | parser.add_argument('--activation_type', type=str, default="tanh", choices=["tanh"]) 48 | parser.add_argument('--nhead', type=int, default=8) 49 | parser.add_argument('--log_interval', type=int, default=1000) 50 | parser.add_argument('--validation_interval', type=int, default=25000) 51 | parser.add_argument('--early_stopping', type=int, default=25) 52 | parser.add_argument('--cpu', action='store_true') 53 | parser.add_argument('--eos_weight', type=float, default=1.0) 54 | parser.add_argument('--seed', type=int, default=14) 55 | parser.add_argument('--tb_dir', type=str, default=None) 56 | args = parser.parse_args() 57 | return args 58 | 59 | def main(): 60 | 61 | args = parse_args() 62 | 63 | if args.determiner_type == "en": 64 | determiner = PunctuationSpace() 65 | elif args.determiner_type == "multilingual": 66 | determiner = MultilingualPunctuation() 67 | else: 68 | determiner = Split() 69 | 70 | torch.manual_seed(args.seed) 71 | logging.info('Starting trainer...') 72 | trainer = ErsatzTrainer(args) 73 | 74 | logging.info(trainer.model) 75 | logging.info(args) 76 | minloss = math.inf 77 | status = {} 78 | status['type'] = 'TRAINING' 79 | best_model = None 80 | results = Results(time.time()) 81 | for epoch in range(args.max_epochs): 82 | status['epoch'] = epoch 83 | trainer.model.train() 84 | res, status, best_model = trainer.run_epoch(epoch, args.batch_size, 85 | log_interval=args.log_interval, 86 | validation_interval=args.validation_interval, 87 | status=status, 88 | results=results, 89 | best_model=best_model, 90 | min_epochs=args.min_epochs, 91 | validation_threshold=args.early_stopping, 92 | use_factors=args.source_factors, 93 | determiner=determiner) 94 | if res == 0 and epoch > args.min_epochs: 95 | break 96 | trainer.scheduler.step() 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | 102 | 103 | ################################################################################ 104 | 105 | class Results(): 106 | def __init__(self, time): 107 | self.total_loss = 0 108 | self.perplexity = 0 109 | self.num_obs_eos = 0 110 | self.num_pred_eos = 0 111 | self.correct_eos = 0 112 | self.correct_mos = 0 113 | self.num_pred = 0 114 | self.update_num = 0 115 | self.batches = 0 116 | self.last_update = time 117 | self.start = time 118 | self.validations = 0 119 | self.type = 'TRAINING' 120 | 121 | def calculate(self, loss, ppl, predictions, labels, eos_ind): 122 | self.total_loss += loss 123 | self.perplexity += ppl 124 | self.num_pred += len(predictions) 125 | self.num_obs_eos += (labels==eos_ind).sum().item() 126 | self.num_pred_eos += (predictions==eos_ind).sum().item() 127 | 128 | predictions[predictions!=eos_ind] = -1 129 | predictions = predictions ^ labels 130 | self.correct_eos += (predictions==0).sum().item() 131 | 132 | self.correct_mos = self.num_pred - (self.num_obs_eos + self.num_pred_eos - self.correct_eos) 133 | self.batches += 1 134 | 135 | def get_results(self, lr): 136 | retVal = {} 137 | retVal['type'] = self.type 138 | retVal['update_num'] = self.update_num 139 | if self.num_pred_eos != 0: 140 | retVal['prec'] = self.correct_eos/self.num_pred_eos 141 | else: 142 | retVal['prec'] = 0 143 | if self.num_obs_eos != 0: 144 | retVal['recall'] = self.correct_eos/self.num_obs_eos 145 | else: 146 | retVal['recall'] = 0 147 | retVal['acc'] = (self.correct_eos + self.correct_mos)/self.num_pred 148 | if retVal['prec'] != 0 and retVal['recall'] != 0: 149 | retVal['f1'] = 2*((retVal['prec']*retVal['recall'])/(retVal['prec']+retVal['recall'])) 150 | else: 151 | retVal['f1'] = 0 152 | retVal['lr'] = lr 153 | retVal['total_loss'] = self.total_loss 154 | retVal['average_loss'] = self.total_loss/self.num_pred 155 | retVal['ppl_per_pred'] = self.perplexity/self.num_pred 156 | retVal['time_since_last_update'] = time.time()-self.last_update 157 | retVal['predictions_per_second'] = self.num_pred/retVal['time_since_last_update'] 158 | retVal['time_passed'] = time.time()-self.start 159 | retVal['correct_eos'] = self.correct_eos 160 | retVal['correct_mos'] = self.correct_mos 161 | retVal['num_pred_eos'] = self.num_pred_eos 162 | retVal['num_obs_eos'] = self.num_obs_eos 163 | retVal['validations'] = self.validations 164 | retVal['num_pred'] = self.num_pred 165 | return retVal 166 | 167 | 168 | # add perplexity 169 | def reset(self, time): 170 | self.total_loss = 0 171 | self.perplexity = 0 172 | self.num_obs_eos = 0 173 | self.num_pred_eos = 0 174 | self.correct_eos = 0 175 | self.correct_mos = 0 176 | self.num_pred = 0 177 | self.update_num += 1 178 | self.last_update = time 179 | 180 | def validated(self): 181 | self.validations += 1 182 | 183 | 184 | def load_model(checkpoint_path): 185 | model_dict = torch.load(checkpoint_path) 186 | model = ErsatzTransformer(model_dict['vocab'], model_dict['args']) 187 | model.load_state_dict(model_dict['weights']) 188 | 189 | return model 190 | 191 | def save_model(model, output_path): 192 | model = model.cpu() 193 | model_dict = { 194 | 'weights': model.state_dict(), 195 | 'tokenizer': open(model.tokenizer.model_path, 'rb').read(), 196 | 'args': model.args 197 | } 198 | torch.save(model_dict, output_path) 199 | 200 | 201 | class ErsatzTrainer(): 202 | 203 | def __init__(self, args): 204 | self.with_cuda = torch.cuda.is_available() and not args.cpu 205 | self.device = torch.device("cuda:0" if self.with_cuda else "cpu") 206 | self.output_path = args.output_path 207 | self.batch_size = args.batch_size 208 | 209 | self.training_set = ErsatzDataset(args.train_path, 210 | self.device, 211 | sentencepiece_path=args.sentencepiece_path, 212 | left_context_size=args.left_size, 213 | right_context_size=args.right_size) 214 | self.validation_set = ErsatzDataset(args.valid_path, 215 | self.device, 216 | tokenizer=self.training_set.tokenizer, 217 | left_context_size=args.left_size, 218 | right_context_size=args.right_size) 219 | 220 | if args.tb_dir is None: 221 | log_dir = f'runs/{args.determiner_type}.L{args.left_size}.R{args.right_size}.T{args.transformer_nlayers}.LIN{args.linear_nlayers}.E{args.eos_weight}.EMB{args.embed_size}.VOC{len(self.training_set.tokenizer)}' 222 | else: 223 | log_dir = args.tb_dir 224 | self.writer = SummaryWriter(log_dir=log_dir) 225 | 226 | logging.info(f'{self.device}') 227 | if not os.path.exists(args.output_path): 228 | os.makedirs(args.output_path, exist_ok=True) 229 | 230 | if args.checkpoint_path is not None and os.path.exists(args.checkpoint_path): 231 | logging.info('Loading pre-existing model from checkpoint') 232 | self.model = torch.load(args.output_path, map_location=self.device) 233 | else: 234 | self.model = ErsatzTransformer(self.training_set.tokenizer, args).to(self.device) 235 | 236 | weights = torch.tensor([args.eos_weight, 1]).to(self.device) 237 | self.criterion = nn.NLLLoss(weight=weights) 238 | self.lr = args.lr 239 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 240 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1.0, gamma=0.95) 241 | 242 | total_params = sum([p.numel() for p in self.model.parameters()]) 243 | logging.info(f'Training with: {total_params}') 244 | if self.with_cuda and torch.cuda.device_count() > 1: 245 | logging.info("Using %d GPUSs for ET" % torch.cuda.device_count()) 246 | self.model = nn.DataParallel(self.model) 247 | self.model = self.model.cuda() 248 | 249 | def validate(self, batch_size, determiner, use_factors=False): 250 | retVal = {} 251 | retVal['num_obs_eos'] = 0 252 | retVal['num_pred_eos'] = 0 253 | retVal['correct_eos'] = 0 254 | retVal['correct_mos'] = 0 255 | retVal['total_loss'] = 0 256 | retVal['ppl'] = 0 257 | retVal['num_pred'] = 0 258 | retVal['inference_correct_eos'] = 0 259 | retVal['inference_incorrect_eos'] = 0 260 | retVal['inference_correct_mos'] = 0 261 | retVal['inference_incorrect_mos'] = 0 262 | self.model.eval() 263 | eos_ind = 0 264 | mos_ind = 1 265 | with torch.no_grad(): 266 | for i, (contexts, factors, labels, context_strings) in enumerate(self.validation_set.batchify(batch_size)): 267 | data = contexts.to(self.device) 268 | 269 | if use_factors: 270 | factors = factors.to(self.device) 271 | else: 272 | factors = None 273 | 274 | labels = labels.to(self.device) 275 | output = self.model.forward(data, factors=factors) 276 | loss = self.criterion(output, labels) 277 | perplexity = torch.exp(F.cross_entropy(output, labels)).item() 278 | pred = output.argmax(1) 279 | 280 | retVal['num_pred'] += len(pred) 281 | retVal['num_obs_eos'] += (labels==eos_ind).sum().item() 282 | retVal['num_pred_eos'] += (pred==eos_ind).sum().item() 283 | 284 | pred[pred!=eos_ind] = -1 285 | pred = pred ^ labels 286 | retVal['correct_eos'] += (pred==0).sum().item() 287 | 288 | retVal['correct_mos'] = retVal['num_pred'] - (retVal['num_obs_eos'] + retVal['num_pred_eos'] - retVal['correct_eos']) 289 | 290 | retVal['total_loss'] += loss.item() 291 | retVal['ppl'] += perplexity 292 | for context_item, label_item, p in zip(context_strings, labels, torch.argmax(output, dim=1)): 293 | left_context = self.model.tokenizer.merge(context_item[0]) 294 | right_context = self.model.tokenizer.merge(context_item[1]) 295 | 296 | if determiner(left_context, right_context): 297 | if label_item == eos_ind: 298 | if p.item() == eos_ind: 299 | retVal['inference_correct_eos'] += 1 300 | else: 301 | retVal['inference_incorrect_eos'] += 1 302 | else: 303 | if p.item() == mos_ind: 304 | retVal['inference_correct_mos'] += 1 305 | else: 306 | retVal['inference_incorrect_mos'] += 1 307 | 308 | retVal['average_loss'] = retVal['total_loss'] / len(self.validation_set) 309 | retVal['ppl_per_pred'] = retVal['ppl'] / len(self.validation_set) 310 | if retVal['inference_correct_eos'] + retVal['inference_incorrect_mos'] != 0: 311 | retVal['inference_prec'] = retVal['inference_correct_eos']/(retVal['inference_correct_eos'] + retVal['inference_incorrect_mos']) 312 | else: 313 | retVal['inference_prec'] = 0 314 | if retVal['inference_correct_eos'] + retVal['inference_incorrect_eos'] != 0: 315 | retVal['inference_recall'] = retVal['inference_correct_eos']/(retVal['inference_correct_eos'] + retVal['inference_incorrect_eos']) 316 | else: 317 | retVal['inference_recall'] = 0 318 | if retVal['inference_prec'] != 0 and retVal['inference_recall'] != 0: 319 | retVal['inference_f1'] = 2*((retVal['inference_prec']*retVal['inference_recall'])/(retVal['inference_prec']+retVal['inference_recall'])) 320 | else: 321 | retVal['inference_f1'] = 0 322 | 323 | retVal['inference_acc'] = (retVal['inference_correct_eos'] + retVal['inference_correct_mos'])/(retVal['inference_correct_eos'] + retVal['inference_correct_mos'] + retVal['inference_incorrect_eos'] + retVal['inference_incorrect_mos']) 324 | retVal['average_loss'] = retVal['total_loss']/retVal['num_pred'] 325 | self.model.train() 326 | return retVal 327 | 328 | def run_epoch(self, epoch, batch_size, 329 | log_interval=1000, 330 | validation_interval=250000, 331 | results=None, 332 | status=None, 333 | best_model=None, 334 | min_epochs = 10, 335 | validation_threshold=10, 336 | use_factors=False, 337 | determiner=None): 338 | 339 | eos_ind = 0 340 | mos_ind = 1 341 | for i, (contexts, factors, labels, context_strings) in enumerate(self.training_set.batchify(batch_size)): 342 | data = contexts.to(self.device) 343 | 344 | if use_factors: 345 | factors = factors.to(self.device) 346 | else: 347 | factors = None 348 | 349 | labels = labels.to(self.device) 350 | output = self.model.forward(data, factors=factors) 351 | loss = self.criterion(output, labels) 352 | ppl = torch.exp(F.cross_entropy(output, labels)).item() 353 | pred = output.argmax(1) 354 | 355 | results.calculate(loss.item(), ppl, pred, labels, eos_ind) 356 | 357 | self.optimizer.zero_grad() 358 | loss.backward() 359 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) 360 | self.optimizer.step() 361 | 362 | if results.batches % log_interval == 1: 363 | status = results.get_results(self.scheduler.get_last_lr()[0]) 364 | logging.info(json.dumps(status)) 365 | for key in status: 366 | if type(status[key]) is float: 367 | time_mark = epoch * (len(self.training_set) * batch_size) + i 368 | self.writer.add_scalar(f'{key}/training', status[key], time_mark) 369 | results.reset(time.time()) 370 | 371 | if results.batches % validation_interval == 1: 372 | stats = self.validate(batch_size, determiner, use_factors=use_factors) 373 | stats['type'] = 'VALIDATION' 374 | results.validated() 375 | stats['average_loss'] = stats['total_loss']/stats['num_pred'] 376 | stats['acc'] = (stats['correct_eos'] + stats['correct_mos'])/stats['num_pred'] 377 | if stats['num_pred_eos'] != 0: 378 | stats['prec'] = stats['correct_eos']/stats['num_pred_eos'] 379 | else: 380 | stats['prec'] = 0 381 | if stats['num_obs_eos'] != 0: 382 | stats['recall'] = stats['correct_eos']/stats['num_obs_eos'] 383 | else: 384 | stats['recall'] = 0 385 | if stats['prec'] != 0 and stats['recall'] != 0: 386 | stats['f1'] = 2*(stats['prec']*stats['recall'])/(stats['prec']+stats['recall']) 387 | else: 388 | stats['f1'] = 0 389 | logging.info(json.dumps(stats)) 390 | for key in stats: 391 | if type(stats[key]) is float: 392 | time_mark = status['validations'] 393 | self.writer.add_scalar(f'{key}/validation', stats[key], time_mark) 394 | if best_model is not None: 395 | if stats['inference_f1'] > best_model['inference_f1']: 396 | save_model(self.model, os.path.join(self.output_path, 'checkpoint.best')) 397 | self.model = self.model.to(self.device) 398 | best_model = stats 399 | best_model['validation_num'] = status['validations'] 400 | logging.info(f'SAVING MODEL: { json.dumps(best_model)}') 401 | else: 402 | if epoch > min_epochs and status['validations'] - best_model['validation_num'] >= validation_threshold: 403 | logging.info(f'EARLY STOPPING {json.dumps(best_model)}') 404 | return 0, status, best_model 405 | else: 406 | save_model(self.model, os.path.join(self.output_path, 'checkpoint.best')) 407 | self.model = self.model.to(self.device) 408 | best_model = stats 409 | logging.info(f'SAVING MODEL: { json.dumps(best_model) }') 410 | best_model['validation_num'] = status['validations'] 411 | logging.info(f'SAVING MODEL: End of epoch {epoch}') 412 | save_model(self.model, os.path.join(self.output_path, f'checkpoint.e{epoch}')) 413 | self.model = self.model.to(self.device) 414 | return 1, status, best_model 415 | --------------------------------------------------------------------------------