├── README.md
├── __init__.py
├── gpu_requirements.txt
├── images
├── best_dev_ppl.svg
├── best_train_ppl.svg
├── gru_dev_ppl.svg
├── gru_train_ppl.svg
├── lstm_2_layer_dev_ppl.svg
├── lstm_2_layer_train_ppl.svg
├── lstm_dev_ppl.svg
├── lstm_leaky_dev_ppl.svg
├── lstm_leaky_train_ppl.svg
├── lstm_relu_dev_ppl.svg
├── lstm_relu_train_ppl.svg
└── lstm_train_ppl.svg
├── local_env.yml
├── model_embeddings.py
├── nmt_model.py
├── run.py
├── run.sh
├── utils.py
└── vocab.py
/README.md:
--------------------------------------------------------------------------------
1 | # ZH-EN NMT Chinese to English Neural Machine Translation
2 |
3 | > This project is inspired by Stanford's CS224N NMT Project
4 |
5 | > Dataset used in this project: [News Commentary v14](http://data.statmt.org/news-commentary/v14)
6 |
7 | ## Intro
8 |
9 | This project is more of a learning project to make myself familiar with Pytorch, machine translation, and NLP model training.
10 |
11 | To investigate how would various setups of the recurrent layer affect the final performance, I compared Training Efficiency and Effectiveness of different types of RNN layer for encoder by changing one feature each time while controlling all other parameters:
12 |
13 | - RNN types
14 | - GRU
15 | - LSTM
16 | - Activation Functions on Output Layer
17 | - Tanh
18 | - ReLU
19 | - LeakyReLU
20 | - Number of layers
21 |
22 | - single layer
23 | - double layer
24 |
25 | ## Code Files
26 |
27 | ```
28 | _/
29 | ├─ utils.py # utilities
30 | ├─ vocab.py # generate vocab
31 | ├─ model_embeddings.py # embedding layer
32 | ├─ nmt_model.py # nmt model definition
33 | ├─ run.py # training and testing
34 | ```
35 |
36 | ## Good Translation Examples
37 |
38 | - ***source***: 相反,这意味着合作的基础应当是共同的长期战略利益,而不是共同的价值观。
39 | - ***target***: Instead, it means that cooperation must be anchored not in shared values, but in shared long-term strategic interests.
40 | - ***translation***: On the contrary, that means cooperation should be a common long-term strategic interests, rather than shared values.
41 |
42 | - ***source***: 但这个问题其实很简单: 谁来承受这些用以降低预算赤字的紧缩措施的冲击。
43 | - ***target***: But the issue is actually simple: Who will bear the brunt of measures to reduce the budget deficit?
44 | - ***translation***: But the question is simple: Who is to bear the impact of austerity measures to reduce budget deficits?
45 | - ***source***: 上述合作对打击恐怖主义、贩卖人口和移民可能发挥至关重要的作用。
46 | - ***target***: Such cooperation is essential to combat terrorism, human trafficking, and migration.
47 | - ***translation***: Such cooperation is essential to fighting terrorism, trafficking, and migration.
48 | - ***source***: 与此同时, 政治危机妨碍着政府追求艰难的改革。
49 | - ***target***: At the same time, political crisis is impeding the government’s pursuit of difficult reforms.
50 | - ***translation***: Meanwhile, political crises hamper the government’s pursuit of difficult reforms.
51 |
52 | ## Preprocessing
53 |
54 | > Preprocessing Colab [notebook](https://colab.research.google.com/drive/1IJTdk7hj3uoPEE0Ox7QaeW4rTuUzuxPJ?usp=sharing)
55 |
56 | - using [`jieba` ](https://github.com/fxsjy/jieba)to separate Chinese words by spaces
57 |
58 | ## Generate Vocab From Training Data
59 |
60 | - Input: training data of Chinese and English
61 |
62 | - Output: a vocab file containing mapping from (sub)words to ids of Chinese and English -- a limited size of vocab is selected using [SentencePiece](https://github.com/google/sentencepiece) (essentially [Byte Pair Encoding](https://en.wikipedia.org/wiki/Byte_pair_encoding) of character n-grams) to cover around 99.95% of training data
63 |
64 | ## Model Definition
65 |
66 | - a Seq2Seq model with attention
67 |
68 | > This image is from the book [DIVE INTO DEEP LEARNING](https://zh-v2.d2l.ai/index.html)
69 |
70 | 
71 |
72 | - Encoder
73 | - A Recurrent Layer
74 | - Decoder
75 | - LSTMCell (hidden_size=512)
76 | - Attention
77 | - Multiplicative Attention
78 |
79 | ## Training And Testing Results
80 |
81 | > Training Colab [notebook](https://colab.research.google.com/drive/1HYbOh0AUMEasBAH7QPGNq9joH2dRRZwg?usp=sharing)
82 |
83 | - **Hyperparameters:**
84 | - Embedding Size & Hidden Size: 512
85 | - Dropout Rate: 0.25
86 | - Starting Learning Rate: 5e-4
87 | - Batch Size: 32
88 | - Beam Size for Beam Search: 10
89 | - **NOTE:** The BLEU score calculated here is based on the `Test Set`, so it could only be used to compare the **relative effectiveness** of the models using this data
90 |
91 | #### For Experiment
92 |
93 | - **Dataset:** the dataset is split into training set(~260000), validation set(~20000), and testing set(~20000) randomly (they are the same for each experiment group)
94 | - **Max Number of Iterations**: 50000
95 | - **NOTE:** I've tried Vanilla-RNN(nn.RNN) in various ways, but the BLEU score turns out to be extremely low for it (absence of `residual connections` might be the issue)
96 | - I decided to not include it for comparison until the issue is resolved
97 |
98 | | | Training Time(sec) | BLEU Score on Test Set | Training Perplexities | Validation Perplexities |
99 | | ------------------------------------------------ | ------------------ | ---------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
100 | | **A.** Bidirectional 1-Layer GRU with Tanh | 5158.99 | 14.26 |  |  |
101 | | **B.** Bidirectional 1-Layer LSTM with Tanh | 5150.31 | 16.20 |  |  |
102 | | **C.** Bidirectional 2-Layer LSTM with Tanh | 6197.58 | **16.38** |  |  |
103 | | **D.** Bidirectional 1-Layer LSTM with ReLU | 5275.12 | 14.01 |  |  |
104 | | **E.** Bidirectional 1-Layer LSTM with LeakyReLU(slope=0.1) | 5292.58 | 14.87 |  |  |
105 |
106 | #### Current Best Version
107 | Bidirectional 2-Layer LSTM with Tanh, **1024 embed_size & hidden_size**, trained 11517.19 sec (44000 iterations), BLEU score **17.95**
108 | | | Traning Time | BLEU Score on Test Set | Training Perplexities | Validation Perplexities |
109 | |:----------:|:------------:|:----------------------:|-----------------------|-------------------------|
110 | | Best Model | 11517.19 | **17.95** |  |  |
111 |
112 | #### Analysis
113 |
114 | - LSTM tends to have better performance than GRU (it has an extra set of parameters)
115 | - Tanh tends to be better since less information is lost
116 | - Making the LSTM deeper (more layers) could improve the performance, but it cost more time to train
117 | - Surprisingly, the training time for **A**, **B**, and **D** are roughly the same
118 | - the issue may be the dataset is not large enough, or the cloud service I used to train models does not perform consistently
119 |
120 | ## Bad Examples & Case Analysis
121 |
122 | - ***source***: **全球目击组织(Global Witness)**的报告记录, 光是2015年就有**16个国家**的185人被杀。
123 | - ***target***: A **Global Witness** report documented 185 killings across **16 countries** in 2015 alone.
124 | - ***translation***: According to the **Global eye**, the World Health Organization reported that 185 people were killed in 2015.
125 | - ***problems***:
126 | - Information Loss: 16 countries
127 | - Unknown Proper Noun: Global Witness
128 | - ***source***: 大自然给了足以满足每个人需要的东西, **但无法满足每个人的贪婪**。
129 | - ***target***: Nature provides enough for everyone’s needs, **but not for everyone’s greed**.
130 | - ***translation***: Nature provides enough to satisfy everyone.
131 | - ***problems***:
132 | - Huge Information Loss
133 | - ***source***: 我衷心希望全球经济危机和巴拉克·奥巴马当选总统能对新冷战的荒唐理念进行正确的评估。
134 | - ***target***: It is my hope that the global economic crisis and Barack Obama’s presidency will put the farcical idea of a new Cold War into proper perspective.
135 | - ***translation***: I do hope that the global economic crisis and President Barack Obama will be corrected for a new Cold War.
136 | - ***problems***:
137 | - Action Sender And Receiver Exchanged
138 | - Failed To Translate Complex Sentence
139 | - ***source***: 人们纷纷**猜测**欧元区将崩溃。
140 | - ***target***: **Speculation** about a possible breakup was widespread.
141 | - ***translation***: The eurozone would collapse.
142 | - ***problems***:
143 | - Significant Information Loss
144 |
145 | ## Means to Improve the NMT model
146 |
147 | - Dataset
148 | - The dataset is fairly small, and our model is not being trained thorough all data
149 | - Being a native Chinese speaker, I could not understand what some of the source sentences are saying
150 | - The target sentences are not informational comprehensive; they themselves need context to be understood (e.g. the target sentence in the last "Bad Examples")
151 | - Even for human, some of the source sentence was too hard to translate
152 | - Model Architecture
153 | - CNN & Transformer
154 | - character based model
155 | - Make the model even larger & deeper (... I need GPUs)
156 | - Tricks that might help
157 | - Add a proper noun dictionary to translate unknown proper nouns word-by-word (phrase-by-phrase)
158 | - Initialize (sub)word embedding with pretrained embedding
159 |
160 | ## How To Run
161 | - Download the dataset you desire, and change all "./zh_en_data" in `run.sh` to the path where your data is stored
162 | - To run locally on a CPU (mostly for sanity check, CPU is not able to train the model)
163 | - set up the environment using conda/miniconda `conda env create --file local env.yml`
164 | - To run on a GPU
165 | - set up the environment and running process following the Colab [notebook](https://colab.research.google.com/drive/1HYbOh0AUMEasBAH7QPGNq9joH2dRRZwg?usp=sharing)
166 |
167 |
168 | ## Contact
169 | If you have any questions or you have trouble running the code, feel free to contact me via [email](mailto:jasonfen@usc.edu)
170 |
171 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonFengGit/Neural-Model-Translation/5585389efc841a0cb3a26656caaba912b29d1770/__init__.py
--------------------------------------------------------------------------------
/gpu_requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | docopt
3 | tqdm==4.29.1
4 | sentencepiece
5 | sacrebleu
6 | torch
7 |
--------------------------------------------------------------------------------
/images/best_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
691 |
--------------------------------------------------------------------------------
/images/gru_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/lstm_2_layer_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
719 |
--------------------------------------------------------------------------------
/images/lstm_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
719 |
--------------------------------------------------------------------------------
/images/lstm_leaky_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
719 |
--------------------------------------------------------------------------------
/images/lstm_relu_dev_ppl.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
719 |
--------------------------------------------------------------------------------
/local_env.yml:
--------------------------------------------------------------------------------
1 | name: local_nmt
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.7
7 | - numpy
8 | - scipy
9 | - tqdm
10 | - docopt
11 | - pytorch
12 | - nltk
13 | - torchvision
14 | - pip
15 | - pip:
16 | - sentencepiece
17 | - sacrebleu
18 | - jieba
19 |
--------------------------------------------------------------------------------
/model_embeddings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch.nn as nn
5 |
6 | class ModelEmbeddings(nn.Module):
7 | """
8 | Class that converts input words to their embeddings.
9 | """
10 | def __init__(self, embed_size, vocab):
11 | """
12 | Init the Embedding layers.
13 |
14 | @param embed_size (int): Embedding size (dimensionality)
15 | @param vocab (Vocab): Vocabulary object containing src and tgt languages
16 | See vocab.py for documentation.
17 | """
18 | super(ModelEmbeddings, self).__init__()
19 | self.embed_size = embed_size
20 |
21 | # default values
22 | self.source = None
23 | self.target = None
24 |
25 | src_pad_token_idx = vocab.src['']
26 | tgt_pad_token_idx = vocab.tgt['']
27 |
28 | self.source = nn.Embedding(len(vocab.src), self.embed_size, padding_idx = src_pad_token_idx)
29 | self.target = nn.Embedding(len(vocab.tgt), self.embed_size, padding_idx = tgt_pad_token_idx)
30 |
31 |
32 |
--------------------------------------------------------------------------------
/nmt_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import sys
5 | from collections import namedtuple
6 | from typing import Dict, List, Set, Tuple, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.utils
12 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
13 |
14 | from model_embeddings import ModelEmbeddings
15 |
16 | Hypothesis = namedtuple('Hypothesis', ['value', 'score'])
17 |
18 |
19 | class NMT(nn.Module):
20 | """ Simple Neural Machine Translation Model:
21 | - RNN Encoder
22 | - Unidirection LSTM Decoder
23 | - Global Attention Model (Luong, et al. 2015)
24 | """
25 | def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2, rnn_layer=nn.LSTM, num_layers=1, activation=torch.tanh):
26 | """ Init NMT Model.
27 |
28 | @param embed_size (int): Embedding size (dimensionality)
29 | @param hidden_size (int): Hidden Size, the size of hidden states (dimensionality)
30 | @param vocab (Vocab): Vocabulary object containing src and tgt languages
31 | See vocab.py for documentation.
32 | @param dropout_rate (float): Dropout probability, for attention
33 | """
34 | super(NMT, self).__init__()
35 | self.model_embeddings = ModelEmbeddings(embed_size, vocab)
36 | self.hidden_size = hidden_size
37 | self.dropout_rate = dropout_rate
38 | self.vocab = vocab
39 | self.rnn_layer = rnn_layer
40 | self.num_layers = num_layers
41 | self.activation = activation
42 |
43 | # default values
44 | self.encoder = None
45 | self.decoder = None
46 | self.h_projection = None
47 | self.c_projection = None
48 | self.att_projection = None
49 | self.combined_output_projection = None
50 | self.target_vocab_projection = None
51 | self.dropout = None
52 | print("***",self.const,"***")
53 | # model layers
54 | self.is_lstm = (rnn_layer == nn.LSTM)
55 | self.encoder = rnn_layer(input_size=embed_size, hidden_size=hidden_size, bidirectional=True, bias=True, num_layers=num_layers)
56 | self.decoder = nn.LSTMCell(input_size=embed_size+hidden_size, hidden_size=hidden_size, bias=True)
57 | self.h_projection = nn.Linear(hidden_size*2, hidden_size, bias=False)
58 | if self.is_lstm:
59 | self.c_projection = nn.Linear(hidden_size*2, hidden_size, bias=False)
60 | self.att_projection = nn.Linear(hidden_size*2, hidden_size, bias=False)
61 | self.combined_output_projection = nn.Linear(hidden_size*3, hidden_size, bias=False)
62 | self.target_vocab_projection = nn.Linear(hidden_size, len(vocab.tgt), bias=False)
63 | self.dropout = nn.Dropout(p=self.dropout_rate)
64 |
65 |
66 | def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor:
67 | """ Take a mini-batch of source and target sentences, compute the log-likelihood of
68 | target sentences under the language models learned by the NMT system.
69 |
70 | @param source (List[List[str]]): list of source sentence tokens
71 | @param target (List[List[str]]): list of target sentence tokens, wrapped by `` and ``
72 |
73 | @returns scores (Tensor): a variable/tensor of shape (b, ) representing the
74 | log-likelihood of generating the gold-standard target sentence for
75 | each example in the input batch. Here b = batch size.
76 | """
77 | # Compute sentence lengths
78 | source_lengths = [len(s) for s in source]
79 |
80 | # Convert list of lists into tensors
81 | source_padded = self.vocab.src.to_input_tensor(source, device=self.device) # Tensor: (src_len, b)
82 | target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) # Tensor: (tgt_len, b)
83 |
84 | enc_hiddens, dec_init_state = self.encode(source_padded, source_lengths)
85 | enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths)
86 | combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded)
87 | P = F.log_softmax(self.target_vocab_projection(combined_outputs), dim=-1)
88 |
89 | # Zero out, probabilities for which we have nothing in the target text
90 | target_masks = (target_padded != self.vocab.tgt['']).float()
91 |
92 | # Compute log probability of generating true target words
93 | target_gold_words_log_prob = torch.gather(P, index=target_padded[1:].unsqueeze(-1), dim=-1).squeeze(-1) * target_masks[1:]
94 | scores = target_gold_words_log_prob.sum(dim=0)
95 | return scores
96 |
97 |
98 | def encode(self, source_padded: torch.Tensor, source_lengths: List[int]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
99 | """ Apply the encoder to source sentences to obtain encoder hidden states.
100 | Additionally, take the final states of the encoder and project them to obtain initial states for decoder.
101 |
102 | @param source_padded (Tensor): Tensor of padded source sentences with shape (src_len, b), where
103 | b = batch_size, src_len = maximum source sentence length. Note that
104 | these have already been sorted in order of longest to shortest sentence.
105 | @param source_lengths (List[int]): List of actual lengths for each of the source sentences in the batch
106 | @returns enc_hiddens (Tensor): Tensor of hidden units with shape (b, src_len, h*2), where
107 | b = batch size, src_len = maximum source sentence length, h = hidden size.
108 | @returns dec_init_state (tuple(Tensor, Tensor)): Tuple of tensors representing the decoder's initial
109 | hidden state and cell.
110 | """
111 | enc_hiddens, dec_init_state = None, None
112 | last_hidden, last_cell = None, None
113 | X = self.model_embeddings.source(source_padded) # (src_len, b, e)
114 | enc_hiddens, last_state = self.encoder(pack_padded_sequence(X, source_lengths))
115 | if self.is_lstm:
116 | last_hidden, last_cell = last_state
117 | else:
118 | last_hidden = last_state
119 | enc_hiddens, _ = pad_packed_sequence(enc_hiddens) # returns seq_unpacked, lens_unpacked
120 | enc_hiddens = enc_hiddens.permute(1, 0, 2)
121 |
122 | init_decoder_hidden = self.h_projection(torch.cat((last_hidden[0], last_hidden[1]), 1))
123 |
124 | if self.is_lstm:
125 | init_decoder_cell = self.c_projection(torch.cat((last_cell[0], last_cell[1]), 1))
126 | dec_init_state = (init_decoder_hidden, init_decoder_cell)
127 | else:
128 | dec_init_state = (init_decoder_hidden, torch.zeros_like(init_decoder_hidden))
129 |
130 | print(enc_hiddens.shape)
131 | print("*"*40)
132 | return enc_hiddens, dec_init_state
133 |
134 |
135 | def decode(self, enc_hiddens: torch.Tensor, enc_masks: torch.Tensor,
136 | dec_init_state: Tuple[torch.Tensor, torch.Tensor], target_padded: torch.Tensor) -> torch.Tensor:
137 | """Compute combined output vectors for a batch.
138 |
139 | @param enc_hiddens (Tensor): Hidden states (b, src_len, h*2), where
140 | b = batch size, src_len = maximum source sentence length, h = hidden size.
141 | @param enc_masks (Tensor): Tensor of sentence masks (b, src_len), where
142 | b = batch size, src_len = maximum source sentence length.
143 | @param dec_init_state (tuple(Tensor, Tensor)): Initial state and cell for decoder
144 | @param target_padded (Tensor): Gold-standard padded target sentences (tgt_len, b), where
145 | tgt_len = maximum target sentence length, b = batch size.
146 |
147 | @returns combined_outputs (Tensor): combined output tensor (tgt_len, b, h), where
148 | tgt_len = maximum target sentence length, b = batch_size, h = hidden size
149 | """
150 | # Chop of the token for max length sentences.
151 | target_padded = target_padded[:-1]
152 |
153 | # Initialize the decoder state (hidden and cell)
154 | dec_state = dec_init_state
155 |
156 | # Initialize previous combined output vector o_{t-1} as zero
157 | batch_size = enc_hiddens.size(0)
158 | o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device)
159 |
160 | # Initialize a list we will use to collect the combined output o_t on each step
161 | combined_outputs = []
162 |
163 | enc_hiddens_proj = self.att_projection(enc_hiddens) # (b, src_len, h)
164 | Y = self.model_embeddings.target(target_padded) # (tgt_len, b, e)
165 | for Y_t in torch.split(Y, 1, dim=0):
166 | Y_t = torch.squeeze(Y_t, 0) # (b, e)
167 | Ybar_t = torch.cat((Y_t, o_prev), 1)
168 | dec_state, o_t, e_t = self.step(Ybar_t, dec_state, enc_hiddens, enc_hiddens_proj, enc_masks)
169 | combined_outputs.append(o_t)
170 | o_prev = o_t
171 | combined_outputs = torch.stack(combined_outputs, dim=0) # (tgt_len, b, h)
172 |
173 | return combined_outputs
174 |
175 |
176 | def step(self, Ybar_t: torch.Tensor,
177 | dec_state: Tuple[torch.Tensor, torch.Tensor],
178 | enc_hiddens: torch.Tensor,
179 | enc_hiddens_proj: torch.Tensor,
180 | enc_masks: torch.Tensor) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
181 | """ Compute one forward step of the LSTM decoder, including the attention computation.
182 |
183 | @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder,
184 | where b = batch size, e = embedding size, h = hidden size.
185 | @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size.
186 | First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
187 | @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size,
188 | src_len = maximum source length, h = hidden size.
189 | @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h),
190 | where b = batch size, src_len = maximum source length, h = hidden size.
191 | @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
192 | where b = batch size, src_len is maximum source length.
193 |
194 | @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size.
195 | First tensor is decoder's new hidden state, second tensor is decoder's new cell.
196 | @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size.
197 | @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
198 | Note: You will not use this outside of this function.
199 | We are simply returning this value so that we can sanity check
200 | your implementation.
201 | """
202 |
203 | combined_output = None
204 | dec_state = self.decoder(Ybar_t, dec_state)
205 | dec_hidden, dec_cell = dec_state
206 | e_t = torch.squeeze(torch.bmm(enc_hiddens_proj, torch.unsqueeze(dec_hidden, dim=2)), dim=2)
207 |
208 | # Set e_t to -inf where enc_masks has 1 to ignore tokens
209 | if enc_masks is not None:
210 | e_t.data.masked_fill_(enc_masks.bool(), -float('inf'))
211 |
212 | alpha_t = F.softmax(e_t, dim=1) # (b, src_len)
213 | a_t = torch.squeeze(torch.bmm(torch.unsqueeze(alpha_t, 1), enc_hiddens), dim=1) # (b, 2h)
214 | U_t = torch.cat((dec_hidden, a_t), dim=1)
215 | V_t = self.combined_output_projection(U_t)
216 | O_t = self.dropout(self.activation(V_t))
217 |
218 | combined_output = O_t
219 | return dec_state, combined_output, e_t
220 |
221 | def generate_sent_masks(self, enc_hiddens: torch.Tensor, source_lengths: List[int]) -> torch.Tensor:
222 | """ Generate sentence masks for encoder hidden states.
223 |
224 | @param enc_hiddens (Tensor): encodings of shape (b, src_len, 2*h), where b = batch size,
225 | src_len = max source length, h = hidden size.
226 | @param source_lengths (List[int]): List of actual lengths for each of the sentences in the batch.
227 |
228 | @returns enc_masks (Tensor): Tensor of sentence masks of shape (b, src_len),
229 | where src_len = max source length, h = hidden size.
230 | """
231 | enc_masks = torch.zeros(enc_hiddens.size(0), enc_hiddens.size(1), dtype=torch.float)
232 | for e_id, src_len in enumerate(source_lengths):
233 | enc_masks[e_id, src_len:] = 1
234 | return enc_masks.to(self.device)
235 |
236 |
237 | def beam_search(self, src_sent: List[str], beam_size: int=5, max_decoding_time_step: int=70) -> List[Hypothesis]:
238 | """ Given a single source sentence, perform beam search, yielding translations in the target language.
239 | @param src_sent (List[str]): a single source sentence (words)
240 | @param beam_size (int): beam size
241 | @param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN
242 | @returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields:
243 | value: List[str]: the decoded target sentence, represented as a list of words
244 | score: float: the log-likelihood of the target sentence
245 | """
246 | src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device)
247 |
248 | src_encodings, dec_init_vec = self.encode(src_sents_var, [len(src_sent)])
249 | src_encodings_att_linear = self.att_projection(src_encodings)
250 |
251 | h_tm1 = dec_init_vec
252 | att_tm1 = torch.zeros(1, self.hidden_size, device=self.device)
253 |
254 | eos_id = self.vocab.tgt['']
255 |
256 | hypotheses = [['']]
257 | hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device)
258 | completed_hypotheses = []
259 |
260 | t = 0
261 | while len(completed_hypotheses) < beam_size and t < max_decoding_time_step:
262 | t += 1
263 | hyp_num = len(hypotheses)
264 |
265 | exp_src_encodings = src_encodings.expand(hyp_num,
266 | src_encodings.size(1),
267 | src_encodings.size(2))
268 |
269 | exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num,
270 | src_encodings_att_linear.size(1),
271 | src_encodings_att_linear.size(2))
272 |
273 | y_tm1 = torch.tensor([self.vocab.tgt[hyp[-1]] for hyp in hypotheses], dtype=torch.long, device=self.device)
274 | y_t_embed = self.model_embeddings.target(y_tm1)
275 |
276 | x = torch.cat([y_t_embed, att_tm1], dim=-1)
277 |
278 | (h_t, cell_t), att_t, _ = self.step(x, h_tm1,
279 | exp_src_encodings, exp_src_encodings_att_linear, enc_masks=None)
280 |
281 | # log probabilities over target words
282 | log_p_t = F.log_softmax(self.target_vocab_projection(att_t), dim=-1)
283 |
284 | live_hyp_num = beam_size - len(completed_hypotheses)
285 | contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1)
286 | top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num)
287 |
288 | prev_hyp_ids = top_cand_hyp_pos // len(self.vocab.tgt)
289 | hyp_word_ids = top_cand_hyp_pos % len(self.vocab.tgt)
290 |
291 | new_hypotheses = []
292 | live_hyp_ids = []
293 | new_hyp_scores = []
294 |
295 | for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
296 | prev_hyp_id = prev_hyp_id.item()
297 | hyp_word_id = hyp_word_id.item()
298 | cand_new_hyp_score = cand_new_hyp_score.item()
299 |
300 | hyp_word = self.vocab.tgt.id2word[hyp_word_id]
301 | new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
302 | if hyp_word == '':
303 | completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1],
304 | score=cand_new_hyp_score))
305 | else:
306 | new_hypotheses.append(new_hyp_sent)
307 | live_hyp_ids.append(prev_hyp_id)
308 | new_hyp_scores.append(cand_new_hyp_score)
309 |
310 | if len(completed_hypotheses) == beam_size:
311 | break
312 |
313 | live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device)
314 | h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
315 | att_tm1 = att_t[live_hyp_ids]
316 |
317 | hypotheses = new_hypotheses
318 | hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device)
319 |
320 | if len(completed_hypotheses) == 0:
321 | completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:],
322 | score=hyp_scores[0].item()))
323 |
324 | completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True)
325 |
326 | return completed_hypotheses
327 |
328 | @property
329 | def device(self) -> torch.device:
330 | """ Determine which device to place the Tensors upon, CPU or GPU.
331 | """
332 | return self.model_embeddings.source.weight.device
333 |
334 | @staticmethod
335 | def load(model_path: str):
336 | """ Load the model from a file.
337 | @param model_path (str): path to model
338 | """
339 | params = torch.load(model_path, map_location=lambda storage, loc: storage)
340 | args = params['args']
341 | model = NMT(vocab=params['vocab'], **args)
342 | model.load_state_dict(params['state_dict'])
343 |
344 | return model
345 |
346 | def save(self, path: str):
347 | """ Save the odel to a file.
348 | @param path (str): path to the model
349 | """
350 | print('save model parameters to [%s]' % path, file=sys.stderr)
351 |
352 | params = {
353 | 'args': dict(embed_size=self.model_embeddings.embed_size, hidden_size=self.hidden_size, dropout_rate=self.dropout_rate, rnn_layer=self.rnn_layer, num_layers=self.num_layers, activation=self.activation),
354 | 'vocab': self.vocab,
355 | 'state_dict': self.state_dict()
356 | }
357 |
358 | torch.save(params, path)
359 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Usage:
6 | run.py train --train-src= --train-tgt= --dev-src= --dev-tgt= --vocab= [options]
7 | run.py decode [options] MODEL_PATH TEST_SOURCE_FILE OUTPUT_FILE
8 | run.py decode [options] MODEL_PATH TEST_SOURCE_FILE TEST_TARGET_FILE OUTPUT_FILE
9 |
10 | Options:
11 | -h --help show this screen.
12 | --cuda use GPU
13 | --train-src= train source file
14 | --train-tgt= train target file
15 | --dev-src= dev source file
16 | --dev-tgt= dev target file
17 | --vocab= vocab file
18 | --seed= seed [default: 0]
19 | --batch-size= batch size [default: 32]
20 | --embed-size= embedding size [default: 256]
21 | --hidden-size= hidden size [default: 256]
22 | --clip-grad= gradient clipping [default: 5.0]
23 | --log-every= log every [default: 10]
24 | --max-epoch= max epoch [default: 30]
25 | --input-feed use input feeding
26 | --patience= wait for how many iterations to decay learning rate [default: 5]
27 | --max-num-trial= terminate training after how many trials [default: 5]
28 | --lr-decay= learning rate decay [default: 0.5]
29 | --beam-size= beam size [default: 5]
30 | --sample-size= sample size [default: 5]
31 | --lr= learning rate [default: 0.001]
32 | --uniform-init= uniformly initialize all parameters [default: 0.1]
33 | --save-to= model save path [default: model.bin]
34 | --valid-niter= perform validation after how many iterations [default: 2000]
35 | --dropout= dropout [default: 0.3]
36 | --max-decoding-time-step= maximum number of decoding time steps [default: 70]
37 | """
38 | import math
39 | import sys
40 | import time
41 |
42 |
43 | from docopt import docopt
44 | # from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
45 | import sacrebleu
46 | #from nmt_model import Hypothesis, NMT
47 | from nmt_model import Hypothesis, NMT
48 | import numpy as np
49 | from typing import List, Tuple, Dict, Set, Union
50 | from tqdm import tqdm
51 | from utils import read_corpus, batch_iter, read_sent_zh
52 | from vocab import Vocab
53 |
54 | import torch
55 | import torch.nn.utils
56 | from torch import nn
57 |
58 |
59 | def evaluate_ppl(model, dev_data, batch_size=32):
60 | """ Evaluate perplexity on dev sentences
61 | @param model (NMT): NMT Model
62 | @param dev_data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
63 | @param batch_size (batch size)
64 | @returns ppl (perplixty on dev sentences)
65 | """
66 | was_training = model.training
67 | model.eval()
68 |
69 | cum_loss = 0.
70 | cum_tgt_words = 0.
71 |
72 | # no_grad() signals backend to throw away all gradients
73 | with torch.no_grad():
74 | for src_sents, tgt_sents in batch_iter(dev_data, batch_size):
75 | loss = -model(src_sents, tgt_sents).sum()
76 |
77 | cum_loss += loss.item()
78 | tgt_word_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading ``
79 | cum_tgt_words += tgt_word_num_to_predict
80 |
81 | ppl = np.exp(cum_loss / cum_tgt_words)
82 |
83 | if was_training:
84 | model.train()
85 |
86 | return ppl
87 |
88 |
89 | def compute_corpus_level_bleu_score(references: List[List[str]], hypotheses: List[Hypothesis]) -> float:
90 | """ Given decoding results and reference sentences, compute corpus-level BLEU score.
91 | @param references (List[List[str]]): a list of gold-standard reference target sentences
92 | @param hypotheses (List[Hypothesis]): a list of hypotheses, one for each reference
93 | @returns bleu_score: corpus-level BLEU score
94 | """
95 | # remove the start and end tokens
96 | if references[0][0] == '':
97 | references = [ref[1:-1] for ref in references]
98 |
99 | # detokenize the subword pieces to get full sentences
100 | detokened_refs = [''.join(pieces).replace('▁', ' ') for pieces in references]
101 | detokened_hyps = [''.join(hyp.value).replace('▁', ' ') for hyp in hypotheses]
102 | print(detokened_refs)
103 | print(detokened_hyps)
104 | # sacreBLEU can take multiple references (golden example per sentence) but we only feed it one
105 | bleu = sacrebleu.corpus_bleu(detokened_hyps, [detokened_refs])
106 |
107 | return bleu.score
108 |
109 |
110 | def train(args: Dict):
111 | """ Train the NMT Model.
112 | @param args (Dict): args from cmd line
113 | """
114 | train_data_src = read_corpus(args['--train-src'], source='src', vocab_size=21000)
115 | train_data_tgt = read_corpus(args['--train-tgt'], source='tgt', vocab_size=8000)
116 |
117 | dev_data_src = read_corpus(args['--dev-src'], source='src', vocab_size=3000)
118 | dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt', vocab_size=2000)
119 |
120 | train_data = list(zip(train_data_src, train_data_tgt))
121 | dev_data = list(zip(dev_data_src, dev_data_tgt))
122 |
123 | train_batch_size = int(args['--batch-size'])
124 | clip_grad = float(args['--clip-grad'])
125 | valid_niter = int(args['--valid-niter'])
126 | log_every = int(args['--log-every'])
127 | model_save_path = args['--save-to']
128 |
129 | vocab = Vocab.load(args['--vocab'])
130 |
131 | model = NMT(embed_size=512,
132 | hidden_size=512,
133 | dropout_rate=float(args['--dropout']),
134 | vocab=vocab,
135 | rnn_layer=nn.LSTM,
136 | bidirectional=False)
137 |
138 |
139 | model.train()
140 |
141 | uniform_init = float(args['--uniform-init'])
142 | if np.abs(uniform_init) > 0.:
143 | print('uniformly initialize parameters [-%f, +%f]' % (uniform_init, uniform_init), file=sys.stderr)
144 | for p in model.parameters():
145 | p.data.uniform_(-uniform_init, uniform_init)
146 |
147 | vocab_mask = torch.ones(len(vocab.tgt))
148 | vocab_mask[vocab.tgt['']] = 0
149 |
150 | device = torch.device("cuda:0" if args['--cuda'] else "cpu")
151 | print('use device: %s' % device, file=sys.stderr)
152 |
153 | model = model.to(device)
154 |
155 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))
156 |
157 | num_trial = 0
158 | train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
159 | cum_examples = report_examples = epoch = valid_num = 0
160 | hist_valid_scores = []
161 | train_time = begin_time = time.time()
162 | print('begin Maximum Likelihood training')
163 | train_ppl_log = open("ppl.log", "w")
164 | dev_ppl_log = open("dev_ppl.log", "w")
165 | while True:
166 | epoch += 1
167 |
168 | for src_sents, tgt_sents in batch_iter(train_data, batch_size=train_batch_size, shuffle=True):
169 | train_iter += 1
170 |
171 | optimizer.zero_grad()
172 |
173 | batch_size = len(src_sents)
174 |
175 | example_losses = -model(src_sents, tgt_sents) # (batch_size,)
176 | batch_loss = example_losses.sum()
177 | loss = batch_loss / batch_size
178 |
179 | loss.backward()
180 |
181 | # clip gradient
182 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
183 |
184 | optimizer.step()
185 |
186 | batch_losses_val = batch_loss.item()
187 | report_loss += batch_losses_val
188 | cum_loss += batch_losses_val
189 |
190 | tgt_words_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading ``
191 | report_tgt_words += tgt_words_num_to_predict
192 | cum_tgt_words += tgt_words_num_to_predict
193 | report_examples += batch_size
194 | cum_examples += batch_size
195 |
196 | if train_iter % log_every == 0:
197 | print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
198 | 'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
199 | report_loss / report_examples,
200 | math.exp(report_loss / report_tgt_words),
201 | cum_examples,
202 | report_tgt_words / (time.time() - train_time),
203 | time.time() - begin_time), file=sys.stderr)
204 | train_ppl_log.write("{} {}\n".format(train_iter, math.exp(report_loss / report_tgt_words)))
205 | train_time = time.time()
206 | report_loss = report_tgt_words = report_examples = 0.
207 |
208 | # perform validation
209 | if train_iter % valid_niter == 0:
210 | print('epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d' % (epoch, train_iter,
211 | cum_loss / cum_examples,
212 | np.exp(cum_loss / cum_tgt_words),
213 | cum_examples), file=sys.stderr)
214 |
215 | cum_loss = cum_examples = cum_tgt_words = 0.
216 | valid_num += 1
217 |
218 | print('begin validation ...', file=sys.stderr)
219 |
220 | # compute dev. ppl and bleu
221 | dev_ppl = evaluate_ppl(model, dev_data, batch_size=128) # dev batch size can be a bit larger
222 | valid_metric = -dev_ppl
223 |
224 | print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), file=sys.stderr)
225 | dev_ppl_log.write("{} {}\n".format(train_iter, dev_ppl))
226 | is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
227 | hist_valid_scores.append(valid_metric)
228 |
229 | if is_better:
230 | patience = 0
231 | print('save currently the best model to [%s]' % model_save_path, file=sys.stderr)
232 | model.save(model_save_path)
233 |
234 | # also save the optimizers' state
235 | torch.save(optimizer.state_dict(), model_save_path + '.optim')
236 | elif patience < int(args['--patience']):
237 | patience += 1
238 | print('hit patience %d' % patience, file=sys.stderr)
239 |
240 | if patience == int(args['--patience']):
241 | num_trial += 1
242 | print('hit #%d trial' % num_trial, file=sys.stderr)
243 | if num_trial == int(args['--max-num-trial']):
244 | print('early stop!', file=sys.stderr)
245 | exit(0)
246 |
247 | # decay lr, and restore from previously best checkpoint
248 | lr = optimizer.param_groups[0]['lr'] * float(args['--lr-decay'])
249 | print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)
250 |
251 | # load model
252 | params = torch.load(model_save_path, map_location=lambda storage, loc: storage)
253 | model.load_state_dict(params['state_dict'])
254 | model = model.to(device)
255 |
256 | print('restore parameters of the optimizers', file=sys.stderr)
257 | optimizer.load_state_dict(torch.load(model_save_path + '.optim'))
258 |
259 | # set new lr
260 | for param_group in optimizer.param_groups:
261 | param_group['lr'] = lr
262 |
263 | # reset patience
264 | patience = 0
265 |
266 | if epoch == int(args['--max-epoch']):
267 | print('reached maximum number of epochs!', file=sys.stderr)
268 | exit(0)
269 |
270 |
271 | def decode(args: Dict[str, str]):
272 | """ Performs decoding on a test set, and save the best-scoring decoding results.
273 | If the target gold-standard sentences are given, the function also computes
274 | corpus-level BLEU score.
275 | @param args (Dict): args from cmd line
276 | """
277 |
278 | print("load test source sentences from [{}]".format(args['TEST_SOURCE_FILE']), file=sys.stderr)
279 | test_data_src = read_corpus(args['TEST_SOURCE_FILE'], source='src', vocab_size=3000)
280 | if args['TEST_TARGET_FILE']:
281 | print("load test target sentences from [{}]".format(args['TEST_TARGET_FILE']), file=sys.stderr)
282 | test_data_tgt = read_corpus(args['TEST_TARGET_FILE'], source='tgt', vocab_size=2000)
283 | print("load model from {}".format(args['MODEL_PATH']), file=sys.stderr)
284 | model = NMT.load(args['MODEL_PATH'])
285 |
286 | if args['--cuda']:
287 | model = model.to(torch.device("cuda:0"))
288 |
289 | hypotheses = beam_search(model, test_data_src,
290 | beam_size=10,
291 | max_decoding_time_step=int(args['--max-decoding-time-step']))
292 |
293 | if args['TEST_TARGET_FILE']:
294 | top_hypotheses = [hyps[0] for hyps in hypotheses]
295 | bleu_score = compute_corpus_level_bleu_score(test_data_tgt, top_hypotheses)
296 | print('Corpus BLEU: {}'.format(bleu_score), file=sys.stderr)
297 |
298 | with open(args['OUTPUT_FILE'], 'w') as f:
299 | for src_sent, hyps in zip(test_data_src, hypotheses):
300 | top_hyp = hyps[0]
301 | src_sent = ''.join(src_sent).replace('_', ' ')
302 | hyp_sent = ''.join(top_hyp.value).replace('▁', ' ')
303 | f.write(src_sent+'\n'+hyp_sent + '\n\n')
304 |
305 | def beam_search(model: NMT, test_data_src: List[List[str]], beam_size: int, max_decoding_time_step: int) -> List[List[Hypothesis]]:
306 | """ Run beam search to construct hypotheses for a list of src-language sentences.
307 | @param model (NMT): NMT Model
308 | @param test_data_src (List[List[str]]): List of sentences (words) in source language, from test set.
309 | @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step)
310 | @param max_decoding_time_step (int): maximum sentence length that Beam search can produce
311 | @returns hypotheses (List[List[Hypothesis]]): List of Hypothesis translations for every source sentence.
312 | """
313 | was_training = model.training
314 | model.eval()
315 |
316 | hypotheses = []
317 | with torch.no_grad():
318 | for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout):
319 | example_hyps = model.beam_search(src_sent, beam_size=beam_size, max_decoding_time_step=max_decoding_time_step)
320 |
321 | hypotheses.append(example_hyps)
322 |
323 | if was_training:
324 | model.train(was_training)
325 |
326 | return hypotheses
327 |
328 |
329 | if __name__ == '__main__':
330 | args = docopt(__doc__)
331 | # Check pytorch version
332 | assert(torch.__version__ >= "1.0.0"), "Please update your installation of PyTorch. You have {} and you should have version 1.0.0".format(torch.__version__)
333 |
334 | # seed the random number generators
335 | seed = int(args['--seed'])
336 | torch.manual_seed(seed)
337 | if args['--cuda']:
338 | torch.cuda.manual_seed(seed)
339 | np.random.seed(seed * 13 // 7)
340 | if args['train']:
341 | train(args)
342 | elif args['decode']:
343 | decode(args)
344 | else:
345 | raise RuntimeError('invalid run mode')
346 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # PLEASE change all "./zh_en_data" to the path where your data is stored
3 |
4 | if [ "$1" = "train" ]; then
5 | CUDA_VISIBLE_DEVICES=0 python run.py train --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en --dev-src=./zh_en_data/dev.zh --dev-tgt=./zh_en_data/dev.en --vocab=./zh_en_data/vocab_zh_en.json --cuda --lr=5e-4 --patience=1 --valid-niter=1000 --batch-size=32 --dropout=.25
6 | elif [ "$1" = "test" ]; then
7 | if [ "$2" = "" ]; then
8 | CUDA_VISIBLE_DEVICES=0 python run.py decode model.bin ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt --cuda
9 | else
10 | CUDA_VISIBLE_DEVICES=0 python run.py decode $2 ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt --cuda
11 | fi
12 | elif [ "$1" = "train_local" ]; then
13 | python run.py train --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en --dev-src=./zh_en_data/dev.zh --dev-tgt=./zh_en_data/dev.en --vocab=./zh_en_data/vocab_zh_en.json --lr=5e-4
14 | elif [ "$1" = "test_local" ]; then
15 | python run.py decode model.bin ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt
16 | elif [ "$1" = "vocab" ]; then
17 | python vocab.py --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en ./zh_en_data/vocab_zh_en.json
18 | else
19 | echo "Invalid Option Selected"
20 | fi
21 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import math
5 |
6 | import numpy as np
7 | import nltk
8 | import sentencepiece as spm
9 | import jieba
10 |
11 | nltk.download('punkt')
12 |
13 |
14 | def pad_sents(sents, pad_token):
15 | """ Pad list of sentences according to the longest sentence in the batch.
16 | The paddings should be at the end of each sentence.
17 | @param sents (list[list[str]]): list of sentences, where each sentence
18 | is represented as a list of words
19 | @param pad_token (str): padding token
20 | @returns sents_padded (list[list[str]]): list of sentences where sentences shorter
21 | than the max length sentence are padded out with the pad_token, such that
22 | each sentences in the batch now has equal length.
23 | """
24 | sents_padded = []
25 |
26 | ### YOUR CODE HERE (~6 Lines)
27 | max_len = max([len(each) for each in sents])
28 | for sent in sents:
29 | sent += [pad_token] * (max_len - len(sent))
30 | sents_padded.append(sent)
31 |
32 | ### END YOUR CODE
33 |
34 | return sents_padded
35 |
36 |
37 | def read_corpus(file_path, source, vocab_size=2500):
38 | """ Read file, where each sentence is dilineated by a `\n`.
39 | @param file_path (str): path to file containing corpus
40 | @param source (str): "tgt" or "src" indicating whether text
41 | is of the source language or target language
42 | @param vocab_size (int): number of unique subwords in
43 | vocabulary when reading and tokenizing
44 | """
45 | data = []
46 | sp = spm.SentencePieceProcessor()
47 | sp.load('{}.model'.format(source))
48 |
49 | with open(file_path, 'r', encoding='utf8') as f:
50 | for line in f:
51 | subword_tokens = sp.encode_as_pieces(line)
52 |
53 | # only append and to the target sentence
54 | if source == 'tgt':
55 | subword_tokens = [""] + subword_tokens + [""]
56 |
57 | data.append(subword_tokens)
58 |
59 | return data
60 |
61 |
62 | def read_sent_zh(sent, source):
63 | """ Read a Chinese sentence, seperate the words using jieba, and generate tokens
64 | @param sent (str): path to file containing corpus
65 | @param source (str): "tgt" or "src" for selecting sp model
66 | """
67 | sp = spm.SentencePieceProcessor()
68 | sp.load('{}.model'.format(source))
69 |
70 | sent = " ".join(jieba.cut(sent, HMM=True))
71 | subword_tokens = sp.encode_as_pieces(sent)
72 |
73 | return subword_tokens
74 |
75 |
76 | def batch_iter(data, batch_size, shuffle=False):
77 | """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
78 | @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
79 | @param batch_size (int): batch size
80 | @param shuffle (boolean): whether to randomly shuffle the dataset
81 | """
82 | batch_num = math.ceil(len(data) / batch_size)
83 | index_array = list(range(len(data)))
84 |
85 | if shuffle:
86 | np.random.shuffle(index_array)
87 |
88 | for i in range(batch_num):
89 | indices = index_array[i * batch_size: (i + 1) * batch_size]
90 | examples = [data[idx] for idx in indices]
91 |
92 | examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
93 | examples = [examples[i] for i in range(len(examples)) if len(examples[i][0]) > 0 and len(examples[i][1]) > 0]
94 |
95 | src_sents = [e[0] for e in examples]
96 | tgt_sents = [e[1] for e in examples]
97 | yield src_sents, tgt_sents
98 |
99 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Usage:
6 | vocab.py --train-src= --train-tgt= [options] VOCAB_FILE
7 |
8 | Options:
9 | -h --help Show this screen.
10 | --train-src= File of training source sentences
11 | --train-tgt= File of training target sentences
12 | --size= vocab size [default: 50000]
13 | --freq-cutoff= frequency cutoff [default: 2]
14 | """
15 |
16 | from collections import Counter
17 | from docopt import docopt
18 | from itertools import chain
19 | import json
20 | import torch
21 | from typing import List
22 | from utils import pad_sents
23 | import sentencepiece as spm
24 |
25 |
26 | class VocabEntry(object):
27 | """ Vocabulary Entry, i.e. structure containing either
28 | src or tgt language terms.
29 | """
30 | def __init__(self, word2id=None):
31 | """ Init VocabEntry Instance.
32 | @param word2id (dict): dictionary mapping words 2 indices
33 | """
34 | if word2id:
35 | self.word2id = word2id
36 | else:
37 | self.word2id = dict()
38 | self.word2id[''] = 0 # Pad Token
39 | self.word2id[''] = 1 # Start Token
40 | self.word2id[''] = 2 # End Token
41 | self.word2id[''] = 3 # Unknown Token
42 | self.unk_id = self.word2id['']
43 | self.id2word = {v: k for k, v in self.word2id.items()}
44 |
45 | def __getitem__(self, word):
46 | """ Retrieve word's index. Return the index for the unk
47 | token if the word is out of vocabulary.
48 | @param word (str): word to look up.
49 | @returns index (int): index of word
50 | """
51 | return self.word2id.get(word, self.unk_id)
52 |
53 | def __contains__(self, word):
54 | """ Check if word is captured by VocabEntry.
55 | @param word (str): word to look up
56 | @returns contains (bool): whether word is contained
57 | """
58 | return word in self.word2id
59 |
60 | def __len__(self):
61 | """ Compute number of words in VocabEntry.
62 | @returns len (int): number of words in VocabEntry
63 | """
64 | return len(self.word2id)
65 |
66 | def __repr__(self):
67 | """ Representation of VocabEntry to be used
68 | when printing the object.
69 | """
70 | return 'Vocabulary[size=%d]' % len(self)
71 |
72 | def id2word(self, wid):
73 | """ Return mapping of index to word.
74 | @param wid (int): word index
75 | @returns word (str): word corresponding to index
76 | """
77 | return self.id2word[wid]
78 |
79 | def add(self, word):
80 | """ Add word to VocabEntry, if it is previously unseen.
81 | @param word (str): word to add to VocabEntry
82 | @return index (int): index that the word has been assigned
83 | """
84 | if word not in self:
85 | wid = self.word2id[word] = len(self)
86 | self.id2word[wid] = word
87 | return wid
88 | else:
89 | return self[word]
90 |
91 | def words2indices(self, sents):
92 | """ Convert list of words or list of sentences of words
93 | into list or list of list of indices.
94 | @param sents (list[str] or list[list[str]]): sentence(s) in words
95 | @return word_ids (list[int] or list[list[int]]): sentence(s) in indices
96 | """
97 | if type(sents[0]) == list:
98 | return [[self[w] for w in s] for s in sents]
99 | else:
100 | return [self[w] for w in sents]
101 |
102 | def indices2words(self, word_ids):
103 | """ Convert list of indices into words.
104 | @param word_ids (list[int]): list of word ids
105 | @return sents (list[str]): list of words
106 | """
107 | return [self.id2word[w_id] for w_id in word_ids]
108 |
109 | def to_input_tensor(self, sents: List[List[str]], device: torch.device) -> torch.Tensor:
110 | """ Convert list of sentences (words) into tensor with necessary padding for
111 | shorter sentences.
112 |
113 | @param sents (List[List[str]]): list of sentences (words)
114 | @param device: device on which to load the tesnor, i.e. CPU or GPU
115 |
116 | @returns sents_var: tensor of (max_sentence_length, batch_size)
117 | """
118 | word_ids = self.words2indices(sents)
119 | sents_t = pad_sents(word_ids, self[''])
120 | sents_var = torch.tensor(sents_t, dtype=torch.long, device=device)
121 | return torch.t(sents_var)
122 |
123 | @staticmethod
124 | def from_corpus(corpus, size, freq_cutoff=2):
125 | """ Given a corpus construct a Vocab Entry.
126 | @param corpus (list[str]): corpus of text produced by read_corpus function
127 | @param size (int): # of words in vocabulary
128 | @param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word
129 | @returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus
130 | """
131 | vocab_entry = VocabEntry()
132 | word_freq = Counter(chain(*corpus))
133 | valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
134 | print('number of word types: {}, number of word types w/ frequency >= {}: {}'
135 | .format(len(word_freq), freq_cutoff, len(valid_words)))
136 | top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
137 | for word in top_k_words:
138 | vocab_entry.add(word)
139 | return vocab_entry
140 |
141 | @staticmethod
142 | def from_subword_list(subword_list):
143 | vocab_entry = VocabEntry()
144 | for subword in subword_list:
145 | vocab_entry.add(subword)
146 | return vocab_entry
147 |
148 |
149 | class Vocab(object):
150 | """ Vocab encapsulating src and target langauges.
151 | """
152 | def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry):
153 | """ Init Vocab.
154 | @param src_vocab (VocabEntry): VocabEntry for source language
155 | @param tgt_vocab (VocabEntry): VocabEntry for target language
156 | """
157 | self.src = src_vocab
158 | self.tgt = tgt_vocab
159 |
160 | @staticmethod
161 | def build(src_sents, tgt_sents) -> 'Vocab':
162 | """ Build Vocabulary.
163 | @param src_sents (list[str]): Source subwords provided by SentencePiece
164 | @param tgt_sents (list[str]): Target subwords provided by SentencePiece
165 | """
166 |
167 | print('initialize source vocabulary ..')
168 | src = VocabEntry.from_subword_list(src_sents)
169 |
170 | print('initialize target vocabulary ..')
171 | tgt = VocabEntry.from_subword_list(tgt_sents)
172 |
173 | return Vocab(src, tgt)
174 |
175 | def save(self, file_path):
176 | """ Save Vocab to file as JSON dump.
177 | @param file_path (str): file path to vocab file
178 | """
179 | with open(file_path, 'w') as f:
180 | json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), f, indent=2)
181 |
182 | @staticmethod
183 | def load(file_path):
184 | """ Load vocabulary from JSON dump.
185 | @param file_path (str): file path to vocab file
186 | @returns Vocab object loaded from JSON dump
187 | """
188 | entry = json.load(open(file_path, 'r'))
189 | src_word2id = entry['src_word2id']
190 | tgt_word2id = entry['tgt_word2id']
191 | return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id))
192 |
193 | def __repr__(self):
194 | """ Representation of Vocab to be used
195 | when printing the object.
196 | """
197 | return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
198 |
199 |
200 | def get_vocab_list(file_path, source, vocab_size):
201 | """ Use SentencePiece to tokenize and acquire list of unique subwords.
202 | @param file_path (str): file path to corpus
203 | @param source (str): tgt or src
204 | @param vocab_size: desired vocabulary size
205 | """
206 | spm.SentencePieceTrainer.train(input=file_path, model_prefix=source, vocab_size=vocab_size) # train the spm model
207 | sp = spm.SentencePieceProcessor() # create an instance; this saves .model and .vocab files
208 | sp.load('{}.model'.format(source)) # loads tgt.model or src.model
209 | sp_list = [sp.id_to_piece(piece_id) for piece_id in range(sp.get_piece_size())] # this is the list of subwords
210 | return sp_list
211 |
212 |
213 |
214 | if __name__ == '__main__':
215 | args = docopt(__doc__)
216 |
217 | print('read in source sentences: %s' % args['--train-src'])
218 | print('read in target sentences: %s' % args['--train-tgt'])
219 |
220 | src_sents = get_vocab_list(args['--train-src'], source='src', vocab_size=21000)
221 | tgt_sents = get_vocab_list(args['--train-tgt'], source='tgt', vocab_size=8000)
222 | vocab = Vocab.build(src_sents, tgt_sents)
223 | print('generated vocabulary, source %d words, target %d words' % (len(src_sents), len(tgt_sents)))
224 |
225 | vocab.save(args['VOCAB_FILE'])
226 | print('vocabulary saved to %s' % args['VOCAB_FILE'])
227 |
--------------------------------------------------------------------------------