├── README.md ├── config.json ├── data.py ├── data ├── de_train.sample ├── de_valid.sample ├── en_train.sample └── en_valid.sample ├── train_model.py └── train_tokenizers.py /README.md: -------------------------------------------------------------------------------- 1 | # encdecmodel-hf 2 | 3 | Sample code to demonstrate training EncoderDecoderModels using the Huggingface transformers library. 4 | 5 | transformers version: 2.11.0 6 | torch version: 1.5.0 7 | 8 | The dataset used is available at: https://nlp.stanford.edu/projects/nmt/ 9 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "global_params" : { 3 | "train_en_file" : "data/en_train.sample", 4 | "train_de_file" : "data/de_train.sample", 5 | "valid_en_file" : "data/en_valid.sample", 6 | "valid_de_file" : "data/de_valid.sample" 7 | }, 8 | 9 | "encoder_params" : { 10 | "tokenizer_path" : "tokenizers/en_tok/", 11 | "vocab_size" : 25000, 12 | "min_freq" : 3, 13 | "max_length" : 512, 14 | "num_attn_heads" : 8, 15 | "num_hidden_layers" : 8, 16 | "hidden_size" : 512 17 | }, 18 | 19 | "decoder_params" : { 20 | "tokenizer_path" : "tokenizers/de_tok/", 21 | "vocab_size" : 25000, 22 | "min_freq" : 3, 23 | "max_length" : 256, 24 | "num_attn_heads" : 8, 25 | "num_hidden_layers" : 8, 26 | "hidden_size" : 512 27 | }, 28 | 29 | "model_params" : { 30 | "batch_size" : 16, 31 | "num_epochs" : 1, 32 | "lr": 0.0001, 33 | "model_path" : "models/", 34 | "model_name": "encdec.mdl" 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.utils.data as data 4 | from torch.nn.utils.rnn import pad_sequence 5 | import os 6 | 7 | class TranslationDataset(data.Dataset): 8 | 9 | def __init__(self, inp_file, targ_file, inp_tokenizer, targ_tokenizer, inp_maxlength, targ_maxlength): 10 | 11 | self.inp_tokenizer = inp_tokenizer 12 | self.targ_tokenizer = targ_tokenizer 13 | self.inp_maxlength = inp_maxlength 14 | self.targ_maxlength = targ_maxlength 15 | 16 | print("Loading and Tokenizing the data ...") 17 | self.encoded_inp = [] 18 | self.encoded_targ = [] 19 | 20 | # Read the EN lines 21 | num_inp_lines = 0 22 | with open(inp_file, "r") as ef: 23 | for line in ef: 24 | enc = self.inp_tokenizer.encode(line.strip(), add_special_tokens=True, max_length=self.inp_maxlength) 25 | self.encoded_inp.append(torch.tensor(enc)) 26 | num_inp_lines += 1 27 | 28 | # read the DE lines 29 | num_targ_lines = 0 30 | with open(targ_file, "r") as df: 31 | for line in df: 32 | enc = self.targ_tokenizer.encode(line.strip(), add_special_tokens=True, max_length=self.targ_maxlength) 33 | self.encoded_targ.append(torch.tensor(enc)) 34 | num_targ_lines += 1 35 | 36 | assert (num_inp_lines==num_targ_lines), "Mismatch in EN and DE lines" 37 | print("Read", num_inp_lines, "lines from EN and DE files.") 38 | 39 | def __getitem__(self, offset): 40 | en = self.encoded_inp[offset] 41 | de = self.encoded_targ[offset] 42 | 43 | return en, en.shape[0], de, de.shape[0] 44 | 45 | def __len__(self): 46 | return len(self.encoded_inp) 47 | 48 | def collate_function(self, batch): 49 | 50 | (inputs, inp_lengths, targets, targ_lengths) = zip(*batch) 51 | 52 | padded_inputs = self._collate_helper(inputs, self.inp_tokenizer) 53 | padded_targets = self._collate_helper(targets, self.targ_tokenizer) 54 | 55 | max_inp_seq_len = padded_inputs.shape[1] 56 | max_out_seq_len = padded_targets.shape[1] 57 | 58 | input_masks = [[1]*l + [0]*(max_inp_seq_len-l) for l in inp_lengths] 59 | target_masks = [[1]*l + [0]*(max_out_seq_len-l) for l in targ_lengths] 60 | 61 | input_tensor = padded_inputs.to(torch.int64) 62 | target_tensor = padded_targets.to(torch.int64) 63 | input_masks = torch.Tensor(input_masks) 64 | target_masks = torch.Tensor(target_masks) 65 | 66 | return input_tensor, input_masks, target_tensor, target_masks 67 | 68 | def _collate_helper(self, examples, tokenizer): 69 | length_of_first = examples[0].size(0) 70 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 71 | if are_tensors_same_length: 72 | return torch.stack(examples, dim=0) 73 | else: 74 | if tokenizer._pad_token is None: 75 | raise ValueError( 76 | "You are attempting to pad samples but the tokenizer you are using" 77 | f" ({tokenizer.__class__.__name__}) does not have one." 78 | ) 79 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 80 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | from data import TranslationDataset 10 | from transformers import BertTokenizerFast 11 | from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel 12 | 13 | # Identify the config file 14 | if len(sys.argv) < 2: 15 | print("No config file specified. Using the default config.") 16 | configfile = "config.json" 17 | else: 18 | configfile = sys.argv[1] 19 | 20 | # Read the params 21 | with open(configfile, "r") as f: 22 | config = json.load(f) 23 | 24 | globalparams = config["global_params"] 25 | encparams = config["encoder_params"] 26 | decparams = config["decoder_params"] 27 | modelparams = config["model_params"] 28 | 29 | # Load the tokenizers 30 | en_tok_path = encparams["tokenizer_path"] 31 | en_tokenizer = BertTokenizerFast(os.path.join(en_tok_path, "vocab.txt")) 32 | de_tok_path = decparams["tokenizer_path"] 33 | de_tokenizer = BertTokenizerFast(os.path.join(de_tok_path, "vocab.txt")) 34 | 35 | # Init the dataset 36 | train_en_file = globalparams["train_en_file"] 37 | train_de_file = globalparams["train_de_file"] 38 | valid_en_file = globalparams["valid_en_file"] 39 | valid_de_file = globalparams["valid_de_file"] 40 | 41 | enc_maxlength = encparams["max_length"] 42 | dec_maxlength = decparams["max_length"] 43 | 44 | batch_size = modelparams["batch_size"] 45 | train_dataset = TranslationDataset(train_en_file, train_de_file, en_tokenizer, de_tokenizer, enc_maxlength, dec_maxlength) 46 | train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, \ 47 | drop_last=True, num_workers=1, collate_fn=train_dataset.collate_function) 48 | 49 | valid_dataset = TranslationDataset(valid_en_file, valid_de_file, en_tokenizer, de_tokenizer, enc_maxlength, dec_maxlength) 50 | valid_dataloader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, \ 51 | drop_last=True, num_workers=1, collate_fn=valid_dataset.collate_function) 52 | 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | print("Using device:", device) 55 | 56 | print("Loading models ..") 57 | vocabsize = encparams["vocab_size"] 58 | max_length = encparams["max_length"] 59 | encoder_config = BertConfig(vocab_size = vocabsize, 60 | max_position_embeddings = max_length+64, # this shuold be some large value 61 | num_attention_heads = encparams["num_attn_heads"], 62 | num_hidden_layers = encparams["num_hidden_layers"], 63 | hidden_size = encparams["hidden_size"], 64 | type_vocab_size = 1) 65 | 66 | encoder = BertModel(config=encoder_config) 67 | 68 | vocabsize = decparams["vocab_size"] 69 | max_length = decparams["max_length"] 70 | decoder_config = BertConfig(vocab_size = vocabsize, 71 | max_position_embeddings = max_length+64, # this shuold be some large value 72 | num_attention_heads = decparams["num_attn_heads"], 73 | num_hidden_layers = decparams["num_hidden_layers"], 74 | hidden_size = decparams["hidden_size"], 75 | type_vocab_size = 1, 76 | is_decoder=True) # Very Important 77 | 78 | decoder = BertForMaskedLM(config=decoder_config) 79 | 80 | # Define encoder decoder model 81 | model = EncoderDecoderModel(encoder=encoder, decoder=decoder) 82 | model.to(device) 83 | 84 | def count_parameters(mdl): 85 | return sum(p.numel() for p in mdl.parameters() if p.requires_grad) 86 | 87 | print(f'The encoder has {count_parameters(encoder):,} trainable parameters') 88 | print(f'The decoder has {count_parameters(decoder):,} trainable parameters') 89 | print(f'The model has {count_parameters(model):,} trainable parameters') 90 | 91 | optimizer = optim.Adam(model.parameters(), lr=modelparams['lr']) 92 | criterion = nn.NLLLoss(ignore_index=de_tokenizer.pad_token_id) 93 | 94 | num_train_batches = len(train_dataloader) 95 | num_valid_batches = len(valid_dataloader) 96 | 97 | def compute_loss(predictions, targets): 98 | """Compute our custom loss""" 99 | predictions = predictions[:, :-1, :].contiguous() 100 | targets = targets[:, 1:] 101 | 102 | rearranged_output = predictions.view(predictions.shape[0]*predictions.shape[1], -1) 103 | rearranged_target = targets.contiguous().view(-1) 104 | 105 | loss = criterion(rearranged_output, rearranged_target) 106 | 107 | return loss 108 | 109 | def train_model(): 110 | model.train() 111 | epoch_loss = 0 112 | 113 | for i, (en_input, en_masks, de_output, de_masks) in enumerate(train_dataloader): 114 | 115 | optimizer.zero_grad() 116 | 117 | en_input = en_input.to(device) 118 | de_output = de_output.to(device) 119 | en_masks = en_masks.to(device) 120 | de_masks = de_masks.to(device) 121 | 122 | lm_labels = de_output.clone() 123 | out = model(input_ids=en_input, attention_mask=en_masks, 124 | decoder_input_ids=de_output, decoder_attention_mask=de_masks,lm_labels=lm_labels) 125 | prediction_scores = out[1] 126 | predictions = F.log_softmax(prediction_scores, dim=2) 127 | loss = compute_loss(predictions, de_output) 128 | 129 | loss.backward() 130 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 131 | optimizer.step() 132 | 133 | epoch_loss += loss.item() 134 | 135 | print("Mean epoch loss:", (epoch_loss / num_train_batches)) 136 | 137 | def eval_model(): 138 | model.eval() 139 | epoch_loss = 0 140 | 141 | for i, (en_input, en_masks, de_output, de_masks) in enumerate(train_dataloader): 142 | 143 | optimizer.zero_grad() 144 | 145 | en_input = en_input.to(device) 146 | de_output = de_output.to(device) 147 | en_masks = en_masks.to(device) 148 | de_masks = de_masks.to(device) 149 | 150 | lm_labels = de_output.clone() 151 | 152 | out = model(input_ids=en_input, attention_mask=en_masks, 153 | decoder_input_ids=de_output, decoder_attention_mask=de_masks,lm_labels=lm_labels) 154 | 155 | prediction_scores = out[1] 156 | predictions = F.log_softmax(prediction_scores, dim=2) 157 | loss = compute_loss(predictions, de_output) 158 | epoch_loss += loss.item() 159 | 160 | print("Mean validation loss:", (epoch_loss / num_valid_batches)) 161 | 162 | 163 | # MAIN TRAINING LOOP 164 | for epoch in range(modelparams['num_epochs']): 165 | print("Starting epoch", epoch+1) 166 | train_model() 167 | eval_model() 168 | 169 | print("Saving model ..") 170 | save_location = modelparams['model_path'] 171 | model_name = modelparams['model_name'] 172 | if not os.path.exists(save_location): 173 | os.makedirs(save_location) 174 | save_location = os.path.join(save_location, model_name) 175 | torch.save(model, save_location) 176 | -------------------------------------------------------------------------------- /train_tokenizers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | from tokenizers import BertWordPieceTokenizer 5 | from tokenizers.processors import BertProcessing 6 | 7 | def train_tokenizer(filename, params): 8 | """ 9 | Train a BertWordPieceTokenizer with the specified params and save it 10 | """ 11 | # Get tokenization params 12 | save_location = params["tokenizer_path"] 13 | max_length = params["max_length"] 14 | min_freq = params["min_freq"] 15 | vocabsize = params["vocab_size"] 16 | 17 | tokenizer = BertWordPieceTokenizer() 18 | tokenizer.do_lower_case = False 19 | special_tokens = ["[S]","[PAD]","[/S]","[UNK]","[MASK]", "[SEP]","[CLS]"] 20 | tokenizer.train(files=[filename], vocab_size=vocabsize, min_frequency=min_freq, special_tokens = special_tokens) 21 | 22 | tokenizer._tokenizer.post_processor = BertProcessing(("[SEP]", tokenizer.token_to_id("[SEP]")), ("[CLS]", tokenizer.token_to_id("[CLS]")),) 23 | tokenizer.enable_truncation(max_length=max_length) 24 | 25 | print("Saving tokenizer ...") 26 | if not os.path.exists(save_location): 27 | os.makedirs(save_location) 28 | tokenizer.save(save_location) 29 | 30 | # Identify the config to use 31 | if len(sys.argv) < 2: 32 | print("No config file specified. Using the default config.") 33 | configfile = "config.json" 34 | else: 35 | configfile = sys.argv[1] 36 | 37 | # Read the params 38 | with open(configfile, "r") as f: 39 | config = json.load(f) 40 | 41 | globalparams = config["global_params"] 42 | encparams = config["encoder_params"] 43 | decparams = config["decoder_params"] 44 | 45 | # Get the dataset files 46 | train_en_file = globalparams["train_en_file"] 47 | train_de_file = globalparams["train_de_file"] 48 | 49 | # Train the tokenizers 50 | train_tokenizer(train_en_file, encparams) 51 | train_tokenizer(train_de_file, decparams) 52 | --------------------------------------------------------------------------------