├── script ├── __init__.py ├── multi-bleu.perl ├── PythonROUGE.py └── tokenizer.perl ├── .gitattributes ├── models ├── __pycache__ │ ├── rnn.cpython-35.pyc │ ├── beam.cpython-35.pyc │ ├── loss.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── optims.cpython-35.pyc │ ├── seq2seq.cpython-35.pyc │ └── attention.cpython-35.pyc ├── __init__.py ├── attention.py ├── optims.py ├── rnn.py ├── beam.py └── seq2seq.py ├── utils ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── metrics.cpython-35.pyc │ ├── data_helper.cpython-35.pyc │ ├── dict_helper.cpython-35.pyc │ └── misc_utils.cpython-35.pyc ├── __init__.py ├── metrics.py ├── misc_utils.py ├── data_helper.py └── dict_helper.py ├── en_vi.yaml ├── pku.yaml ├── opts.py ├── README.md ├── preprocess.py ├── train.py └── lr_scheduler.py /script/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.perl linguist-language=python 2 | *.py linguist-language=python 3 | *.pl linguist-language=python -------------------------------------------------------------------------------- /models/__pycache__/rnn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/rnn.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/beam.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/beam.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/loss.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/optims.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/optims.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/seq2seq.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/seq2seq.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/utils/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/models/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/utils/__pycache__/data_helper.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dict_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/utils/__pycache__/dict_helper.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SACT/HEAD/utils/__pycache__/misc_utils.cpython-35.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .optims import * 3 | from .rnn import * 4 | from .seq2seq import * 5 | from .beam import * 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Date : 2017/12/18 3 | @Author: Shuming Ma 4 | @mail : shumingma@pku.edu.cn 5 | @homepage: shumingma.com 6 | ''' 7 | from .data_helper import * 8 | from .dict_helper import * 9 | from .misc_utils import * 10 | from .metrics import * 11 | -------------------------------------------------------------------------------- /en_vi.yaml: -------------------------------------------------------------------------------- 1 | data: '/home/linjunyang/en_vi/data/' 2 | logF: 'experiments/en_vi/' 3 | epoch: 120 4 | batch_size: 64 5 | optim: 'adam' 6 | cell: 'lstm' 7 | attention: 'luong_gate' 8 | learning_rate: 0.0003 9 | max_grad_norm: 10 10 | learning_rate_decay: 0.8 11 | start_decay_at: 12 12 | emb_size: 512 13 | hidden_size: 512 14 | dec_num_layers: 2 15 | enc_num_layers: 2 16 | bidirectional: True 17 | dropout: 0.4 18 | max_time_step: 50 19 | eval_interval: 1000 20 | save_interval: 3000 21 | metrics: ['bleu'] 22 | shared_vocab: False 23 | beam_size: 4 24 | unk: False 25 | schedule: False 26 | schesamp: False 27 | length_norm: True -------------------------------------------------------------------------------- /pku.yaml: -------------------------------------------------------------------------------- 1 | data: '/home/linjunyang/pku/data/' 2 | logF: 'experiments/pku/' 3 | epoch: 20 4 | batch_size: 64 5 | optim: 'adam' 6 | cell: 'lstm' 7 | attention: 'luong_gate' 8 | learning_rate: 0.0003 9 | max_grad_norm: 10 10 | learning_rate_decay: 0.5 11 | start_decay_at: 4 12 | emb_size: 512 13 | hidden_size: 512 14 | dec_num_layers: 3 15 | enc_num_layers: 3 16 | bidirectional: True 17 | dropout: 0.2 18 | max_time_step: 50 19 | eval_interval: 10000 20 | save_interval: 3000 21 | metrics: ['bleu'] 22 | shared_vocab: False 23 | beam_size: 12 24 | unk: False 25 | schedule: False 26 | schesamp: False 27 | length_norm: True 28 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class luong_gate_attention(nn.Module): 6 | 7 | def __init__(self, hidden_size, emb_size, prob=0.2): 8 | super(luong_gate_attention, self).__init__() 9 | self.hidden_size, self.emb_size = hidden_size, emb_size 10 | self.linear_in = nn.Sequential( 11 | nn.Linear(hidden_size, hidden_size), nn.Dropout(p=prob)) 12 | self.feed = nn.Sequential(nn.Linear(2*hidden_size, hidden_size), nn.SELU( 13 | ), nn.Dropout(p=prob), nn.Linear(hidden_size, 1), nn.Tanh(), nn.Dropout(p=prob)) 14 | self.linear_out = nn.Sequential(nn.Linear(2*hidden_size, hidden_size), nn.SELU(), nn.Dropout( 15 | p=prob), nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Dropout(p=prob)) 16 | self.softmax = nn.Softmax(dim=1) 17 | 18 | def init_context(self, context): 19 | self.context = context.transpose(0, 1) 20 | 21 | def forward(self, h, c_t): 22 | tau = torch.exp(self.feed(torch.cat([h, c_t], 1))) 23 | gamma_h = self.linear_in(h).unsqueeze(2) 24 | weights = torch.bmm(self.context, gamma_h).squeeze(2) / tau 25 | weights = self.softmax(weights) 26 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) 27 | output = self.linear_out(torch.cat([h, c_t], 1)) 28 | return output, weights, c_t -------------------------------------------------------------------------------- /models/optims.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm_ 3 | 4 | 5 | class Optim(object): 6 | 7 | def set_parameters(self, params): 8 | self.params = list(params) # careful: params may be a generator 9 | if self.method == 'sgd': 10 | self.optimizer = optim.SGD(self.params, lr=self.lr) 11 | elif self.method == 'adagrad': 12 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 13 | elif self.method == 'adadelta': 14 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 15 | elif self.method == 'adam': 16 | self.optimizer = optim.Adam(self.params, lr=self.lr, eps=1e-9) 17 | else: 18 | raise RuntimeError("Invalid optim method: " + self.method) 19 | 20 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None, max_decay_times=2): 21 | self.last_score = None 22 | self.decay_times = 0 23 | self.max_decay_times = max_decay_times 24 | self.lr = lr 25 | self.max_grad_norm = max_grad_norm 26 | self.method = method 27 | self.lr_decay = lr_decay 28 | self.start_decay_at = start_decay_at 29 | self.start_decay = False 30 | 31 | def step(self): 32 | # Compute gradients norm. 33 | if self.max_grad_norm: 34 | clip_grad_norm_(self.params, self.max_grad_norm) 35 | self.optimizer.step() 36 | 37 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 38 | def updateLearningRate(self, score, epoch): 39 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 40 | self.start_decay = True 41 | 42 | if self.start_decay: 43 | self.lr = self.lr * self.lr_decay 44 | print("Decaying learning rate to %g" % self.lr) 45 | 46 | self.last_score = score 47 | self.optimizer.param_groups[0]['lr'] = self.lr 48 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | def model_opts(parser): 2 | 3 | parser.add_argument('-config', default='default.yaml', type=str, 4 | help="config file") 5 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 6 | help="Use CUDA on the listed devices.") 7 | parser.add_argument('-restore', default='', type=str, 8 | help="restore checkpoint") 9 | parser.add_argument('-seed', type=int, default=1234, 10 | help="Random seed") 11 | parser.add_argument('-model', default='seq2seq', type=str, 12 | help="Model selection") 13 | parser.add_argument('-mode', default='train', type=str, 14 | help="Mode selection") 15 | parser.add_argument('-module', default='seq2seq', type=str, 16 | help="Module selection") 17 | parser.add_argument('-log', default='', type=str, 18 | help="log directory") 19 | parser.add_argument('-num_processes', type=int, default=4, 20 | help="number of processes") 21 | parser.add_argument('-refF', default='', type=str, 22 | help="reference file") 23 | parser.add_argument('-unk', action='store_true', help='replace unk') 24 | parser.add_argument('-char', action='store_true', help='char level decoding') 25 | parser.add_argument('-length_norm', action='store_true', help='replace unk') 26 | parser.add_argument('-pool_size', type=int, default=0, help="pool size of maxout layer") 27 | parser.add_argument('-scale', type=float, default=1, help="proportion of the training set") 28 | parser.add_argument('-max_split', type=int, default=0, help="max generator time steps for memory efficiency") 29 | parser.add_argument('-split_num', type=int, default=0, help="split number for splitres") 30 | parser.add_argument('-pretrain', default='', type=str, help="load pretrain encoder") 31 | 32 | 33 | def convert_to_config(opt, config): 34 | opt = vars(opt) 35 | for key in opt: 36 | if key not in config: 37 | config[key] = opt[key] 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Adaptive Control of Temperature 2 | This is the code for our paper *Learning When to Concentrate or Divert Attention: Self-Adaptive Attention Temperature for Neural Machine Translation*, https://www.aclweb.org/anthology/papers/D/D18/D18-1331/ 3 | 4 | *********************************************************** 5 | 6 | ## Requirements 7 | * Ubuntu 16.0.4 8 | * Python 3.6 9 | * Pytorch >= 0.4 10 | * pyrouge 11 | 12 | 13 | ************************************************************** 14 | 15 | ## Preprocessing 16 | ``` 17 | python3 preprocess.py -load_data path_to_data -save_data path_to_store_data 18 | ``` 19 | Remember to put the data into a folder and name them *train.src*, *train.tgt*, *valid.src*, *valid.tgt*, *test.src* and *test.tgt*, and make a new folder inside called *data* 20 | 21 | *************************************************************** 22 | 23 | ## Training 24 | ``` 25 | python3 train.py -log log_name -config config_yaml -gpus id 26 | ``` 27 | 28 | **************************************************************** 29 | 30 | ## Evaluation 31 | ``` 32 | python3 train.py -log log_name -config config_yaml -gpus id -restore checkpoint -mode eval 33 | ``` 34 | 35 | ******************************************************************* 36 | 37 | # Citation 38 | If you use this code for your research, please cite the paper this code is based on: Learning When to Concentrate or Divert Attention: Self-Adaptive Attention Temperature for Neural Machine Translation:. 39 | ``` 40 | @inproceedings{SACT, 41 | title = "Learning When to Concentrate or Divert Attention: Self-Adaptive Attention Temperature for Neural Machine Translation", 42 | author = "Lin, Junyang and 43 | Sun, Xu and 44 | Ren, Xuancheng and 45 | Li, Muyu and 46 | Su, Qi", 47 | booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing", 48 | month = oct # "-" # nov, 49 | year = "2018", 50 | address = "Brussels, Belgium", 51 | publisher = "Association for Computational Linguistics", 52 | url = "https://www.aclweb.org/anthology/D18-1331", 53 | pages = "2985--2990" 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import pyrouge 2 | import codecs 3 | import os 4 | import logging 5 | 6 | def bleu(reference, candidate, log_path, print_log, config): 7 | ref_file = log_path+'reference.txt' 8 | cand_file = log_path+'candidate.txt' 9 | with codecs.open(ref_file, 'w', 'utf-8') as f: 10 | for s in reference: 11 | if not config.char: 12 | f.write(" ".join(s)+'\n') 13 | else: 14 | f.write("".join(s) + '\n') 15 | with codecs.open(cand_file, 'w', 'utf-8') as f: 16 | for s in candidate: 17 | if not config.char: 18 | f.write(" ".join(s).strip()+'\n') 19 | else: 20 | f.write("".join(s).strip() + '\n') 21 | 22 | if config.refF != '': 23 | ref_file = config.refF 24 | 25 | temp = log_path + "result.txt" 26 | command = "perl script/multi-bleu.perl " + ref_file + "<" + cand_file + "> " + temp 27 | os.system(command) 28 | with open(temp) as ft: 29 | result = ft.read() 30 | os.remove(temp) 31 | print_log(result) 32 | 33 | return float(result.split()[2][:-1]) 34 | 35 | 36 | def rouge(reference, candidate, log_path, print_log, config): 37 | assert len(reference) == len(candidate) 38 | 39 | ref_dir = log_path + 'reference/' 40 | cand_dir = log_path + 'candidate/' 41 | if not os.path.exists(ref_dir): 42 | os.mkdir(ref_dir) 43 | if not os.path.exists(cand_dir): 44 | os.mkdir(cand_dir) 45 | 46 | for i in range(len(reference)): 47 | with codecs.open(ref_dir+"%06d_reference.txt" % i, 'w', 'utf-8') as f: 48 | f.write(" ".join(reference[i]).replace(' <\s> ', '\n') + '\n') 49 | with codecs.open(cand_dir+"%06d_candidate.txt" % i, 'w', 'utf-8') as f: 50 | f.write(" ".join(candidate[i]).replace(' <\s> ', '\n').replace('', 'unk') + '\n') 51 | 52 | r = pyrouge.Rouge155() 53 | r.model_filename_pattern = '#ID#_reference.txt' 54 | r.system_filename_pattern = '(\d+)_candidate.txt' 55 | r.model_dir = ref_dir 56 | r.system_dir = cand_dir 57 | logging.getLogger('global').setLevel(logging.WARNING) 58 | rouge_results = r.convert_and_evaluate() 59 | scores = r.output_to_dict(rouge_results) 60 | recall = [round(scores["rouge_1_recall"] * 100, 2), 61 | round(scores["rouge_2_recall"] * 100, 2), 62 | round(scores["rouge_l_recall"] * 100, 2)] 63 | precision = [round(scores["rouge_1_precision"] * 100, 2), 64 | round(scores["rouge_2_precision"] * 100, 2), 65 | round(scores["rouge_l_precision"] * 100, 2)] 66 | f_score = [round(scores["rouge_1_f_score"] * 100, 2), 67 | round(scores["rouge_2_f_score"] * 100, 2), 68 | round(scores["rouge_l_f_score"] * 100, 2)] 69 | print_log("F_measure: %s Recall: %s Precision: %s\n" 70 | % (str(f_score), str(recall), str(precision))) 71 | 72 | return f_score[:], recall[:], precision[:] 73 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Date : 2017/12/18 3 | @Author: Shuming Ma 4 | @mail : shumingma@pku.edu.cn 5 | @homepage: shumingma.com 6 | ''' 7 | import yaml 8 | import os 9 | import time 10 | import sys 11 | 12 | class AttrDict(dict): 13 | def __init__(self, *args, **kwargs): 14 | super(AttrDict, self).__init__(*args, **kwargs) 15 | self.__dict__ = self 16 | 17 | 18 | def read_config(path): 19 | return AttrDict(yaml.load(open(path, 'r'))) 20 | 21 | 22 | def print_log(file): 23 | def write_log(s): 24 | print(s, end='') 25 | with open(file, 'a') as f: 26 | f.write(s) 27 | return write_log 28 | 29 | 30 | 31 | _, term_width = os.popen('stty size', 'r').read().split() 32 | term_width = int(term_width) 33 | 34 | TOTAL_BAR_LENGTH = 86. 35 | last_time = time.time() 36 | begin_time = last_time 37 | def progress_bar(current, total, msg=None): 38 | global last_time, begin_time 39 | current = current % total 40 | if current == 0: 41 | begin_time = time.time() # Reset for new bar. 42 | 43 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 44 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 45 | 46 | sys.stdout.write(' [') 47 | for i in range(cur_len): 48 | sys.stdout.write('=') 49 | sys.stdout.write('>') 50 | for i in range(rest_len): 51 | sys.stdout.write('.') 52 | sys.stdout.write(']') 53 | 54 | cur_time = time.time() 55 | step_time = cur_time - last_time 56 | last_time = cur_time 57 | tot_time = cur_time - begin_time 58 | 59 | L = [] 60 | L.append(' Step: %s' % format_time(step_time)) 61 | L.append(' | Tot: %s' % format_time(tot_time)) 62 | if msg: 63 | L.append(' | ' + msg) 64 | 65 | msg = ''.join(L) 66 | sys.stdout.write(msg) 67 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 68 | sys.stdout.write(' ') 69 | 70 | # Go back to the center of the bar. 71 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 72 | sys.stdout.write('\b') 73 | sys.stdout.write(' %d/%d ' % (current+1, total)) 74 | 75 | if current < total-1: 76 | sys.stdout.write('\r') 77 | else: 78 | sys.stdout.write('\n') 79 | sys.stdout.flush() 80 | 81 | def format_time(seconds): 82 | days = int(seconds / 3600/24) 83 | seconds = seconds - days*3600*24 84 | hours = int(seconds / 3600) 85 | seconds = seconds - hours*3600 86 | minutes = int(seconds / 60) 87 | seconds = seconds - minutes*60 88 | secondsf = int(seconds) 89 | seconds = seconds - secondsf 90 | millis = int(seconds*1000) 91 | 92 | f = '' 93 | i = 1 94 | if days > 0: 95 | f += str(days) + 'D' 96 | i += 1 97 | if hours > 0 and i <= 2: 98 | f += str(hours) + 'h' 99 | i += 1 100 | if minutes > 0 and i <= 2: 101 | f += str(minutes) + 'm' 102 | i += 1 103 | if secondsf > 0 and i <= 2: 104 | f += str(secondsf) + 's' 105 | i += 1 106 | if millis > 0 and i <= 2: 107 | f += str(millis) + 'ms' 108 | i += 1 109 | if f == '': 110 | f = '0ms' 111 | return f -------------------------------------------------------------------------------- /script/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | sub add_to_ref { 35 | my ($file,$REF) = @_; 36 | my $s=0; 37 | open(REF,$file) or die "Can't read $file"; 38 | while() { 39 | chop; 40 | push @{$$REF[$s++]}, $_; 41 | } 42 | close(REF); 43 | } 44 | 45 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 46 | my $s=0; 47 | while() { 48 | chop; 49 | $_ = lc if $lowercase; 50 | my @WORD = split; 51 | my %REF_NGRAM = (); 52 | my $length_translation_this_sentence = scalar(@WORD); 53 | my ($closest_diff,$closest_length) = (9999,9999); 54 | foreach my $reference (@{$REF[$s]}) { 55 | # print "$s $_ <=> $reference\n"; 56 | $reference = lc($reference) if $lowercase; 57 | my @WORD = split(' ',$reference); 58 | my $length = scalar(@WORD); 59 | my $diff = abs($length_translation_this_sentence-$length); 60 | if ($diff < $closest_diff) { 61 | $closest_diff = $diff; 62 | $closest_length = $length; 63 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 64 | } elsif ($diff == $closest_diff) { 65 | $closest_length = $length if $length < $closest_length; 66 | # from two references with the same closeness to me 67 | # take the *shorter* into account, not the "first" one. 68 | } 69 | for(my $n=1;$n<=4;$n++) { 70 | my %REF_NGRAM_N = (); 71 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 72 | my $ngram = "$n"; 73 | for(my $w=0;$w<$n;$w++) { 74 | $ngram .= " ".$WORD[$start+$w]; 75 | } 76 | $REF_NGRAM_N{$ngram}++; 77 | } 78 | foreach my $ngram (keys %REF_NGRAM_N) { 79 | if (!defined($REF_NGRAM{$ngram}) || 80 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 81 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 82 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 83 | } 84 | } 85 | } 86 | } 87 | $length_translation += $length_translation_this_sentence; 88 | $length_reference += $closest_length; 89 | for(my $n=1;$n<=4;$n++) { 90 | my %T_NGRAM = (); 91 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 92 | my $ngram = "$n"; 93 | for(my $w=0;$w<$n;$w++) { 94 | $ngram .= " ".$WORD[$start+$w]; 95 | } 96 | $T_NGRAM{$ngram}++; 97 | } 98 | foreach my $ngram (keys %T_NGRAM) { 99 | $ngram =~ /^(\d+) /; 100 | my $n = $1; 101 | # my $corr = 0; 102 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 103 | $TOTAL[$n] += $T_NGRAM{$ngram}; 104 | if (defined($REF_NGRAM{$ngram})) { 105 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 106 | $CORRECT[$n] += $T_NGRAM{$ngram}; 107 | # $corr = $T_NGRAM{$ngram}; 108 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 109 | } 110 | else { 111 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 112 | # $corr = $REF_NGRAM{$ngram}; 113 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 114 | } 115 | } 116 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 117 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 118 | } 119 | } 120 | $s++; 121 | } 122 | my $brevity_penalty = 1; 123 | my $bleu = 0; 124 | 125 | my @bleu=(); 126 | 127 | for(my $n=1;$n<=4;$n++) { 128 | if (defined ($TOTAL[$n])){ 129 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 130 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 131 | }else{ 132 | $bleu[$n]=0; 133 | } 134 | } 135 | 136 | if ($length_reference==0){ 137 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 138 | exit(1); 139 | } 140 | 141 | if ($length_translation<$length_reference) { 142 | $brevity_penalty = exp(1-$length_reference/$length_translation); 143 | } 144 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 145 | my_log( $bleu[2] ) + 146 | my_log( $bleu[3] ) + 147 | my_log( $bleu[4] ) ) / 4) ; 148 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 149 | 100*$bleu, 150 | 100*$bleu[1], 151 | 100*$bleu[2], 152 | 100*$bleu[3], 153 | 100*$bleu[4], 154 | $brevity_penalty, 155 | $length_translation / $length_reference, 156 | $length_translation, 157 | $length_reference; 158 | 159 | sub my_log { 160 | return -9999999999 unless $_[0]; 161 | return log($_[0]); 162 | } 163 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence as pack 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | import models 6 | import math 7 | import numpy as np 8 | 9 | 10 | class rnn_encoder(nn.Module): 11 | 12 | def __init__(self, config, embedding=None): 13 | super(rnn_encoder, self).__init__() 14 | 15 | self.embedding = embedding if embedding is not None else nn.Embedding( 16 | config.src_vocab_size, config.emb_size) 17 | self.hidden_size = config.hidden_size 18 | self.config = config 19 | 20 | if config.cell == 'gru': 21 | self.rnn = nn.GRU(input_size=config.emb_size, hidden_size=config.hidden_size, 22 | num_layers=config.enc_num_layers, dropout=config.dropout, 23 | bidirectional=config.bidirectional) 24 | else: 25 | self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.hidden_size, 26 | num_layers=config.enc_num_layers, dropout=config.dropout, 27 | bidirectional=config.bidirectional) 28 | 29 | def forward(self, inputs, lengths): 30 | embs = pack(self.embedding(inputs), lengths) 31 | outputs, state = self.rnn(embs) 32 | outputs = unpack(outputs)[0] 33 | if self.config.bidirectional: 34 | outputs = outputs[:, :, :self.config.hidden_size] + \ 35 | outputs[:, :, self.config.hidden_size:] 36 | 37 | if self.config.cell == 'gru': 38 | state = state[:self.config.dec_num_layers] 39 | else: 40 | state = (state[0][::2], state[1][::2]) 41 | 42 | return outputs, state 43 | 44 | 45 | class rnn_decoder(nn.Module): 46 | 47 | def __init__(self, config, embedding=None, use_attention=True): 48 | super(rnn_decoder, self).__init__() 49 | self.embedding = embedding if embedding is not None else nn.Embedding( 50 | config.tgt_vocab_size, config.emb_size) 51 | 52 | input_size = config.emb_size 53 | 54 | if config.cell == 'gru': 55 | self.rnn = StackedGRU(input_size=input_size, hidden_size=config.hidden_size, 56 | num_layers=config.dec_num_layers, dropout=config.dropout) 57 | else: 58 | self.rnn = StackedLSTM(input_size=input_size, hidden_size=config.hidden_size, 59 | num_layers=config.dec_num_layers, dropout=config.dropout) 60 | 61 | self.linear = nn.Linear(config.hidden_size, config.tgt_vocab_size) 62 | 63 | if not use_attention or config.attention == 'None': 64 | self.attention = None 65 | elif config.attention == 'luong_gate': 66 | self.attention = models.luong_gate_attention( 67 | config.hidden_size, config.emb_size, prob=config.dropout) 68 | 69 | self.hidden_size = config.hidden_size 70 | self.dropout = nn.Dropout(config.dropout) 71 | self.config = config 72 | 73 | def forward(self, input, state, c_t): 74 | embs = self.embedding(input) 75 | output, state = self.rnn(embs, state) 76 | 77 | if self.attention is not None: 78 | output, attn_weights, c_t = self.attention(output, c_t) 79 | else: 80 | attn_weights = None 81 | 82 | output = self.compute_score(output) 83 | 84 | return output, state, attn_weights, c_t 85 | 86 | def compute_score(self, hiddens): 87 | scores = self.linear(hiddens) 88 | return scores 89 | 90 | 91 | class StackedLSTM(nn.Module): 92 | def __init__(self, num_layers, input_size, hidden_size, dropout): 93 | super(StackedLSTM, self).__init__() 94 | self.dropout = nn.Dropout(dropout) 95 | self.num_layers = num_layers 96 | self.layers = nn.ModuleList() 97 | 98 | for i in range(num_layers): 99 | self.layers.append(nn.LSTMCell(input_size, hidden_size)) 100 | input_size = hidden_size 101 | 102 | def forward(self, input, hidden): 103 | h_0, c_0 = hidden 104 | h_1, c_1 = [], [] 105 | for i, layer in enumerate(self.layers): 106 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 107 | input = h_1_i 108 | if i + 1 != self.num_layers: 109 | input = self.dropout(input) 110 | h_1 += [h_1_i] 111 | c_1 += [c_1_i] 112 | 113 | h_1 = torch.stack(h_1) 114 | c_1 = torch.stack(c_1) 115 | 116 | return input, (h_1, c_1) 117 | 118 | 119 | class StackedGRU(nn.Module): 120 | def __init__(self, num_layers, input_size, hidden_size, dropout): 121 | super(StackedGRU, self).__init__() 122 | self.dropout = nn.Dropout(dropout) 123 | self.num_layers = num_layers 124 | self.layers = nn.ModuleList() 125 | 126 | for i in range(num_layers): 127 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 128 | input_size = hidden_size 129 | 130 | def forward(self, input, hidden): 131 | h_0 = hidden 132 | h_1 = [] 133 | for i, layer in enumerate(self.layers): 134 | h_1_i = layer(input, h_0[i]) 135 | input = h_1_i 136 | if i + 1 != self.num_layers: 137 | input = self.dropout(input) 138 | h_1 += [h_1_i] 139 | 140 | h_1 = torch.stack(h_1) 141 | 142 | return input, h_1 143 | -------------------------------------------------------------------------------- /utils/data_helper.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import torch 3 | import torch.utils.data as torch_data 4 | from random import Random 5 | import utils 6 | 7 | num_samples = 1 8 | 9 | 10 | class MonoDataset(torch_data.Dataset): 11 | 12 | def __init__(self, infos, indexes=None): 13 | 14 | self.srcF = infos['srcF'] 15 | self.original_srcF = infos['original_srcF'] 16 | self.length = infos['length'] 17 | self.infos = infos 18 | if indexes is None: 19 | self.indexes = list(range(self.length)) 20 | else: 21 | self.indexes = indexes 22 | 23 | def __getitem__(self, index): 24 | index = self.indexes[index] 25 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 26 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 27 | 28 | return src, original_src 29 | 30 | def __len__(self): 31 | return len(self.indexes) 32 | 33 | 34 | class BiDataset(torch_data.Dataset): 35 | 36 | def __init__(self, infos, indexes=None, char=False): 37 | 38 | self.srcF = infos['srcF'] 39 | self.tgtF = infos['tgtF'] 40 | self.original_srcF = infos['original_srcF'] 41 | self.original_tgtF = infos['original_tgtF'] 42 | self.length = infos['length'] 43 | self.infos = infos 44 | self.char = char 45 | if indexes is None: 46 | self.indexes = list(range(self.length)) 47 | else: 48 | self.indexes = indexes 49 | 50 | def __getitem__(self, index): 51 | index = self.indexes[index] 52 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 53 | tgt = list(map(int, linecache.getline(self.tgtF, index+1).strip().split())) 54 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 55 | original_tgt = linecache.getline(self.original_tgtF, index+1).strip().split() if not self.char else \ 56 | list(linecache.getline(self.original_tgtF, index + 1).strip()) 57 | 58 | return src, tgt, original_src, original_tgt 59 | 60 | def __len__(self): 61 | return len(self.indexes) 62 | 63 | 64 | def splitDataset(data_set, sizes): 65 | length = len(data_set) 66 | indexes = list(range(length)) 67 | rng = Random() 68 | rng.seed(1234) 69 | rng.shuffle(indexes) 70 | 71 | data_sets = [] 72 | part_len = int(length / sizes) 73 | for i in range(sizes-1): 74 | data_sets.append(BiDataset(data_set.infos, indexes[0:part_len])) 75 | indexes = indexes[part_len:] 76 | data_sets.append(BiDataset(data_set.infos, indexes)) 77 | return data_sets 78 | 79 | 80 | def padding(data): 81 | src, tgt, original_src, original_tgt = zip(*data) 82 | 83 | src_len = [len(s) for s in src] 84 | src_pad = torch.zeros(len(src), max(src_len)).long() 85 | for i, s in enumerate(src): 86 | end = src_len[i] 87 | src_pad[i, :end] = torch.LongTensor(s[end-1::-1]) 88 | 89 | tgt_len = [len(s) for s in tgt] 90 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 91 | for i, s in enumerate(tgt): 92 | end = tgt_len[i] 93 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 94 | 95 | return src_pad, tgt_pad, \ 96 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), \ 97 | original_src, original_tgt 98 | 99 | 100 | def ae_padding(data): 101 | src, tgt, original_src, original_tgt = zip(*data) 102 | 103 | src_len = [len(s) for s in src] 104 | src_pad = torch.zeros(len(src), max(src_len)).long() 105 | for i, s in enumerate(src): 106 | end = src_len[i] 107 | src_pad[i, :end] = torch.LongTensor(s)[:end] 108 | 109 | tgt_len = [len(s) for s in tgt] 110 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 111 | for i, s in enumerate(tgt): 112 | end = tgt_len[i] 113 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 114 | 115 | ae_len = [len(s)+2 for s in src] 116 | ae_pad = torch.zeros(len(src), max(ae_len)).long() 117 | for i, s in enumerate(src): 118 | end = ae_len[i] 119 | ae_pad[i, 0] = utils.BOS 120 | ae_pad[i, 1:end-1] = torch.LongTensor(s)[:end-2] 121 | ae_pad[i, end-1] = utils.EOS 122 | 123 | return src_pad, tgt_pad, ae_pad, \ 124 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), torch.LongTensor(ae_len), \ 125 | original_src, original_tgt 126 | 127 | 128 | def split_padding(data): 129 | src, tgt, original_src, original_tgt = zip(*data) 130 | 131 | split_samples = [] 132 | num_per_sample = int(len(src) / utils.num_samples) 133 | 134 | for i in range(utils.num_samples): 135 | split_src = src[i*num_per_sample:(i+1)*num_per_sample] 136 | split_tgt = tgt[i*num_per_sample:(i+1)*num_per_sample] 137 | split_original_src = original_src[i * num_per_sample:(i + 1) * num_per_sample] 138 | split_original_tgt = original_tgt[i * num_per_sample:(i + 1) * num_per_sample] 139 | 140 | src_len = [len(s) for s in split_src] 141 | src_pad = torch.zeros(len(split_src), max(src_len)).long() 142 | for i, s in enumerate(split_src): 143 | end = src_len[i] 144 | src_pad[i, :end] = torch.LongTensor(s)[:end] 145 | 146 | tgt_len = [len(s) for s in split_tgt] 147 | tgt_pad = torch.zeros(len(split_tgt), max(tgt_len)).long() 148 | for i, s in enumerate(split_tgt): 149 | end = tgt_len[i] 150 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 151 | 152 | split_samples.append([src_pad, tgt_pad, 153 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), 154 | split_original_src, split_original_tgt]) 155 | 156 | return split_samples -------------------------------------------------------------------------------- /script/PythonROUGE.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 13 10:31:58 2012 3 | 4 | author: Miguel B. Almeida 5 | mail: mba@priberam.pt 6 | """ 7 | 8 | import os 9 | import re 10 | import subprocess 11 | 12 | # Wrapper function to use ROUGE from Python easily 13 | # Inputs: 14 | # guess_summ_list, a string with the absolute path to the file with your guess summary 15 | # ref_summ_list, a list of lists of paths to multiple reference summaries. 16 | # IMPORTANT: all the reference summaries must be in the same directory! 17 | # (optional) ngram_order, the order of the N-grams used to compute ROUGE 18 | # the default is 1 (unigrams) 19 | # Output: a tuple of the form (recall,precision,F_measure) 20 | # 21 | # Example usage: PythonROUGE('/home/foo/my_guess_summary.txt',[/home/bar/my_ref_summary_1.txt,/home/bar/my_ref_summary_2.txt]) 22 | def PythonROUGE(guess_summ_list,ref_summ_list,ngram_order=2, byte=0): 23 | """ Wrapper function to use ROUGE from Python easily. """ 24 | 25 | # even though we ask that the first argument is a list, 26 | # if it is a single string we can handle it 27 | if type(guess_summ_list) == str: 28 | temp = list() 29 | temp.append(guess_summ_list) 30 | guess_summ_list = temp 31 | del temp 32 | 33 | # even though we ask that the second argument is a list of lists, 34 | # if it is a single string we can handle it 35 | if type(ref_summ_list[0]) == str: 36 | temp = list() 37 | temp.append(ref_summ_list) 38 | ref_summ_list = temp 39 | del temp 40 | 41 | # this is the path to your ROUGE distribution 42 | ROUGE_path = 'data/script/ROUGE-1.5.5.pl' 43 | data_path = 'data/script/ROUGE' 44 | 45 | # these are the options used to call ROUGE 46 | # feel free to edit this is you want to call ROUGE with different options 47 | options = '-w 1.2 -a -m -n ' + str(ngram_order) 48 | if byte > 0: 49 | options = '-b ' + str(byte) + ' ' + options 50 | # this is a temporary XML file which will contain information 51 | # in the format ROUGE uses 52 | xml_path = 'data/script/temp.xml' 53 | xml_file = open(xml_path,'w') 54 | xml_file.write('\n') 55 | for guess_summ_index,guess_summ_file in enumerate(guess_summ_list): 56 | xml_file.write('\n') 57 | create_xml(xml_file,guess_summ_file,ref_summ_list[guess_summ_index]) 58 | xml_file.write('\n') 59 | xml_file.write('\n') 60 | xml_file.close() 61 | 62 | 63 | # this is the file where the output of ROUGE will be stored 64 | ROUGE_output_path = 'data/script/ROUGE_result.txt' 65 | # this is where we run ROUGE itself 66 | exec_command = 'perl ' + ROUGE_path + ' -e ' + data_path + ' ' + options + ' ' + xml_path + ' > ' + ROUGE_output_path 67 | os.system(exec_command) 68 | # here, we read the file with the ROUGE output and 69 | # look for the recall, precision, and F-measure scores 70 | recall_list = list() 71 | precision_list = list() 72 | F_measure_list = list() 73 | ROUGE_output_file = open(ROUGE_output_path,'r') 74 | 75 | index = ['1', '2', 'L'] 76 | for idx in index: 77 | ROUGE_output_file.seek(0) 78 | for line in ROUGE_output_file: 79 | match = re.findall('X ROUGE-' + idx + ' Average_R: ([0-9.]+)',line) 80 | if match != []: 81 | recall_list.append(float(match[0])) 82 | match = re.findall('X ROUGE-' + idx + ' Average_P: ([0-9.]+)',line) 83 | if match != []: 84 | precision_list.append(float(match[0])) 85 | match = re.findall('X ROUGE-' + idx + ' Average_F: ([0-9.]+)',line) 86 | if match != []: 87 | F_measure_list.append(float(match[0])) 88 | ROUGE_output_file.close() 89 | # remove temporary files which were created 90 | #os.remove(xml_path) 91 | #os.remove(ROUGE_output_path) 92 | 93 | return (recall_list,precision_list,F_measure_list) 94 | 95 | 96 | # This is an auxiliary function 97 | # It creates an XML file which ROUGE can read 98 | # Don't ask me how ROUGE works, because I don't know! 99 | def create_xml(xml_file,guess_summ_file,ref_summ_list): 100 | xml_file.write('\n') 101 | guess_summ_dir = os.path.dirname(guess_summ_file) 102 | xml_file.write(guess_summ_dir + '\n') 103 | xml_file.write('\n') 104 | xml_file.write('\n') 105 | ref_summ_dir = os.path.dirname(ref_summ_list[0] + '\n') 106 | xml_file.write(ref_summ_dir + '\n') 107 | xml_file.write('\n') 108 | xml_file.write('\n') 109 | xml_file.write('\n') 110 | xml_file.write('\n') 111 | guess_summ_basename = os.path.basename(guess_summ_file) 112 | xml_file.write('

' + guess_summ_basename + '

\n') 113 | xml_file.write('
\n') 114 | xml_file.write('') 115 | letter_list = ['A','B','C','D','E','F','G','H','I','J'] 116 | for ref_summ_index,ref_summ_file in enumerate(ref_summ_list): 117 | ref_summ_basename = os.path.basename(ref_summ_file) 118 | xml_file.write('' + ref_summ_basename + '\n') 119 | xml_file.write('\n') 120 | 121 | # This is only called if this file is executed as a script. 122 | # It shows an example of usage. 123 | if __name__ == '__main__': 124 | guess_summary_list = ['Example/Guess_Summ_1.txt','Example/Guess_Summ_2.txt'] 125 | ref_summ_list = [['Example/Ref_Summ_1_1.txt','Example/Ref_Summ_1_2.txt'] , ['Example/Ref_Summ_2_1.txt','Example/Ref_Summ_2_2.txt','Example/Ref_Summ_2_3.txt']] 126 | recall_list,precision_list,F_measure_list = PythonROUGE(guess_summary_list,ref_summ_list) 127 | #print 'recall = ' + str(recall_list) 128 | #print 'precision = ' + str(precision_list) 129 | #print 'F = ' + str(F_measure_list) 130 | -------------------------------------------------------------------------------- /models/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | class Beam(object): 5 | def __init__(self, size, n_best=1, cuda=True, length_norm=False, minimum_length=0): 6 | 7 | self.size = size 8 | self.tt = torch.cuda if cuda else torch 9 | 10 | # The score for each translation on the beam. 11 | self.scores = self.tt.FloatTensor(size).zero_() 12 | self.allScores = [] 13 | 14 | # The backpointers at each time-step. 15 | self.prevKs = [] 16 | 17 | # The outputs at each time-step. 18 | self.nextYs = [self.tt.LongTensor(size) 19 | .fill_(utils.EOS)] 20 | self.nextYs[0][0] = utils.BOS 21 | 22 | # Has EOS topped the beam yet. 23 | self._eos = utils.EOS 24 | self.eosTop = False 25 | 26 | # The attentions (matrix) for each time. 27 | self.attn = [] 28 | 29 | # Time and k pair for finished. 30 | self.finished = [] 31 | self.n_best = n_best 32 | 33 | self.length_norm = length_norm 34 | self.minimum_length = minimum_length 35 | 36 | 37 | def getCurrentState(self): 38 | "Get the outputs for the current timestep." 39 | return self.nextYs[-1] 40 | 41 | def getCurrentOrigin(self): 42 | "Get the backpointers for the current timestep." 43 | return self.prevKs[-1] 44 | 45 | def advance(self, wordLk, attnOut): 46 | """ 47 | Given prob over words for every last beam `wordLk` and attention 48 | `attnOut`: Compute and update the beam search. 49 | Parameters: 50 | * `wordLk`- probs of advancing from the last step (K x words) 51 | * `attnOut`- attention at the last step 52 | Returns: True if beam search is complete. 53 | """ 54 | numWords = wordLk.size(1) 55 | 56 | # Sum the previous scores. 57 | if len(self.prevKs) > 0: 58 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 59 | 60 | # Don't let EOS have children. 61 | for i in range(self.nextYs[-1].size(0)): 62 | if self.nextYs[-1][i] == self._eos: 63 | beamLk[i] = -1e20 64 | ngrams = [] 65 | le = len(self.nextYs) 66 | for j in range(self.nextYs[-1].size(0)): 67 | hyp, _ = self.getHyp(le-1, j) 68 | ngrams = set() 69 | fail = False 70 | gram = [] 71 | for i in range(le-1): 72 | # last n tokens, n = block_ngram_repeat 73 | gram = (gram + [hyp[i]])[-3:] 74 | # skip the blocking if it is in the exclusion list 75 | #if set(gram) & self.exclusion_tokens: 76 | #continue 77 | if tuple(gram) in ngrams: 78 | fail = True 79 | ngrams.add(tuple(gram)) 80 | if fail: 81 | beamLk[j] = -1e20 82 | else: 83 | beamLk = wordLk[0] 84 | flatBeamLk = beamLk.view(-1) 85 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 86 | 87 | self.allScores.append(self.scores) 88 | self.scores = bestScores 89 | 90 | # bestScoresId is flattened beam x word array, so calculate which 91 | # word and beam each score came from 92 | prevK = bestScoresId / numWords 93 | self.prevKs.append(prevK) 94 | self.nextYs.append((bestScoresId - prevK * numWords)) 95 | self.attn.append(attnOut.index_select(0, prevK)) 96 | 97 | for i in range(self.nextYs[-1].size(0)): 98 | if self.nextYs[-1][i] == self._eos: 99 | s = self.scores[i] 100 | if self.length_norm: 101 | s /= len(self.nextYs) 102 | if len(self.nextYs) - 1 >= self.minimum_length: 103 | self.finished.append((s, len(self.nextYs) - 1, i)) 104 | 105 | # End condition is when top-of-beam is EOS and no global score. 106 | if self.nextYs[-1][0] == utils.EOS: 107 | self.allScores.append(self.scores) 108 | self.eosTop = True 109 | 110 | def done(self): 111 | return self.eosTop and len(self.finished) >= self.n_best 112 | 113 | def beam_update(self, state, idx): 114 | positions = self.getCurrentOrigin() 115 | for e in state: 116 | a, br, d = e.size() 117 | e = e.view(a, self.size, br // self.size, d) 118 | sentStates = e[:, :, idx] 119 | sentStates.copy_(sentStates.index_select(1, positions)) 120 | 121 | def beam_update_gru(self, state, idx): 122 | positions = self.getCurrentOrigin() 123 | for e in state: 124 | br, d = e.size() 125 | e = e.view(self.size, br // self.size, d) 126 | sentStates = e[:, idx] 127 | sentStates.copy_(sentStates.index_select(0, positions)) 128 | 129 | def beam_update_memory(self, state, idx): 130 | positions = self.getCurrentOrigin() 131 | e = state 132 | br, d = e.size() 133 | e = e.view(self.size, br // self.size, d) 134 | sentStates = e[:, idx] 135 | sentStates.copy_(sentStates.index_select(0, positions)) 136 | 137 | def sortFinished(self, minimum=None): 138 | if minimum is not None: 139 | i = 0 140 | # Add from beam until we have minimum outputs. 141 | while len(self.finished) < minimum: 142 | s = self.scores[i].item() 143 | self.finished.append((s, len(self.nextYs) - 1, i)) 144 | i += 1 145 | 146 | self.finished.sort(key=lambda a: -a[0]) 147 | scores = [sc for sc, _, _ in self.finished] 148 | ks = [(t, k) for _, t, k in self.finished] 149 | return scores, ks 150 | 151 | def getHyp(self, timestep, k): 152 | """ 153 | Walk back to construct the full hypothesis. 154 | """ 155 | hyp, attn = [], [] 156 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 157 | hyp.append(self.nextYs[j+1][k].item()) 158 | attn.append(self.attn[j][k]) 159 | k = self.prevKs[j][k].item() 160 | return hyp[::-1], torch.stack(attn[::-1]) 161 | -------------------------------------------------------------------------------- /utils/dict_helper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Date : 2017/12/18 3 | @Author: Shuming Ma 4 | @mail : shumingma@pku.edu.cn 5 | @homepage: shumingma.com 6 | ''' 7 | 8 | import torch 9 | from collections import OrderedDict 10 | 11 | PAD = 0 12 | UNK = 1 13 | BOS = 2 14 | EOS = 3 15 | 16 | PAD_WORD = '' 17 | UNK_WORD = ' ' 18 | BOS_WORD = '' 19 | EOS_WORD = '' 20 | 21 | 22 | class Dict(object): 23 | def __init__(self, data=None, lower=True): 24 | self.idxToLabel = {} 25 | self.labelToIdx = {} 26 | self.frequencies = {} 27 | self.lower = lower 28 | # Special entries will not be pruned. 29 | self.special = [] 30 | 31 | if data is not None: 32 | if type(data) == str: 33 | self.loadFile(data) 34 | else: 35 | self.addSpecials(data) 36 | 37 | def size(self): 38 | return len(self.idxToLabel) 39 | 40 | # Load entries from a file. 41 | def loadFile(self, filename): 42 | for line in open(filename): 43 | fields = line.split() 44 | label = fields[0] 45 | idx = int(fields[1]) 46 | self.add(label, idx) 47 | 48 | # Write entries to a file. 49 | def writeFile(self, filename): 50 | with open(filename, 'w') as file: 51 | for i in range(self.size()): 52 | label = self.idxToLabel[i] 53 | file.write('%s %d\n' % (label, i)) 54 | 55 | file.close() 56 | 57 | def loadDict(self, idxToLabel): 58 | for i in range(len(idxToLabel)): 59 | label = idxToLabel[i] 60 | self.add(label, i) 61 | 62 | def lookup(self, key, default=None): 63 | key = key.lower() if self.lower else key 64 | try: 65 | return self.labelToIdx[key] 66 | except KeyError: 67 | return default 68 | 69 | def getLabel(self, idx, default=None): 70 | try: 71 | return self.idxToLabel[idx] 72 | except KeyError: 73 | return default 74 | 75 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 76 | def addSpecial(self, label, idx=None): 77 | idx = self.add(label, idx) 78 | self.special += [idx] 79 | 80 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 81 | def addSpecials(self, labels): 82 | for label in labels: 83 | self.addSpecial(label) 84 | 85 | # Add `label` in the dictionary. Use `idx` as its index if given. 86 | def add(self, label, idx=None): 87 | label = label.lower() if self.lower else label 88 | if idx is not None: 89 | self.idxToLabel[idx] = label 90 | self.labelToIdx[label] = idx 91 | else: 92 | if label in self.labelToIdx: 93 | idx = self.labelToIdx[label] 94 | else: 95 | idx = len(self.idxToLabel) 96 | self.idxToLabel[idx] = label 97 | self.labelToIdx[label] = idx 98 | 99 | if idx not in self.frequencies: 100 | self.frequencies[idx] = 1 101 | else: 102 | self.frequencies[idx] += 1 103 | 104 | return idx 105 | 106 | # Return a new dictionary with the `size` most frequent entries. 107 | def prune(self, size): 108 | if size > self.size(): 109 | return self 110 | 111 | # Only keep the `size` most frequent entries. 112 | freq = torch.Tensor( 113 | [self.frequencies[i] for i in range(len(self.frequencies))]) 114 | _, idx = torch.sort(freq, 0, True) 115 | 116 | newDict = Dict() 117 | newDict.lower = self.lower 118 | 119 | # Add special entries in all cases. 120 | for i in self.special: 121 | newDict.addSpecial(self.idxToLabel[i]) 122 | 123 | for i in idx[:size]: 124 | newDict.add(self.idxToLabel[i]) 125 | 126 | return newDict 127 | 128 | # Convert `labels` to indices. Use `unkWord` if not found. 129 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 130 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 131 | vec = [] 132 | 133 | if bosWord is not None: 134 | vec += [self.lookup(bosWord)] 135 | 136 | unk = self.lookup(unkWord) 137 | vec += [self.lookup(label, default=unk) for label in labels] 138 | 139 | if eosWord is not None: 140 | vec += [self.lookup(eosWord)] 141 | 142 | return vec 143 | 144 | 145 | def convertToIdxandOOVs(self, labels, unkWord, bosWord=None, eosWord=None): 146 | vec = [] 147 | oovs = OrderedDict() 148 | 149 | if bosWord is not None: 150 | vec += [self.lookup(bosWord)] 151 | 152 | unk = self.lookup(unkWord) 153 | for label in labels: 154 | id = self.lookup(label, default=unk) 155 | if id != unk: 156 | vec += [id] 157 | else: 158 | if label not in oovs: 159 | oovs[label] = len(oovs)+self.size() 160 | oov_num = oovs[label] 161 | vec += [oov_num] 162 | 163 | if eosWord is not None: 164 | vec += [self.lookup(eosWord)] 165 | 166 | return torch.LongTensor(vec), oovs 167 | 168 | def convertToIdxwithOOVs(self, labels, unkWord, bosWord=None, eosWord=None, oovs=None): 169 | vec = [] 170 | 171 | if bosWord is not None: 172 | vec += [self.lookup(bosWord)] 173 | 174 | unk = self.lookup(unkWord) 175 | for label in labels: 176 | id = self.lookup(label, default=unk) 177 | if id == unk and label in oovs: 178 | vec += [oovs[label]] 179 | else: 180 | vec += [id] 181 | 182 | if eosWord is not None: 183 | vec += [self.lookup(eosWord)] 184 | 185 | return torch.LongTensor(vec) 186 | 187 | 188 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 189 | def convertToLabels(self, idx, stop, oovs=None): 190 | labels = [] 191 | 192 | for i in idx: 193 | if i == stop: 194 | break 195 | if i < self.size(): 196 | labels += [self.getLabel(i)] 197 | else: 198 | labels += [oovs[i-self.size()]] 199 | 200 | return labels 201 | -------------------------------------------------------------------------------- /models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils 4 | import models 5 | import random 6 | 7 | 8 | class seq2seq(nn.Module): 9 | 10 | def __init__(self, config, use_attention=True, encoder=None, decoder=None): 11 | super(seq2seq, self).__init__() 12 | 13 | if encoder is not None: 14 | self.encoder = encoder 15 | else: 16 | self.encoder = models.rnn_encoder(config) 17 | tgt_embedding = self.encoder.embedding if config.shared_vocab else None 18 | if decoder is not None: 19 | self.decoder = decoder 20 | else: 21 | self.decoder = models.rnn_decoder( 22 | config, embedding=tgt_embedding, use_attention=use_attention) 23 | self.log_softmax = nn.LogSoftmax(dim=-1) 24 | self.use_cuda = config.use_cuda 25 | self.config = config 26 | self.criterion = nn.CrossEntropyLoss( 27 | ignore_index=utils.PAD, reduction='none') 28 | if config.use_cuda: 29 | self.criterion.cuda() 30 | 31 | def compute_loss(self, scores, targets): 32 | scores = scores.view(-1, scores.size(2)) 33 | loss = self.criterion(scores, targets.contiguous().view(-1)) 34 | return loss 35 | 36 | def forward(self, src, src_len, dec, targets, teacher_ratio=1.0): 37 | src = src.t() 38 | dec = dec.t() 39 | targets = targets.t() 40 | teacher = random.random() < teacher_ratio 41 | 42 | contexts, state = self.encoder(src, src_len.tolist()) 43 | 44 | if self.decoder.attention is not None: 45 | self.decoder.attention.init_context(context=contexts) 46 | if self.config.cell == "lstm": 47 | c_t = state[0][-1] 48 | else: 49 | c_t = state[-1] 50 | 51 | outputs = [] 52 | if teacher: 53 | for input in dec.split(1): 54 | output, state, attn_weights, c_t = self.decoder( 55 | input.squeeze(0), state, c_t) 56 | outputs.append(output) 57 | outputs = torch.stack(outputs) 58 | else: 59 | inputs = [dec.split(1)[0].squeeze(0)] 60 | for i, _ in enumerate(dec.split(1)): 61 | output, state, attn_weights = self.decoder( 62 | inputs[i], state, c_t) 63 | predicted = output.max(1)[1] 64 | inputs += [predicted] 65 | outputs.append(output) 66 | outputs = torch.stack(outputs) 67 | 68 | loss = self.compute_loss(outputs, targets) 69 | return loss, outputs 70 | 71 | def sample(self, src, src_len): 72 | 73 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 74 | _, reverse_indices = torch.sort(indices) 75 | src = torch.index_select(src, dim=0, index=indices) 76 | bos = torch.ones(src.size(0)).long().fill_(utils.BOS) 77 | src = src.t() 78 | 79 | if self.use_cuda: 80 | bos = bos.cuda() 81 | 82 | contexts, state = self.encoder(src, lengths.tolist()) 83 | 84 | if self.decoder.attention is not None: 85 | self.decoder.attention.init_context(context=contexts) 86 | if self.config.cell == "lstm": 87 | c_t = state[0][-1] 88 | else: 89 | c_t = state[-1] 90 | 91 | inputs, outputs, attn_matrix = [bos], [], [] 92 | for i in range(self.config.max_time_step): 93 | output, state, attn_weights, c_t = self.decoder( 94 | inputs[i], state, c_t) 95 | predicted = output.max(1)[1] 96 | inputs += [predicted] 97 | outputs += [predicted] 98 | attn_matrix += [attn_weights] 99 | 100 | outputs = torch.stack(outputs) 101 | sample_ids = torch.index_select( 102 | outputs, dim=1, index=reverse_indices).t().list() 103 | 104 | if self.decoder.attention is not None: 105 | attn_matrix = torch.stack(attn_matrix) 106 | alignments = attn_matrix.max(2)[1] 107 | # alignments = torch.index_select(alignments, dim=1, index=reverse_indices).t().data 108 | alignments = torch.index_select( 109 | alignments, dim=1, index=reverse_indices).t().list() 110 | else: 111 | alignments = None 112 | 113 | return sample_ids, alignments 114 | 115 | def beam_sample(self, src, src_len, beam_size=1, eval_=False): 116 | 117 | # (1) Run the encoder on the src. 118 | 119 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 120 | _, ind = torch.sort(indices) 121 | src = torch.index_select(src, dim=0, index=indices) 122 | src = src.t() 123 | batch_size = src.size(1) 124 | contexts, encState = self.encoder(src, lengths.tolist()) 125 | 126 | # (1b) Initialize for the decoder. 127 | def rvar(a): 128 | return a.repeat(1, beam_size, 1) 129 | 130 | def bottle(m): 131 | return m.view(batch_size * beam_size, -1) 132 | 133 | def unbottle(m): 134 | return m.view(beam_size, batch_size, -1) 135 | 136 | # Repeat everything beam_size times. 137 | contexts = rvar(contexts) 138 | 139 | if self.config.cell == 'lstm': 140 | decState = (rvar(encState[0]), rvar(encState[1])) 141 | c_t = decState[0][-1] 142 | else: 143 | decState = rvar(encState) 144 | c_t = decState[-1] 145 | 146 | beam = [models.Beam(beam_size, n_best=1, 147 | cuda=self.use_cuda, length_norm=self.config.length_norm) 148 | for __ in range(batch_size)] 149 | if self.decoder.attention is not None: 150 | self.decoder.attention.init_context(contexts) 151 | 152 | # (2) run the decoder to generate sentences, using beam search. 153 | 154 | for i in range(self.config.max_time_step): 155 | 156 | if all((b.done() for b in beam)): 157 | break 158 | 159 | # Construct batch x beam_size nxt words. 160 | # Get all the pending current beam words and arrange for forward. 161 | inp = torch.stack([b.getCurrentState() 162 | for b in beam]).t().contiguous().view(-1) 163 | 164 | # Run one step. 165 | output, decState, attn, c_t = self.decoder(inp, decState, c_t) 166 | # decOut: beam x rnn_size 167 | 168 | # (b) Compute a vector of batch*beam word scores. 169 | output = unbottle(self.log_softmax(output)) 170 | attn = unbottle(attn) 171 | # beam x tgt_vocab 172 | 173 | # (c) Advance each beam. 174 | # update state 175 | for j, b in enumerate(beam): 176 | b.advance(output[:, j], attn[:, j]) 177 | if self.config.cell == 'lstm': 178 | b.beam_update(decState, j) 179 | else: 180 | b.beam_update_gru(decState, j) 181 | b.beam_update_memory(c_t, j) 182 | 183 | # (3) Package everything up. 184 | allHyps, allScores, allAttn = [], [], [] 185 | if eval_: 186 | allWeight = [] 187 | 188 | for j in ind: 189 | b = beam[j] 190 | n_best = 1 191 | scores, ks = b.sortFinished(minimum=n_best) 192 | hyps, attn = [], [] 193 | if eval_: 194 | weight = [] 195 | for i, (times, k) in enumerate(ks[:n_best]): 196 | hyp, att = b.getHyp(times, k) 197 | hyps.append(hyp) 198 | attn.append(att.max(1)[1]) 199 | if eval_: 200 | weight.append(att) 201 | allHyps.append(hyps[0]) 202 | allScores.append(scores[0]) 203 | allAttn.append(attn[0]) 204 | if eval_: 205 | allWeight.append(weight[0]) 206 | 207 | if eval_: 208 | return allHyps, allAttn, allWeight 209 | 210 | return allHyps, allAttn 211 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | import pickle 4 | 5 | parser = argparse.ArgumentParser(description='preprocess.py') 6 | 7 | parser.add_argument('-load_data', required=True, 8 | help="input file for the data") 9 | 10 | parser.add_argument('-save_data', required=True, 11 | help="Output file for the prepared data") 12 | 13 | parser.add_argument('-src_vocab_size', type=int, default=50000, 14 | help="Size of the source vocabulary") 15 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 16 | help="Size of the target vocabulary") 17 | parser.add_argument('-src_filter', type=int, default=0, 18 | help="Maximum source sequence length") 19 | parser.add_argument('-tgt_filter', type=int, default=0, 20 | help="Maximum target sequence length") 21 | parser.add_argument('-src_trun', type=int, default=0, 22 | help="Truncate source sequence length") 23 | parser.add_argument('-tgt_trun', type=int, default=0, 24 | help="Truncate target sequence length") 25 | parser.add_argument('-src_char', action='store_true', help='character based encoding') 26 | parser.add_argument('-tgt_char', action='store_true', help='character based decoding') 27 | parser.add_argument('-src_suf', default='src', 28 | help="the suffix of the source filename") 29 | parser.add_argument('-tgt_suf', default='tgt', 30 | help="the suffix of the target filename") 31 | 32 | parser.add_argument('-share', action='store_true', help='share the vocabulary between source and target') 33 | 34 | parser.add_argument('-report_every', type=int, default=100000, 35 | help="Report status every this many sentences") 36 | 37 | opt = parser.parse_args() 38 | 39 | 40 | def makeVocabulary(filename, trun_length, filter_length, char, vocab, size): 41 | 42 | print("%s: length limit = %d, truncate length = %d" % (filename, filter_length, trun_length)) 43 | max_length = 0 44 | with open(filename, encoding='utf8') as f: 45 | for sent in f.readlines(): 46 | if char: 47 | tokens = list(sent.strip()) 48 | else: 49 | tokens = sent.strip().split() 50 | if 0 < filter_length < len(sent.strip().split()): 51 | continue 52 | max_length = max(max_length, len(tokens)) 53 | if trun_length > 0: 54 | tokens = tokens[:trun_length] 55 | for word in tokens: 56 | vocab.add(word) 57 | 58 | print('Max length of %s = %d' % (filename, max_length)) 59 | 60 | if size > 0: 61 | originalSize = vocab.size() 62 | vocab = vocab.prune(size) 63 | print('Created dictionary of size %d (pruned from %d)' % 64 | (vocab.size(), originalSize)) 65 | 66 | return vocab 67 | 68 | 69 | def saveVocabulary(name, vocab, file): 70 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...') 71 | vocab.writeFile(file) 72 | 73 | 74 | def makeData(srcFile, tgtFile, srcDicts, tgtDicts, save_srcFile, save_tgtFile, lim=0): 75 | sizes = 0 76 | count, empty_ignored, limit_ignored = 0, 0, 0 77 | 78 | print('Processing %s & %s ...' % (srcFile, tgtFile)) 79 | srcF = open(srcFile, encoding='utf8') 80 | tgtF = open(tgtFile, encoding='utf8') 81 | 82 | srcIdF = open(save_srcFile + '.id', 'w') 83 | tgtIdF = open(save_tgtFile + '.id', 'w') 84 | srcStrF = open(save_srcFile + '.str', 'w', encoding='utf8') 85 | tgtStrF = open(save_tgtFile + '.str', 'w', encoding='utf8') 86 | 87 | while True: 88 | sline = srcF.readline() 89 | tline = tgtF.readline() 90 | 91 | # normal end of file 92 | if sline == "" and tline == "": 93 | break 94 | 95 | # source or target does not have same number of lines 96 | if sline == "" or tline == "": 97 | print('WARNING: source and target do not have the same number of sentences') 98 | break 99 | 100 | sline = sline.strip() 101 | tline = tline.strip() 102 | 103 | # source and/or target are empty 104 | if sline == "" or tline == "": 105 | print('WARNING: ignoring an empty line ('+str(count+1)+')') 106 | empty_ignored += 1 107 | continue 108 | 109 | sline = sline.lower() 110 | tline = tline.lower() 111 | 112 | srcWords = sline.split() if not opt.src_char else list(sline) 113 | tgtWords = tline.split() if not opt.tgt_char else list(tline) 114 | 115 | 116 | if (opt.src_filter == 0 or len(sline.split()) <= opt.src_filter) and \ 117 | (opt.tgt_filter == 0 or len(tline.split()) <= opt.tgt_filter): 118 | 119 | if opt.src_trun > 0: 120 | srcWords = srcWords[:opt.src_trun] 121 | if opt.tgt_trun > 0: 122 | tgtWords = tgtWords[:opt.tgt_trun] 123 | 124 | srcIds = srcDicts.convertToIdx(srcWords, utils.UNK_WORD) 125 | tgtIds = tgtDicts.convertToIdx(tgtWords, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD) 126 | 127 | srcIdF.write(" ".join(list(map(str, srcIds)))+'\n') 128 | tgtIdF.write(" ".join(list(map(str, tgtIds)))+'\n') 129 | if not opt.src_char: 130 | srcStrF.write(" ".join(srcWords)+'\n') 131 | else: 132 | srcStrF.write("".join(srcWords) + '\n') 133 | if not opt.tgt_char: 134 | tgtStrF.write(" ".join(tgtWords)+'\n') 135 | else: 136 | tgtStrF.write("".join(tgtWords) + '\n') 137 | 138 | sizes += 1 139 | else: 140 | limit_ignored += 1 141 | 142 | count += 1 143 | 144 | if count % opt.report_every == 0: 145 | print('... %d sentences prepared' % count) 146 | 147 | srcF.close() 148 | tgtF.close() 149 | srcStrF.close() 150 | tgtStrF.close() 151 | srcIdF.close() 152 | tgtIdF.close() 153 | 154 | print('Prepared %d sentences (%d and %d ignored due to length == 0 or > )' % 155 | (sizes, empty_ignored, limit_ignored)) 156 | 157 | return {'srcF': save_srcFile + '.id', 'tgtF': save_tgtFile + '.id', 158 | 'original_srcF': save_srcFile + '.str', 'original_tgtF': save_tgtFile + '.str', 159 | 'length': sizes} 160 | 161 | 162 | def main(): 163 | 164 | dicts = {} 165 | 166 | train_src, train_tgt = opt.load_data + 'train.' + opt.src_suf, opt.load_data + 'train.' + opt.tgt_suf 167 | valid_src, valid_tgt = opt.load_data + 'valid.' + opt.src_suf, opt.load_data + 'valid.' + opt.tgt_suf 168 | test_src, test_tgt = opt.load_data + 'test.' + opt.src_suf, opt.load_data + 'test.' + opt.tgt_suf 169 | 170 | save_train_src, save_train_tgt = opt.save_data + 'train.' + opt.src_suf, opt.save_data + 'train.' + opt.tgt_suf 171 | save_valid_src, save_valid_tgt = opt.save_data + 'valid.' + opt.src_suf, opt.save_data + 'valid.' + opt.tgt_suf 172 | save_test_src, save_test_tgt = opt.save_data + 'test.' + opt.src_suf, opt.save_data + 'test.' + opt.tgt_suf 173 | 174 | src_dict, tgt_dict = opt.save_data + 'src.dict', opt.save_data + 'tgt.dict' 175 | 176 | if opt.share: 177 | assert opt.src_vocab_size == opt.tgt_vocab_size 178 | print('Building source and target vocabulary...') 179 | dicts['src'] = dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 180 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 181 | dicts['src'] = dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 182 | else: 183 | print('Building source vocabulary...') 184 | dicts['src'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 185 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 186 | print('Building target vocabulary...') 187 | dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 188 | dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 189 | 190 | print('Preparing training ...') 191 | train = makeData(train_src, train_tgt, dicts['src'], dicts['tgt'], save_train_src, save_train_tgt) 192 | 193 | print('Preparing validation ...') 194 | valid = makeData(valid_src, valid_tgt, dicts['src'], dicts['tgt'], save_valid_src, save_valid_tgt) 195 | 196 | print('Preparing test ...') 197 | test = makeData(test_src, test_tgt, dicts['src'], dicts['tgt'], save_test_src, save_test_tgt) 198 | 199 | print('Saving source vocabulary to \'' + src_dict + '\'...') 200 | dicts['src'].writeFile(src_dict) 201 | 202 | print('Saving source vocabulary to \'' + tgt_dict + '\'...') 203 | dicts['tgt'].writeFile(tgt_dict) 204 | 205 | data = {'train': train, 'valid': valid, 206 | 'test': test, 'dict': dicts} 207 | pickle.dump(data, open(opt.save_data+'data.pkl', 'wb')) 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import lr_scheduler as L 4 | 5 | import os 6 | import argparse 7 | import pickle 8 | import time 9 | from collections import OrderedDict 10 | 11 | import opts 12 | import models 13 | import utils 14 | import codecs 15 | 16 | import numpy as np 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | import matplotlib.ticker as ticker 21 | 22 | 23 | parser = argparse.ArgumentParser(description='train.py') 24 | opts.model_opts(parser) 25 | 26 | opt = parser.parse_args() 27 | config = utils.read_config(opt.config) 28 | torch.manual_seed(opt.seed) 29 | opts.convert_to_config(opt, config) 30 | 31 | # cuda 32 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 33 | config.use_cuda = use_cuda 34 | if use_cuda: 35 | torch.cuda.set_device(opt.gpus[0]) 36 | torch.cuda.manual_seed(opt.seed) 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def load_data(): 41 | print('loading data...\n') 42 | data = pickle.load(open(config.data+'data.pkl', 'rb')) 43 | data['train']['length'] = int(data['train']['length'] * opt.scale) 44 | 45 | trainset = utils.BiDataset(data['train'], char=config.char) 46 | validset = utils.BiDataset(data['valid'], char=config.char) 47 | 48 | src_vocab = data['dict']['src'] 49 | tgt_vocab = data['dict']['tgt'] 50 | config.src_vocab_size = src_vocab.size() 51 | config.tgt_vocab_size = tgt_vocab.size() 52 | 53 | trainloader = torch.utils.data.DataLoader(dataset=trainset, 54 | batch_size=config.batch_size, 55 | shuffle=True, 56 | num_workers=0, 57 | collate_fn=utils.padding) 58 | if hasattr(config, 'valid_batch_size'): 59 | valid_batch_size = config.valid_batch_size 60 | else: 61 | valid_batch_size = config.batch_size 62 | validloader = torch.utils.data.DataLoader(dataset=validset, 63 | batch_size=valid_batch_size, 64 | shuffle=False, 65 | num_workers=0, 66 | collate_fn=utils.padding) 67 | 68 | return {'trainset': trainset, 'validset': validset, 69 | 'trainloader': trainloader, 'validloader': validloader, 70 | 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab} 71 | 72 | 73 | 74 | def build_model(checkpoints, print_log): 75 | for k, v in config.items(): 76 | print_log("%s:\t%s\n" % (str(k), str(v))) 77 | 78 | # model 79 | print('building model...\n') 80 | model = getattr(models, opt.model)(config) 81 | if checkpoints is not None: 82 | model.load_state_dict(checkpoints['model']) 83 | if opt.pretrain: 84 | print('loading checkpoint from %s' % opt.pretrain) 85 | pre_ckpt = torch.load(opt.pretrain)['model'] 86 | pre_ckpt = OrderedDict({key[8:]: pre_ckpt[key] for key in pre_ckpt if key.startswith('encoder')}) 87 | print(model.encoder.state_dict().keys()) 88 | print(pre_ckpt.keys()) 89 | model.encoder.load_state_dict(pre_ckpt) 90 | if use_cuda: 91 | model.cuda() 92 | 93 | # optimizer 94 | if checkpoints is not None: 95 | optim = checkpoints['optim'] 96 | else: 97 | optim = models.Optim(config.optim, config.learning_rate, config.max_grad_norm, 98 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at) 99 | optim.set_parameters(model.parameters()) 100 | 101 | # print log 102 | param_count = 0 103 | for param in model.parameters(): 104 | param_count += param.view(-1).size()[0] 105 | for k, v in config.items(): 106 | print_log("%s:\t%s\n" % (str(k), str(v))) 107 | print_log("\n") 108 | print_log(repr(model) + "\n\n") 109 | print_log('total number of parameters: %d\n\n' % param_count) 110 | 111 | return model, optim, print_log 112 | 113 | 114 | def train_model(model, data, optim, epoch, params): 115 | 116 | model.train() 117 | trainloader = data['trainloader'] 118 | 119 | for src, tgt, src_len, tgt_len, original_src, original_tgt in trainloader: 120 | 121 | model.zero_grad() 122 | 123 | if config.use_cuda: 124 | src = src.cuda() 125 | tgt = tgt.cuda() 126 | src_len = src_len.cuda() 127 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 128 | src = torch.index_select(src, dim=0, index=indices) 129 | tgt = torch.index_select(tgt, dim=0, index=indices) 130 | dec = tgt[:, :-1] 131 | targets = tgt[:, 1:] 132 | 133 | try: 134 | if config.schesamp: 135 | if epoch > 8: 136 | e = epoch - 8 137 | loss, outputs = model(src, lengths, dec, targets, teacher_ratio=0.9**e) 138 | else: 139 | loss, outputs = model(src, lengths, dec, targets) 140 | else: 141 | loss, outputs = model(src, lengths, dec, targets) 142 | pred = outputs.max(2)[1] 143 | targets = targets.t() 144 | num_correct = pred.eq(targets).masked_select(targets.ne(utils.PAD)).sum().item() 145 | num_total = targets.ne(utils.PAD).sum().item() 146 | if config.max_split == 0: 147 | loss = torch.sum(loss) / num_total 148 | loss.backward() 149 | optim.step() 150 | 151 | # params['report_loss'] += loss.data 152 | params['report_loss'] += loss.item() 153 | params['report_correct'] += num_correct 154 | params['report_total'] += num_total 155 | 156 | except RuntimeError as e: 157 | if 'out of memory' in str(e): 158 | print('| WARNING: ran out of memory') 159 | if hasattr(torch.cuda, 'empty_cache'): 160 | torch.cuda.empty_cache() 161 | else: 162 | raise e 163 | 164 | utils.progress_bar(params['updates'], config.eval_interval) 165 | params['updates'] += 1 166 | 167 | if params['updates'] % config.eval_interval == 0: 168 | params['log']("epoch: %3d, loss: %6.3f, time: %6.3f, updates: %8d, accuracy: %2.2f\n" 169 | % (epoch, params['report_loss'], time.time()-params['report_time'], 170 | params['updates'], params['report_correct'] * 100.0 / params['report_total'])) 171 | print('evaluating after %d updates...\r' % params['updates']) 172 | score = eval_model(model, data, params) 173 | for metric in config.metrics: 174 | params[metric].append(score[metric]) 175 | if score[metric] >= max(params[metric]): 176 | with codecs.open(params['log_path']+'best_'+metric+'_prediction.txt','w','utf-8') as f: 177 | f.write(codecs.open(params['log_path']+'candidate.txt','r','utf-8').read()) 178 | save_model(params['log_path']+'best_'+metric+'_checkpoint.pt', model, optim, params['updates']) 179 | model.train() 180 | params['report_loss'], params['report_time'] = 0, time.time() 181 | params['report_correct'], params['report_total'] = 0, 0 182 | 183 | if params['updates'] % config.save_interval == 0: 184 | save_model(params['log_path']+'checkpoint.pt', model, optim, params['updates']) 185 | 186 | optim.updateLearningRate(score=0, epoch=epoch) 187 | 188 | 189 | def eval_model(model, data, params): 190 | 191 | model.eval() 192 | reference, candidate, source, alignments = [], [], [], [] 193 | count, total_count = 0, len(data['validset']) 194 | validloader = data['validloader'] 195 | tgt_vocab = data['tgt_vocab'] 196 | 197 | 198 | for src, tgt, src_len, tgt_len, original_src, original_tgt in validloader: 199 | 200 | if config.use_cuda: 201 | src = src.cuda() 202 | src_len = src_len.cuda() 203 | 204 | with torch.no_grad(): 205 | if config.beam_size > 1: 206 | samples, alignment, weight = model.beam_sample(src, src_len, beam_size=config.beam_size, eval_=True) 207 | else: 208 | samples, alignment = model.sample(src, src_len) 209 | 210 | candidate += [tgt_vocab.convertToLabels(s, utils.EOS) for s in samples] 211 | source += original_src 212 | reference += original_tgt 213 | if alignment is not None: 214 | alignments += [align for align in alignment] 215 | 216 | count += len(original_src) 217 | utils.progress_bar(count, total_count) 218 | 219 | if config.unk and config.attention != 'None': 220 | cands = [] 221 | for s, c, align in zip(source, candidate, alignments): 222 | cand = [] 223 | for word, idx in zip(c, align): 224 | if word == utils.UNK_WORD and idx < len(s): 225 | try: 226 | cand.append(s[idx]) 227 | except: 228 | cand.append(word) 229 | print("%d %d\n" % (len(s), idx)) 230 | else: 231 | cand.append(word) 232 | cands.append(cand) 233 | if len(cand) == 0: 234 | print('Error!') 235 | candidate = cands 236 | 237 | with codecs.open(params['log_path']+'candidate.txt','w+','utf-8') as f: 238 | for i in range(len(candidate)): 239 | f.write(" ".join(candidate[i])+'\n') 240 | 241 | score = {} 242 | for metric in config.metrics: 243 | score[metric] = getattr(utils, metric)(reference, candidate, params['log_path'], params['log'], config) 244 | 245 | return score 246 | 247 | 248 | def save_model(path, model, optim, updates): 249 | model_state_dict = model.state_dict() 250 | checkpoints = { 251 | 'model': model_state_dict, 252 | 'config': config, 253 | 'optim': optim, 254 | 'updates': updates} 255 | torch.save(checkpoints, path) 256 | 257 | 258 | def build_log(): 259 | # log 260 | if not os.path.exists(config.logF): 261 | os.mkdir(config.logF) 262 | if opt.log == '': 263 | log_path = config.logF + str(int(time.time() * 1000)) + '/' 264 | else: 265 | log_path = config.logF + opt.log + '/' 266 | if not os.path.exists(log_path): 267 | os.mkdir(log_path) 268 | print_log = utils.print_log(log_path + 'log.txt') 269 | return print_log, log_path 270 | 271 | 272 | def showAttention(path, s, c, attentions, index): 273 | # Set up figure with colorbar 274 | fig = plt.figure() 275 | ax = fig.add_subplot(111) 276 | cax = ax.matshow(attentions.numpy(), cmap='bone') 277 | fig.colorbar(cax) 278 | # Set up axes 279 | ax.set_xticklabels([''] + s, rotation=90) 280 | ax.set_yticklabels([''] + c) 281 | # Show label at every tick 282 | ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 283 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 284 | plt.show() 285 | plt.savefig(path + str(index) + '.jpg') 286 | 287 | 288 | def main(): 289 | # checkpoint 290 | if opt.restore: 291 | print('loading checkpoint...\n') 292 | checkpoints = torch.load(opt.restore) 293 | else: 294 | checkpoints = None 295 | 296 | data = load_data() 297 | print_log, log_path = build_log() 298 | model, optim, print_log = build_model(checkpoints, print_log) 299 | # scheduler 300 | if config.schedule: 301 | scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch) 302 | params = {'updates': 0, 'report_loss': 0, 'report_total': 0, 303 | 'report_correct': 0, 'report_time': time.time(), 304 | 'log': print_log, 'log_path': log_path} 305 | for metric in config.metrics: 306 | params[metric] = [] 307 | if opt.restore: 308 | params['updates'] = checkpoints['updates'] 309 | 310 | if opt.mode == "train": 311 | for i in range(1, config.epoch + 1): 312 | if config.schedule: 313 | scheduler.step() 314 | print("Decaying learning rate to %g" % scheduler.get_lr()[0]) 315 | train_model(model, data, optim, i, params) 316 | for metric in config.metrics: 317 | print_log("Best %s score: %.2f\n" % (metric, max(params[metric]))) 318 | else: 319 | score = eval_model(model, data, params) 320 | 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class _LRScheduler(object): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} is not an Optimizer'.format( 10 | type(optimizer).__name__)) 11 | self.optimizer = optimizer 12 | if last_epoch == -1: 13 | for group in optimizer.param_groups: 14 | group.setdefault('initial_lr', group['lr']) 15 | else: 16 | for i, group in enumerate(optimizer.param_groups): 17 | if 'initial_lr' not in group: 18 | raise KeyError("param 'initial_lr' is not specified " 19 | "in param_groups[{}] when resuming an optimizer".format(i)) 20 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 21 | self.step(last_epoch + 1) 22 | self.last_epoch = last_epoch 23 | 24 | def get_lr(self): 25 | raise NotImplementedError 26 | 27 | def step(self, epoch=None): 28 | if epoch is None: 29 | epoch = self.last_epoch + 1 30 | self.last_epoch = epoch 31 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 32 | param_group['lr'] = lr 33 | 34 | 35 | class LambdaLR(_LRScheduler): 36 | """Sets the learning rate of each parameter group to the initial lr 37 | times a given function. When last_epoch=-1, sets initial lr as lr. 38 | Args: 39 | optimizer (Optimizer): Wrapped optimizer. 40 | lr_lambda (function or list): A function which computes a multiplicative 41 | factor given an integer parameter epoch, or a list of such 42 | functions, one for each group in optimizer.param_groups. 43 | last_epoch (int): The index of last epoch. Default: -1. 44 | Example: 45 | >>> # Assuming optimizer has two groups. 46 | >>> lambda1 = lambda epoch: epoch // 30 47 | >>> lambda2 = lambda epoch: 0.95 ** epoch 48 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 49 | >>> for epoch in range(100): 50 | >>> scheduler.step() 51 | >>> train(...) 52 | >>> validate(...) 53 | """ 54 | def __init__(self, optimizer, lr_lambda, last_epoch=-1): 55 | self.optimizer = optimizer 56 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 57 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 58 | else: 59 | if len(lr_lambda) != len(optimizer.param_groups): 60 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 61 | len(optimizer.param_groups), len(lr_lambda))) 62 | self.lr_lambdas = list(lr_lambda) 63 | self.last_epoch = last_epoch 64 | super(LambdaLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | return [base_lr * lmbda(self.last_epoch) 68 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 69 | 70 | 71 | class StepLR(_LRScheduler): 72 | """Sets the learning rate of each parameter group to the initial lr 73 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 74 | initial lr as lr. 75 | Args: 76 | optimizer (Optimizer): Wrapped optimizer. 77 | step_size (int): Period of learning rate decay. 78 | gamma (float): Multiplicative factor of learning rate decay. 79 | Default: 0.1. 80 | last_epoch (int): The index of last epoch. Default: -1. 81 | Example: 82 | >>> # Assuming optimizer uses lr = 0.5 for all groups 83 | >>> # lr = 0.05 if epoch < 30 84 | >>> # lr = 0.005 if 30 <= epoch < 60 85 | >>> # lr = 0.0005 if 60 <= epoch < 90 86 | >>> # ... 87 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 88 | >>> for epoch in range(100): 89 | >>> scheduler.step() 90 | >>> train(...) 91 | >>> validate(...) 92 | """ 93 | 94 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): 95 | self.step_size = step_size 96 | self.gamma = gamma 97 | super(StepLR, self).__init__(optimizer, last_epoch) 98 | 99 | def get_lr(self): 100 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size) 101 | for base_lr in self.base_lrs] 102 | 103 | 104 | class MultiStepLR(_LRScheduler): 105 | """Set the learning rate of each parameter group to the initial lr decayed 106 | by gamma once the number of epoch reaches one of the milestones. When 107 | last_epoch=-1, sets initial lr as lr. 108 | Args: 109 | optimizer (Optimizer): Wrapped optimizer. 110 | milestones (list): List of epoch indices. Must be increasing. 111 | gamma (float): Multiplicative factor of learning rate decay. 112 | Default: 0.1. 113 | last_epoch (int): The index of last epoch. Default: -1. 114 | Example: 115 | >>> # Assuming optimizer uses lr = 0.5 for all groups 116 | >>> # lr = 0.05 if epoch < 30 117 | >>> # lr = 0.005 if 30 <= epoch < 80 118 | >>> # lr = 0.0005 if epoch >= 80 119 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 120 | >>> for epoch in range(100): 121 | >>> scheduler.step() 122 | >>> train(...) 123 | >>> validate(...) 124 | """ 125 | 126 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 127 | if not list(milestones) == sorted(milestones): 128 | raise ValueError('Milestones should be a list of' 129 | ' increasing integers. Got {}', milestones) 130 | self.milestones = milestones 131 | self.gamma = gamma 132 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 133 | 134 | def get_lr(self): 135 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 136 | for base_lr in self.base_lrs] 137 | 138 | 139 | class ExponentialLR(_LRScheduler): 140 | """Set the learning rate of each parameter group to the initial lr decayed 141 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 142 | Args: 143 | optimizer (Optimizer): Wrapped optimizer. 144 | gamma (float): Multiplicative factor of learning rate decay. 145 | last_epoch (int): The index of last epoch. Default: -1. 146 | """ 147 | 148 | def __init__(self, optimizer, gamma, last_epoch=-1): 149 | self.gamma = gamma 150 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 151 | 152 | def get_lr(self): 153 | return [base_lr * self.gamma ** self.last_epoch 154 | for base_lr in self.base_lrs] 155 | 156 | 157 | class CosineAnnealingLR(_LRScheduler): 158 | """Set the learning rate of each parameter group using a cosine annealing 159 | schedule, where :math:`\eta_{max}` is set to the initial lr and 160 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 161 | .. math:: 162 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 163 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 164 | When last_epoch=-1, sets initial lr as lr. 165 | It has been proposed in 166 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 167 | implements the cosine annealing part of SGDR, and not the restarts. 168 | Args: 169 | optimizer (Optimizer): Wrapped optimizer. 170 | T_max (int): Maximum number of iterations. 171 | eta_min (float): Minimum learning rate. Default: 0. 172 | last_epoch (int): The index of last epoch. Default: -1. 173 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 174 | https://arxiv.org/abs/1608.03983 175 | """ 176 | 177 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): 178 | self.T_max = T_max 179 | self.eta_min = eta_min 180 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) 181 | 182 | def get_lr(self): 183 | return [self.eta_min + (base_lr - self.eta_min) * 184 | (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2 185 | for base_lr in self.base_lrs] 186 | 187 | 188 | class ReduceLROnPlateau(object): 189 | """Reduce learning rate when a metric has stopped improving. 190 | Models often benefit from reducing the learning rate by a factor 191 | of 2-10 once learning stagnates. This scheduler reads a metrics 192 | quantity and if no improvement is seen for a 'patience' number 193 | of epochs, the learning rate is reduced. 194 | Args: 195 | optimizer (Optimizer): Wrapped optimizer. 196 | mode (str): One of `min`, `max`. In `min` mode, lr will 197 | be reduced when the quantity monitored has stopped 198 | decreasing; in `max` mode it will be reduced when the 199 | quantity monitored has stopped increasing. Default: 'min'. 200 | factor (float): Factor by which the learning rate will be 201 | reduced. new_lr = lr * factor. Default: 0.1. 202 | patience (int): Number of epochs with no improvement after 203 | which learning rate will be reduced. Default: 10. 204 | verbose (bool): If True, prints a message to stdout for 205 | each update. Default: False. 206 | threshold (float): Threshold for measuring the new optimum, 207 | to only focus on significant changes. Default: 1e-4. 208 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 209 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 210 | mode or best * ( 1 - threshold ) in `min` mode. 211 | In `abs` mode, dynamic_threshold = best + threshold in 212 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 213 | cooldown (int): Number of epochs to wait before resuming 214 | normal operation after lr has been reduced. Default: 0. 215 | min_lr (float or list): A scalar or a list of scalars. A 216 | lower bound on the learning rate of all param groups 217 | or each group respectively. Default: 0. 218 | eps (float): Minimal decay applied to lr. If the difference 219 | between new and old lr is smaller than eps, the update is 220 | ignored. Default: 1e-8. 221 | Example: 222 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 223 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 224 | >>> for epoch in range(10): 225 | >>> train(...) 226 | >>> val_loss = validate(...) 227 | >>> # Note that step should be called after validate() 228 | >>> scheduler.step(val_loss) 229 | """ 230 | 231 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 232 | verbose=False, threshold=1e-4, threshold_mode='rel', 233 | cooldown=0, min_lr=0, eps=1e-8): 234 | 235 | if factor >= 1.0: 236 | raise ValueError('Factor should be < 1.0.') 237 | self.factor = factor 238 | 239 | if not isinstance(optimizer, Optimizer): 240 | raise TypeError('{} is not an Optimizer'.format( 241 | type(optimizer).__name__)) 242 | self.optimizer = optimizer 243 | 244 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 245 | if len(min_lr) != len(optimizer.param_groups): 246 | raise ValueError("expected {} min_lrs, got {}".format( 247 | len(optimizer.param_groups), len(min_lr))) 248 | self.min_lrs = list(min_lr) 249 | else: 250 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 251 | 252 | self.patience = patience 253 | self.verbose = verbose 254 | self.cooldown = cooldown 255 | self.cooldown_counter = 0 256 | self.mode = mode 257 | self.threshold = threshold 258 | self.threshold_mode = threshold_mode 259 | self.best = None 260 | self.num_bad_epochs = None 261 | self.mode_worse = None # the worse value for the chosen mode 262 | self.is_better = None 263 | self.eps = eps 264 | self.last_epoch = -1 265 | self._init_is_better(mode=mode, threshold=threshold, 266 | threshold_mode=threshold_mode) 267 | self._reset() 268 | 269 | def _reset(self): 270 | """Resets num_bad_epochs counter and cooldown counter.""" 271 | self.best = self.mode_worse 272 | self.cooldown_counter = 0 273 | self.num_bad_epochs = 0 274 | 275 | def step(self, metrics, epoch=None): 276 | current = metrics 277 | if epoch is None: 278 | epoch = self.last_epoch = self.last_epoch + 1 279 | self.last_epoch = epoch 280 | 281 | if self.is_better(current, self.best): 282 | self.best = current 283 | self.num_bad_epochs = 0 284 | else: 285 | self.num_bad_epochs += 1 286 | 287 | if self.in_cooldown: 288 | self.cooldown_counter -= 1 289 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 290 | 291 | if self.num_bad_epochs > self.patience: 292 | self._reduce_lr(epoch) 293 | self.cooldown_counter = self.cooldown 294 | self.num_bad_epochs = 0 295 | 296 | def _reduce_lr(self, epoch): 297 | for i, param_group in enumerate(self.optimizer.param_groups): 298 | old_lr = float(param_group['lr']) 299 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 300 | if old_lr - new_lr > self.eps: 301 | param_group['lr'] = new_lr 302 | if self.verbose: 303 | print('Epoch {:5d}: reducing learning rate' 304 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 305 | 306 | @property 307 | def in_cooldown(self): 308 | return self.cooldown_counter > 0 309 | 310 | def _init_is_better(self, mode, threshold, threshold_mode): 311 | if mode not in {'min', 'max'}: 312 | raise ValueError('mode ' + mode + ' is unknown!') 313 | if threshold_mode not in {'rel', 'abs'}: 314 | raise ValueError('threshold mode ' + mode + ' is unknown!') 315 | if mode == 'min' and threshold_mode == 'rel': 316 | rel_epsilon = 1. - threshold 317 | self.is_better = lambda a, best: a < best * rel_epsilon 318 | self.mode_worse = float('Inf') 319 | elif mode == 'min' and threshold_mode == 'abs': 320 | self.is_better = lambda a, best: a < best - threshold 321 | self.mode_worse = float('Inf') 322 | elif mode == 'max' and threshold_mode == 'rel': 323 | rel_epsilon = threshold + 1. 324 | self.is_better = lambda a, best: a > best * rel_epsilon 325 | self.mode_worse = -float('Inf') 326 | else: # mode == 'max' and epsilon_mode == 'abs': 327 | self.is_better = lambda a, best: a > best + threshold 328 | self.mode_worse = -float('Inf') 329 | -------------------------------------------------------------------------------- /script/tokenizer.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | use warnings; 7 | 8 | # Sample Tokenizer 9 | ### Version 1.1 10 | # written by Pidong Wang, based on the code written by Josh Schroeder and Philipp Koehn 11 | # Version 1.1 updates: 12 | # (1) add multithreading option "-threads NUM_THREADS" (default is 1); 13 | # (2) add a timing option "-time" to calculate the average speed of this tokenizer; 14 | # (3) add an option "-lines NUM_SENTENCES_PER_THREAD" to set the number of lines for each thread (default is 2000), and this option controls the memory amount needed: the larger this number is, the larger memory is required (the higher tokenization speed); 15 | ### Version 1.0 16 | # $Id: tokenizer.perl 915 2009-08-10 08:15:49Z philipp $ 17 | # written by Josh Schroeder, based on code by Philipp Koehn 18 | 19 | binmode(STDIN, ":utf8"); 20 | binmode(STDOUT, ":utf8"); 21 | 22 | use warnings; 23 | use FindBin qw($RealBin); 24 | use strict; 25 | use Time::HiRes; 26 | 27 | if (eval {require Thread;1;}) { 28 | #module loaded 29 | Thread->import(); 30 | } 31 | 32 | my $mydir = "$RealBin/nonbreaking_prefixes"; 33 | 34 | my %NONBREAKING_PREFIX = (); 35 | my @protected_patterns = (); 36 | my $protected_patterns_file = ""; 37 | my $language = "en"; 38 | my $QUIET = 0; 39 | my $HELP = 0; 40 | my $AGGRESSIVE = 0; 41 | my $SKIP_XML = 0; 42 | my $TIMING = 0; 43 | my $NUM_THREADS = 1; 44 | my $NUM_SENTENCES_PER_THREAD = 2000; 45 | my $PENN = 0; 46 | my $NO_ESCAPING = 0; 47 | while (@ARGV) 48 | { 49 | $_ = shift; 50 | /^-b$/ && ($| = 1, next); 51 | /^-l$/ && ($language = shift, next); 52 | /^-q$/ && ($QUIET = 1, next); 53 | /^-h$/ && ($HELP = 1, next); 54 | /^-x$/ && ($SKIP_XML = 1, next); 55 | /^-a$/ && ($AGGRESSIVE = 1, next); 56 | /^-time$/ && ($TIMING = 1, next); 57 | # Option to add list of regexps to be protected 58 | /^-protected/ && ($protected_patterns_file = shift, next); 59 | /^-threads$/ && ($NUM_THREADS = int(shift), next); 60 | /^-lines$/ && ($NUM_SENTENCES_PER_THREAD = int(shift), next); 61 | /^-penn$/ && ($PENN = 1, next); 62 | /^-no-escape/ && ($NO_ESCAPING = 1, next); 63 | } 64 | 65 | # for time calculation 66 | my $start_time; 67 | if ($TIMING) 68 | { 69 | $start_time = [ Time::HiRes::gettimeofday( ) ]; 70 | } 71 | 72 | # print help message 73 | if ($HELP) 74 | { 75 | print "Usage ./tokenizer.perl (-l [en|de|...]) (-threads 4) < textfile > tokenizedfile\n"; 76 | print "Options:\n"; 77 | print " -q ... quiet.\n"; 78 | print " -a ... aggressive hyphen splitting.\n"; 79 | print " -b ... disable Perl buffering.\n"; 80 | print " -time ... enable processing time calculation.\n"; 81 | print " -penn ... use Penn treebank-like tokenization.\n"; 82 | print " -protected FILE ... specify file with patters to be protected in tokenisation.\n"; 83 | print " -no-escape ... don't perform HTML escaping on apostrophy, quotes, etc.\n"; 84 | exit; 85 | } 86 | 87 | if (!$QUIET) 88 | { 89 | print STDERR "Tokenizer Version 1.1\n"; 90 | print STDERR "Language: $language\n"; 91 | print STDERR "Number of threads: $NUM_THREADS\n"; 92 | } 93 | 94 | # load the language-specific non-breaking prefix info from files in the directory nonbreaking_prefixes 95 | load_prefixes($language,\%NONBREAKING_PREFIX); 96 | 97 | if (scalar(%NONBREAKING_PREFIX) eq 0) 98 | { 99 | print STDERR "Warning: No known abbreviations for language '$language'\n"; 100 | } 101 | 102 | # Load protected patterns 103 | if ($protected_patterns_file) 104 | { 105 | open(PP,$protected_patterns_file) || die "Unable to open $protected_patterns_file"; 106 | while() { 107 | chomp; 108 | push @protected_patterns, $_; 109 | } 110 | } 111 | 112 | my @batch_sentences = (); 113 | my @thread_list = (); 114 | my $count_sentences = 0; 115 | 116 | if ($NUM_THREADS > 1) 117 | {# multi-threading tokenization 118 | while() 119 | { 120 | $count_sentences = $count_sentences + 1; 121 | push(@batch_sentences, $_); 122 | if (scalar(@batch_sentences)>=($NUM_SENTENCES_PER_THREAD*$NUM_THREADS)) 123 | { 124 | # assign each thread work 125 | for (my $i=0; $i<$NUM_THREADS; $i++) 126 | { 127 | my $start_index = $i*$NUM_SENTENCES_PER_THREAD; 128 | my $end_index = $start_index+$NUM_SENTENCES_PER_THREAD-1; 129 | my @subbatch_sentences = @batch_sentences[$start_index..$end_index]; 130 | my $new_thread = new Thread \&tokenize_batch, @subbatch_sentences; 131 | push(@thread_list, $new_thread); 132 | } 133 | foreach (@thread_list) 134 | { 135 | my $tokenized_list = $_->join; 136 | foreach (@$tokenized_list) 137 | { 138 | print $_; 139 | } 140 | } 141 | # reset for the new run 142 | @thread_list = (); 143 | @batch_sentences = (); 144 | } 145 | } 146 | # the last batch 147 | if (scalar(@batch_sentences)>0) 148 | { 149 | # assign each thread work 150 | for (my $i=0; $i<$NUM_THREADS; $i++) 151 | { 152 | my $start_index = $i*$NUM_SENTENCES_PER_THREAD; 153 | if ($start_index >= scalar(@batch_sentences)) 154 | { 155 | last; 156 | } 157 | my $end_index = $start_index+$NUM_SENTENCES_PER_THREAD-1; 158 | if ($end_index >= scalar(@batch_sentences)) 159 | { 160 | $end_index = scalar(@batch_sentences)-1; 161 | } 162 | my @subbatch_sentences = @batch_sentences[$start_index..$end_index]; 163 | my $new_thread = new Thread \&tokenize_batch, @subbatch_sentences; 164 | push(@thread_list, $new_thread); 165 | } 166 | foreach (@thread_list) 167 | { 168 | my $tokenized_list = $_->join; 169 | foreach (@$tokenized_list) 170 | { 171 | print $_; 172 | } 173 | } 174 | } 175 | } 176 | else 177 | {# single thread only 178 | while() 179 | { 180 | if (($SKIP_XML && /^<.+>$/) || /^\s*$/) 181 | { 182 | #don't try to tokenize XML/HTML tag lines 183 | print $_; 184 | } 185 | else 186 | { 187 | print &tokenize($_); 188 | } 189 | } 190 | } 191 | 192 | if ($TIMING) 193 | { 194 | my $duration = Time::HiRes::tv_interval( $start_time ); 195 | print STDERR ("TOTAL EXECUTION TIME: ".$duration."\n"); 196 | print STDERR ("TOKENIZATION SPEED: ".($duration/$count_sentences*1000)." milliseconds/line\n"); 197 | } 198 | 199 | ##################################################################################### 200 | # subroutines afterward 201 | 202 | # tokenize a batch of texts saved in an array 203 | # input: an array containing a batch of texts 204 | # return: another array containing a batch of tokenized texts for the input array 205 | sub tokenize_batch 206 | { 207 | my(@text_list) = @_; 208 | my(@tokenized_list) = (); 209 | foreach (@text_list) 210 | { 211 | if (($SKIP_XML && /^<.+>$/) || /^\s*$/) 212 | { 213 | #don't try to tokenize XML/HTML tag lines 214 | push(@tokenized_list, $_); 215 | } 216 | else 217 | { 218 | push(@tokenized_list, &tokenize($_)); 219 | } 220 | } 221 | return \@tokenized_list; 222 | } 223 | 224 | # the actual tokenize function which tokenizes one input string 225 | # input: one string 226 | # return: the tokenized string for the input string 227 | sub tokenize 228 | { 229 | my($text) = @_; 230 | 231 | if ($PENN) { 232 | return tokenize_penn($text); 233 | } 234 | 235 | chomp($text); 236 | $text = " $text "; 237 | 238 | # remove ASCII junk 239 | $text =~ s/\s+/ /g; 240 | $text =~ s/[\000-\037]//g; 241 | 242 | # Find protected patterns 243 | my @protected = (); 244 | foreach my $protected_pattern (@protected_patterns) { 245 | my $t = $text; 246 | while ($t =~ /(?$protected_pattern)(?.*)$/) { 247 | push @protected, $+{PATTERN}; 248 | $t = $+{TAIL}; 249 | } 250 | } 251 | 252 | for (my $i = 0; $i < scalar(@protected); ++$i) { 253 | my $subst = sprintf("THISISPROTECTED%.3d", $i); 254 | $text =~ s,\Q$protected[$i], $subst ,g; 255 | } 256 | $text =~ s/ +/ /g; 257 | $text =~ s/^ //g; 258 | $text =~ s/ $//g; 259 | 260 | # seperate out all "other" special characters 261 | $text =~ s/([^\p{IsAlnum}\s\.\'\`\,\-])/ $1 /g; 262 | 263 | # aggressive hyphen splitting 264 | if ($AGGRESSIVE) 265 | { 266 | $text =~ s/([\p{IsAlnum}])\-(?=[\p{IsAlnum}])/$1 \@-\@ /g; 267 | } 268 | 269 | #multi-dots stay together 270 | $text =~ s/\.([\.]+)/ DOTMULTI$1/g; 271 | while($text =~ /DOTMULTI\./) 272 | { 273 | $text =~ s/DOTMULTI\.([^\.])/DOTDOTMULTI $1/g; 274 | $text =~ s/DOTMULTI\./DOTDOTMULTI/g; 275 | } 276 | 277 | # seperate out "," except if within numbers (5,300) 278 | #$text =~ s/([^\p{IsN}])[,]([^\p{IsN}])/$1 , $2/g; 279 | 280 | # separate out "," except if within numbers (5,300) 281 | # previous "global" application skips some: A,B,C,D,E > A , B,C , D,E 282 | # first application uses up B so rule can't see B,C 283 | # two-step version here may create extra spaces but these are removed later 284 | # will also space digit,letter or letter,digit forms (redundant with next section) 285 | $text =~ s/([^\p{IsN}])[,]/$1 , /g; 286 | $text =~ s/[,]([^\p{IsN}])/ , $1/g; 287 | 288 | # separate "," after a number if it's the end of a sentence 289 | $text =~ s/([\p{IsN}])[,]$/$1 ,/g; 290 | 291 | # separate , pre and post number 292 | #$text =~ s/([\p{IsN}])[,]([^\p{IsN}])/$1 , $2/g; 293 | #$text =~ s/([^\p{IsN}])[,]([\p{IsN}])/$1 , $2/g; 294 | 295 | # turn `into ' 296 | #$text =~ s/\`/\'/g; 297 | 298 | #turn '' into " 299 | #$text =~ s/\'\'/ \" /g; 300 | 301 | if ($language eq "en") 302 | { 303 | #split contractions right 304 | $text =~ s/([^\p{IsAlpha}])[']([^\p{IsAlpha}])/$1 ' $2/g; 305 | $text =~ s/([^\p{IsAlpha}\p{IsN}])[']([\p{IsAlpha}])/$1 ' $2/g; 306 | $text =~ s/([\p{IsAlpha}])[']([^\p{IsAlpha}])/$1 ' $2/g; 307 | $text =~ s/([\p{IsAlpha}])[']([\p{IsAlpha}])/$1 '$2/g; 308 | #special case for "1990's" 309 | $text =~ s/([\p{IsN}])[']([s])/$1 '$2/g; 310 | } 311 | elsif (($language eq "fr") or ($language eq "it") or ($language eq "ga")) 312 | { 313 | #split contractions left 314 | $text =~ s/([^\p{IsAlpha}])[']([^\p{IsAlpha}])/$1 ' $2/g; 315 | $text =~ s/([^\p{IsAlpha}])[']([\p{IsAlpha}])/$1 ' $2/g; 316 | $text =~ s/([\p{IsAlpha}])[']([^\p{IsAlpha}])/$1 ' $2/g; 317 | $text =~ s/([\p{IsAlpha}])[']([\p{IsAlpha}])/$1' $2/g; 318 | } 319 | else 320 | { 321 | $text =~ s/\'/ \' /g; 322 | } 323 | 324 | #word token method 325 | my @words = split(/\s/,$text); 326 | $text = ""; 327 | for (my $i=0;$i<(scalar(@words));$i++) 328 | { 329 | my $word = $words[$i]; 330 | if ( $word =~ /^(\S+)\.$/) 331 | { 332 | my $pre = $1; 333 | if (($pre =~ /\./ && $pre =~ /\p{IsAlpha}/) || ($NONBREAKING_PREFIX{$pre} && $NONBREAKING_PREFIX{$pre}==1) || ($i/\>/g; # xml 377 | $text =~ s/\'/\'/g; # xml 378 | $text =~ s/\"/\"/g; # xml 379 | $text =~ s/\[/\[/g; # syntax non-terminal 380 | $text =~ s/\]/\]/g; # syntax non-terminal 381 | } 382 | 383 | #ensure final line break 384 | $text .= "\n" unless $text =~ /\n$/; 385 | 386 | return $text; 387 | } 388 | 389 | sub tokenize_penn 390 | { 391 | # Improved compatibility with Penn Treebank tokenization. Useful if 392 | # the text is to later be parsed with a PTB-trained parser. 393 | # 394 | # Adapted from Robert MacIntyre's sed script: 395 | # http://www.cis.upenn.edu/~treebank/tokenizer.sed 396 | 397 | my($text) = @_; 398 | chomp($text); 399 | 400 | # remove ASCII junk 401 | $text =~ s/\s+/ /g; 402 | $text =~ s/[\000-\037]//g; 403 | 404 | # attempt to get correct directional quotes 405 | $text =~ s/^``/`` /g; 406 | $text =~ s/^"/`` /g; 407 | $text =~ s/^`([^`])/` $1/g; 408 | $text =~ s/^'/` /g; 409 | $text =~ s/([ ([{<])"/$1 `` /g; 410 | $text =~ s/([ ([{<])``/$1 `` /g; 411 | $text =~ s/([ ([{<])`([^`])/$1 ` $2/g; 412 | $text =~ s/([ ([{<])'/$1 ` /g; 413 | # close quotes handled at end 414 | 415 | $text =~ s=\.\.\.= _ELLIPSIS_ =g; 416 | 417 | # separate out "," except if within numbers (5,300) 418 | $text =~ s/([^\p{IsN}])[,]([^\p{IsN}])/$1 , $2/g; 419 | # separate , pre and post number 420 | $text =~ s/([\p{IsN}])[,]([^\p{IsN}])/$1 , $2/g; 421 | $text =~ s/([^\p{IsN}])[,]([\p{IsN}])/$1 , $2/g; 422 | 423 | #$text =~ s=([;:@#\$%&\p{IsSc}])= $1 =g; 424 | $text =~ s=([;:@#\$%&\p{IsSc}\p{IsSo}])= $1 =g; 425 | 426 | # Separate out intra-token slashes. PTB tokenization doesn't do this, so 427 | # the tokens should be merged prior to parsing with a PTB-trained parser 428 | # (see syntax-hyphen-splitting.perl). 429 | $text =~ s/([\p{IsAlnum}])\/([\p{IsAlnum}])/$1 \@\/\@ $2/g; 430 | 431 | # Assume sentence tokenization has been done first, so split FINAL periods 432 | # only. 433 | $text =~ s=([^.])([.])([\]\)}>"']*) ?$=$1 $2$3 =g; 434 | # however, we may as well split ALL question marks and exclamation points, 435 | # since they shouldn't have the abbrev.-marker ambiguity problem 436 | $text =~ s=([?!])= $1 =g; 437 | 438 | # parentheses, brackets, etc. 439 | $text =~ s=([\]\[\(\){}<>])= $1 =g; 440 | $text =~ s/\(/-LRB-/g; 441 | $text =~ s/\)/-RRB-/g; 442 | $text =~ s/\[/-LSB-/g; 443 | $text =~ s/\]/-RSB-/g; 444 | $text =~ s/{/-LCB-/g; 445 | $text =~ s/}/-RCB-/g; 446 | 447 | $text =~ s=--= -- =g; 448 | 449 | # First off, add a space to the beginning and end of each line, to reduce 450 | # necessary number of regexps. 451 | $text =~ s=$= =; 452 | $text =~ s=^= =; 453 | 454 | $text =~ s="= '' =g; 455 | # possessive or close-single-quote 456 | $text =~ s=([^'])' =$1 ' =g; 457 | # as in it's, I'm, we'd 458 | $text =~ s='([sSmMdD]) = '$1 =g; 459 | $text =~ s='ll = 'll =g; 460 | $text =~ s='re = 're =g; 461 | $text =~ s='ve = 've =g; 462 | $text =~ s=n't = n't =g; 463 | $text =~ s='LL = 'LL =g; 464 | $text =~ s='RE = 'RE =g; 465 | $text =~ s='VE = 'VE =g; 466 | $text =~ s=N'T = N'T =g; 467 | 468 | $text =~ s= ([Cc])annot = $1an not =g; 469 | $text =~ s= ([Dd])'ye = $1' ye =g; 470 | $text =~ s= ([Gg])imme = $1im me =g; 471 | $text =~ s= ([Gg])onna = $1on na =g; 472 | $text =~ s= ([Gg])otta = $1ot ta =g; 473 | $text =~ s= ([Ll])emme = $1em me =g; 474 | $text =~ s= ([Mm])ore'n = $1ore 'n =g; 475 | $text =~ s= '([Tt])is = '$1 is =g; 476 | $text =~ s= '([Tt])was = '$1 was =g; 477 | $text =~ s= ([Ww])anna = $1an na =g; 478 | 479 | #word token method 480 | my @words = split(/\s/,$text); 481 | $text = ""; 482 | for (my $i=0;$i<(scalar(@words));$i++) 483 | { 484 | my $word = $words[$i]; 485 | if ( $word =~ /^(\S+)\.$/) 486 | { 487 | my $pre = $1; 488 | if (($pre =~ /\./ && $pre =~ /\p{IsAlpha}/) || ($NONBREAKING_PREFIX{$pre} && $NONBREAKING_PREFIX{$pre}==1) || ($i/\>/g; # xml 517 | $text =~ s/\'/\'/g; # xml 518 | $text =~ s/\"/\"/g; # xml 519 | $text =~ s/\[/\[/g; # syntax non-terminal 520 | $text =~ s/\]/\]/g; # syntax non-terminal 521 | 522 | #ensure final line break 523 | $text .= "\n" unless $text =~ /\n$/; 524 | 525 | return $text; 526 | } 527 | 528 | sub load_prefixes 529 | { 530 | my ($language, $PREFIX_REF) = @_; 531 | 532 | my $prefixfile = "$mydir/nonbreaking_prefix.$language"; 533 | 534 | #default back to English if we don't have a language-specific prefix file 535 | if (!(-e $prefixfile)) 536 | { 537 | $prefixfile = "$mydir/nonbreaking_prefix.en"; 538 | print STDERR "WARNING: No known abbreviations for language '$language', attempting fall-back to English version...\n"; 539 | die ("ERROR: No abbreviations files found in $mydir\n") unless (-e $prefixfile); 540 | } 541 | 542 | if (-e "$prefixfile") 543 | { 544 | open(PREFIX, "<:utf8", "$prefixfile"); 545 | while () 546 | { 547 | my $item = $_; 548 | chomp($item); 549 | if (($item) && (substr($item,0,1) ne "#")) 550 | { 551 | if ($item =~ /(.*)[\s]+(\#NUMERIC_ONLY\#)/) 552 | { 553 | $PREFIX_REF->{$1} = 2; 554 | } 555 | else 556 | { 557 | $PREFIX_REF->{$item} = 1; 558 | } 559 | } 560 | } 561 | close(PREFIX); 562 | } 563 | } 564 | --------------------------------------------------------------------------------