├── README.md ├── transformer_discriminator.py ├── transformer_elmo.py ├── char_cnn_discriminator.py ├── utils.py ├── Embedder.py ├── my_transformer_cyclegan.py ├── v3_cyclegan.py ├── v2_cyclegan.py ├── attn.py ├── v4_cyclegan.py ├── maskgan.py ├── cyclegan.py ├── maskgan.1.py └── v5_cyclegan.py /README.md: -------------------------------------------------------------------------------- 1 | # Transformer_CycleGAN_Text_Style_Transfer-pytorch 2 | Implementation of CycleGAN for Text style transfer with PyTorch. 3 | 4 | ## CycleGAN Architecture 5 | ![](https://i.imgur.com/tHl26oG.png) 6 | 7 | ## Model Detail 8 | ![](https://i.imgur.com/UrLR9qS.png) 9 | ![](https://i.imgur.com/tmMivIm.png) 10 | 11 | - We simply use softmax weighted sum over all word vectors to overcome the discrete gradient issue in GAN training process. 12 | -------------------------------------------------------------------------------- /transformer_discriminator.py: -------------------------------------------------------------------------------- 1 | from attn import Transformer, LabelSmoothing, \ 2 | data_gen, NoamOpt, Generator, SimpleLossCompute, \ 3 | greedy_decode, subsequent_mask, TransformerEncoder 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, word_dim, inner_dim, seq_len, N=3, device="cuda:0", kernel_size=3): 10 | super(Discriminator, self).__init__() 11 | self.encoder = TransformerEncoder(N, word_dim, inner_dim) 12 | self.seq_len = seq_len 13 | self.conv_1 = nn.Conv1d(word_dim, inner_dim, kernel_size, padding=int((kernel_size-1)/2)) 14 | self.conv_2 = nn.Conv1d(inner_dim, inner_dim, kernel_size, padding=int((kernel_size-1)/2)) 15 | W = seq_len*inner_dim 16 | self.fc_1 = nn.Linear(W, int(W/8)) 17 | self.fc_2 = nn.Linear(int(W/8), int(W/32)) 18 | self.fc_3 = nn.Linear(int(W/32), int(W/64)) 19 | self.fc_4 = nn.Linear(int(W / 64), 2) 20 | self.relu = nn.LeakyReLU() 21 | 22 | def feed_fc(self, inputs): 23 | output = self.relu(self.fc_1(inputs)) 24 | output = self.relu(self.fc_2(output)) 25 | output = self.relu(self.fc_3(output)) 26 | return self.fc_4(output) 27 | 28 | def forward(self, src, src_mask=None): 29 | this_bs = src.shape[0] 30 | x = self.encoder(src, src_mask) 31 | inputs = x.permute(0, 2, 1).float() 32 | if inputs.shape[-1] != self.seq_len: 33 | # print("Warning: seq_len(%d) != fixed_seq_len(%d), auto-pad."%(inputs.shape[-1], self.seq_len)) 34 | p1d = (0, self.seq_len - inputs.shape[-1]) 35 | inputs = F.pad(inputs, p1d, "constant", 0) 36 | # print("after padding,", inputs.shape) 37 | x = self.conv_1(inputs) 38 | x = self.conv_2(x) 39 | x = x.view(this_bs, -1) 40 | return self.feed_fc(x) -------------------------------------------------------------------------------- /transformer_elmo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | global device 10 | device = "cuda:0" 11 | from transformer import MultiHeadedAttention, PositionwiseFeedForward, \ 12 | PositionalEncoding, EncoderDecoder, \ 13 | Encoder, EncoderLayer, Decoder, DecoderLayer, \ 14 | subsequent_mask 15 | 16 | from Embedder import * 17 | 18 | def make_model_elmo(N=6, d_model=1024, d_ff=2048, h=8, dropout=0.1): 19 | "Helper: Construct a model from hyperparameters." 20 | c = copy.deepcopy 21 | attn = MultiHeadedAttention(h, d_model) 22 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 23 | position = PositionalEncoding(d_model, dropout) 24 | model = EncoderDecoder( 25 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 26 | Decoder(DecoderLayer(d_model, c(attn), c(attn), 27 | c(ff), dropout), N), 28 | nn.Sequential(Embedder(), c(position)), 29 | nn.Sequential(Embedder(), c(position)), 30 | generator=None) 31 | 32 | # This was important from their code. 33 | # Initialize parameters with Glorot / fan_avg. 34 | for p in model.parameters(): 35 | if p.dim() > 1: 36 | nn.init.xavier_uniform(p) 37 | return model 38 | 39 | class Batch(): 40 | "Object for holding a batch of data with mask during training." 41 | def __init__(self, src, trg=None, max_len=40): 42 | self.src = src 43 | length = [min(max_len, len(x))+2 for x in self.src] 44 | self.src_mask = torch.zeros((len(length), max_len + 2)) 45 | self.max_len = max_len 46 | for i,j in enumerate(length): 47 | self.src_mask[i,range(j)]=1 48 | 49 | if trg is not None: 50 | self.trg = trg 51 | # self.trg_y = trg 52 | self.trg_mask = \ 53 | self.make_std_mask(self.trg, max_len) 54 | self.ntokens = self.src_mask.data.sum() 55 | 56 | @staticmethod 57 | def make_std_mask(tgt, max_len): 58 | "Create a mask to hide padding and future words." 59 | length = [min(max_len, len(x))+1 for x in tgt] 60 | tgt_mask = torch.zeros((len(length), max_len + 1)) 61 | for i,j in enumerate(length): 62 | tgt_mask[i,range(j)]=1 63 | # tgt_mask = (tgt != pad).unsqueeze(-2) 64 | tgt_mask = tgt_mask & Variable( 65 | subsequent_mask(max_len + 1).type_as(tgt_mask.data)) 66 | return tgt_mask 67 | 68 | def pretrain_data_gen(iterator): 69 | "Generate random data for a src-tgt copy task." 70 | for i in iterator: 71 | yield Batch(i, i) 72 | 73 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num): 74 | "Standard Training and Logging Function" 75 | start = time.time() 76 | total_tokens = 0 77 | total_loss = 0 78 | tokens = 0 79 | ct = Clock(train_step_num) 80 | for i, batch in enumerate(data_iter): 81 | out = model.forward(batch.src, batch.trg, 82 | batch.src_mask, batch.trg_mask) 83 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 84 | total_loss += loss 85 | total_tokens += batch.ntokens 86 | tokens += batch.ntokens 87 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 88 | return total_loss / total_tokens.float().to(device) -------------------------------------------------------------------------------- /char_cnn_discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | global device 10 | device = "cuda:0" 11 | 12 | 13 | DIM = 512 14 | SEQ_LEN = 15 + 2 15 | WORD_DIM = 1024 16 | 17 | class Resblock(nn.Module): 18 | def __init__(self, inner_dim, kernel_size): 19 | super(Resblock, self).__init__() 20 | self.inner_dim = inner_dim 21 | self.kernel_size = kernel_size 22 | self.relu = nn.ReLU() 23 | if kernel_size % 2 != 1: 24 | raise Exception("kernel size must be odd number!") 25 | self.conv_1 = nn.Conv1d(self.inner_dim, self.inner_dim, self.kernel_size, padding=int((kernel_size-1)/2)) 26 | self.conv_2 = nn.Conv1d(self.inner_dim, self.inner_dim, self.kernel_size, padding=int((kernel_size-1)/2)) 27 | 28 | def forward(self, inputs): 29 | output = self.relu(inputs) 30 | output = self.conv_1(output) 31 | output = self.relu(output) 32 | output = self.conv_2(output) 33 | return inputs + (0.3*output) 34 | 35 | 36 | class Discriminator(nn.Module): 37 | def __init__(self, word_dim, inner_dim, seq_len, kernel_size=3, device="cuda:0", two_out=False): 38 | super(Discriminator, self).__init__() 39 | self.device = device 40 | self.word_dim = word_dim 41 | self.inner_dim = inner_dim 42 | self.seq_len = seq_len 43 | self.kernel_size = kernel_size 44 | if kernel_size % 2 != 1: 45 | raise Exception("kernel size must be odd number!") 46 | self.conv_1 = nn.Conv1d(self.word_dim, self.inner_dim, self.kernel_size, padding=int((kernel_size-1)/2)) 47 | self.resblock_1 = Resblock(inner_dim, kernel_size) 48 | self.resblock_2 = Resblock(inner_dim, kernel_size) 49 | self.resblock_3 = Resblock(inner_dim, kernel_size) 50 | self.resblock_4 = Resblock(inner_dim, kernel_size) 51 | W = seq_len*inner_dim 52 | self.fc_1 = nn.Linear(W, int(W/8)) 53 | self.fc_2 = nn.Linear(int(W/8), int(W/32)) 54 | self.fc_3 = nn.Linear(int(W/32), int(W/64)) 55 | self.fc_4 = nn.Linear(int(W / 64), 2 if two_out else 1) 56 | self.relu = nn.LeakyReLU() 57 | 58 | def feed_fc(self, inputs): 59 | output = self.relu(self.fc_1(inputs)) 60 | output = self.relu(self.fc_2(output)) 61 | output = self.relu(self.fc_3(output)) 62 | return self.fc_4(output) 63 | 64 | def forward(self, inputs): 65 | this_bs = inputs.shape[0] 66 | inputs = inputs.permute(0, 2, 1).float() 67 | if inputs.shape[-1] != self.seq_len: 68 | # print("Warning: seq_len(%d) != fixed_seq_len(%d), auto-pad."%(inputs.shape[-1], self.seq_len)) 69 | p1d = (0, self.seq_len - inputs.shape[-1]) 70 | inputs = F.pad(inputs, p1d, "constant", 0) 71 | # print("after padding,", inputs.shape) 72 | output = self.conv_1(inputs) 73 | output = self.resblock_1(output) 74 | output = self.resblock_2(output) 75 | output = self.resblock_3(output) 76 | output = self.resblock_4(output) 77 | output = output.view(this_bs, -1) 78 | # print(output.shape) 79 | return self.feed_fc(output) 80 | 81 | 82 | 83 | 84 | # def ResBlock(name, inputs): 85 | # output = inputs 86 | # output = tf.nn.relu(output) 87 | # output = tflib.ops.conv1d.Conv1D(name+'.1', DIM, DIM, 3, output) 88 | # output = tf.nn.relu(output) 89 | # output = tflib.ops.conv1d.Conv1D(name+'.2', DIM, DIM, 3, output) 90 | # return inputs + (0.3*output) 91 | 92 | 93 | # def discriminator_X(inputs): 94 | # output = tf.transpose(inputs, [0,2,1]) 95 | # output = tflib.ops.conv1d.Conv1D('discriminator_x.Input',WORD_DIM, DIM, 1, output) 96 | # output = ResBlock('discriminator_x.1', output) 97 | # output = ResBlock('discriminator_x.2', output) 98 | # output = ResBlock('discriminator_x.3', output) 99 | # output = ResBlock('discriminator_x.4', output) 100 | # #output = ResBlock('Discriminator.5', output) 101 | 102 | # output = tf.reshape(output, [-1, SEQ_LEN*DIM]) 103 | # output = tflib.ops.linear.Linear('discriminator_x.Output', SEQ_LEN*DIM, 1, output) 104 | # return tf.squeeze(output,[1]) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from attn import * 2 | 3 | def load_embedding(limit=100000): 4 | idx2word = ["",""] + list(np.load(os.path.join(os.path.dirname(__file__),"WordEmb/idx2word.npy")))[:limit-2] 5 | word2idx = dict([(word, i) for i, word in enumerate(idx2word)]) 6 | syn0 = np.load("WordEmb/word2vec_weights.npy")[:limit - 2] 7 | syn0 = np.concatenate((np.zeros((2, syn0.shape[1])), syn0),axis=0) 8 | chgex = re.compile(r'[\u4e00-\u9fff]+') 9 | non_ch = [] 10 | for i in range(2, len(idx2word)): 11 | if chgex.findall(idx2word[i]).__len__()==0: 12 | non_ch.append(syn0[i]) 13 | syn0[1] = np.mean(non_ch, axis=0) 14 | # syn0[1] = non_ch[87] 15 | return idx2word, word2idx, syn0 16 | 17 | def f2h(s): 18 | s = list(s) 19 | for i in range(len(s)): 20 | num = ord(s[i]) 21 | if num == 0x3000: 22 | num = 32 23 | elif 0xFF01 <= num <= 0xFF5E: 24 | num -= 0xfee0 25 | s[i] = chr(num).translate(str.maketrans('﹕﹐﹑。﹔﹖﹗﹘ ', ':,、。;?!- ')) 26 | return re.sub(r"( | )+", " ", "".join(s)).strip() 27 | 28 | class Utils(): 29 | def __init__(self, 30 | X_data_path, 31 | Y_data_path, 32 | X_test_path, 33 | Y_test_path, 34 | batch_size = 32, vocab_lim=10000): 35 | self.X_data_path = X_data_path 36 | self.X_line_num = int(os.popen("wc -l %s"%self.X_data_path).read().split(' ')[0]) 37 | self.Y_data_path = Y_data_path 38 | self.Y_line_num = int(os.popen("wc -l %s" % self.Y_data_path).read().split(' ')[0]) 39 | 40 | self.X_test_path = X_test_path 41 | self.X_test_num = int(os.popen("wc -l %s"%self.X_test_path).read().split(' ')[0]) 42 | self.Y_test_path = Y_test_path 43 | self.Y_test_num = int(os.popen("wc -l %s" % self.Y_test_path).read().split(' ')[0]) 44 | 45 | self.idx2word, self.word2idx, self.emb_mat = load_embedding(limit=vocab_lim) 46 | self.batch_size = batch_size 47 | self.train_step_num = math.floor(self.X_line_num / batch_size) 48 | self.test_step_num = math.floor(self.X_test_num / batch_size) 49 | self.device = "cuda:0" 50 | self.ch_gex = re.compile(r'[\u4e00-\u9fff]+') 51 | self.eng_gex = re.compile(r'[a-zA-Z0-90123456789\s]+') 52 | self.max_len = 15 53 | self.vocab_lim = vocab_lim 54 | 55 | def string2list(self, line): 56 | ret = [] 57 | temp_str = [] 58 | for char in line: 59 | if self.eng_gex.findall(char).__len__() == 0: 60 | if temp_str.__len__() > 0: 61 | ret.append("".join(temp_str).strip()) 62 | temp_str = [] 63 | ret.append(char) 64 | else: 65 | temp_str.append(char) 66 | if temp_str.__len__() > 0: 67 | ret.append("".join(temp_str).strip()) 68 | return ret 69 | 70 | def process_sent(self, sent): 71 | sent = f2h(sent) 72 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 73 | # char_list = self.string2list("".join(word_list)) 74 | # for i, char in enumerate(char_list): 75 | # if char not in self.word2idx: 76 | # char_list[i] = "" 77 | for i, word in enumerate(word_list): 78 | if word not in self.word2idx: 79 | word_list[i] = "" 80 | return word_list 81 | 82 | def data_generator(self, mode="X", write_actual_data=False): 83 | path = eval("self.%s_data_path" % mode) 84 | while True: 85 | file = open(path) 86 | sents = [] 87 | for sent in file: 88 | if len(sent.strip()) == 0: 89 | continue 90 | word_list = self.process_sent(sent) 91 | sents.append(word_list) 92 | if len(sents) == self.batch_size: 93 | yield sents 94 | sents = [] 95 | if len(sents)!=0: 96 | yield sents 97 | 98 | def test_generator(self, mode="X", write_actual_data=False): 99 | path = eval("self.%s_test_path" % mode) 100 | while True: 101 | file = open(path) 102 | sents = [] 103 | for sent in file: 104 | if len(sent.strip()) == 0: 105 | continue 106 | word_list = self.process_sent(sent) 107 | sents.append(word_list) 108 | if len(sents) == self.batch_size: 109 | yield sents 110 | sents = [] 111 | if len(sents)!=0: 112 | yield sents 113 | 114 | def sents2idx(self, sents, pad=0, add_eos=True, eos=3): 115 | idx_mat = np.zeros((len(sents), self.max_len + 1), dtype=np.int32) + pad 116 | for i in range(len(sents)): 117 | for j in range(min(len(sents[i]), self.max_len)): 118 | idx_mat[i][j] = self.word2idx[sents[i][j]] 119 | eos_pos = min(len(sents[i]), self.max_len) 120 | idx_mat[i][eos_pos] = eos 121 | return idx_mat 122 | 123 | def idx2sent(self, idxs, pad=0): 124 | ret = [] 125 | for i in range(len(idxs)): 126 | sent = [] 127 | for j in range(len(idxs[i])): 128 | sent.append(self.idx2word[idxs[i][j]]) 129 | ret.append(sent) 130 | return ret -------------------------------------------------------------------------------- /Embedder.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import os 5 | from ELMoForManyLangs import elmo 6 | import numpy as np 7 | 8 | class Embedder(nn.Module): 9 | def __init__(self, seq_len=0, use_cuda=True, run_device=None, target_device=None ,d_model=1024): 10 | super(Embedder, self).__init__() 11 | self.embedder = elmo.Embedder(model_dir="new.model", batch_size=512, use_cuda=use_cuda) 12 | self.seq_len = seq_len 13 | self.device = run_device 14 | self.target_device = target_device 15 | if self.device != None: 16 | self.embedder.model.to(self.device) 17 | self.bos_vec, self.eos_vec, self.pad, self.oov = self.embedder.sents2elmo([["","","",""]], output_layer=0)[0] 18 | 19 | def __call__(self, sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False): 20 | seq_lens = np.array([len(x) for x in sents], dtype=np.int64) 21 | sents = [[self.sub_unk(x) for x in sent] for sent in sents] 22 | if max_len != 0: 23 | pass 24 | elif self.seq_len != 0: 25 | max_len = self.seq_len 26 | else: 27 | max_len = seq_lens.max() 28 | emb_list = self.embedder.sents2elmo(sents, output_layer=layer) 29 | if not with_bos_eos: 30 | for i in range(len(emb_list)): 31 | if max_len - seq_lens[i] > 0: 32 | if pad_matters: 33 | emb_list[i] = np.concatenate([emb_list[i], np.tile(self.pad,[max_len - seq_lens[i],1])], axis=0) 34 | else: 35 | emb_list[i] = np.concatenate([emb_list[i], np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))]) 36 | else: 37 | emb_list[i] = emb_list[i][:max_len] 38 | elif with_bos_eos: 39 | for i in range(len(emb_list)): 40 | if max_len - seq_lens[i] > 0: 41 | if pad_matters: 42 | emb_list[i] = np.concatenate([ 43 | self.bos_vec[np.newaxis], 44 | emb_list[i], 45 | self.eos_vec[np.newaxis], 46 | np.tile(self.pad, [max_len - seq_lens[i], 1])], axis=0) 47 | else: 48 | emb_list[i] = np.concatenate([ 49 | self.bos_vec[np.newaxis], 50 | emb_list[i], 51 | self.eos_vec[np.newaxis], 52 | np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))], axis=0) 53 | else: 54 | emb_list[i] = np.concatenate([self.bos_vec[np.newaxis], emb_list[i][:max_len],self.eos_vec[np.newaxis]], axis=0) 55 | embedded = np.array(emb_list, dtype=np.float32) 56 | seq_lens = seq_lens+2 if with_bos_eos else seq_lens 57 | return embedded, seq_lens 58 | 59 | def forward(self, sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False): 60 | return torch.from_numpy(self.__call__(sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False)[0]).to(self.target_device) 61 | 62 | def sub_unk(self, e): 63 | e = e.replace(',',',') 64 | e = e.replace(':',':') 65 | e = e.replace(';',';') 66 | e = e.replace('?','?') 67 | e = e.replace('!', '!') 68 | return e 69 | 70 | class oov_handler(): 71 | def __init__(self): 72 | self.ch_gex = re.compile(r'[\u4e00-\u9fff]+') 73 | self.num_gex = re.compile(r'[0-9]+') 74 | self.eng_gex = re.compile(r'[a-zA-Z]+') 75 | self.sym_list = list(np.load(os.path.join(os.path.dirname(__file__),"sym_list.npy"))) 76 | def __call__(self, word): 77 | if self.ch_gex.findall(word) != []: 78 | return "" 79 | if self.eng_gex.findall(word) != []: 80 | return "" 81 | if self.num_gex.findall(word) != []: 82 | return "" 83 | if word in self.sym_list: 84 | return "" 85 | else: 86 | return "" 87 | 88 | class invELMo(nn.Module): 89 | def __init__(self, 90 | elmo=None, 91 | batch_size=32, 92 | input_size=1024, 93 | hidden_size=300, 94 | h_size=500, 95 | n_layers=3, 96 | dropout=0.33): 97 | super(invELMo, self).__init__() 98 | self.batch_size = batch_size 99 | self.vocab_lim = 100000 100 | self.n_layers = n_layers 101 | self.hidden_size = hidden_size 102 | # self.load_elmo() 103 | if elmo == None: 104 | print("ELMo model not provided. You can't use this model to train, but you can test.") 105 | self.elmo = elmo 106 | self.total_line = 50563844 107 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 108 | dropout=(0 if n_layers == 1 else dropout), 109 | bidirectional=True, 110 | batch_first=True) 111 | self.fc1 = nn.Linear(2*hidden_size, self.vocab_lim) 112 | self.criterion = nn.CrossEntropyLoss(ignore_index=0) 113 | self.optimizer = torch.optim.Adam(self.parameters()) 114 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 115 | self.load_corpus_dict(self.vocab_lim) 116 | self.handle_oov = oov_handler() 117 | self.bos_vec, self.eos_vec = np.load(os.path.join(os.path.dirname(__file__),"bos_eos.npy")) 118 | 119 | def process(self, sent): 120 | return [x if x in self.word2idx else self.handle_oov(x) for x in sent] 121 | 122 | def load_elmo(self): 123 | print("loading ELMo model ...") 124 | self.elmo = Embedder() 125 | print("ELMo model loaded!") 126 | 127 | def load_corpus_dict(self, limit): 128 | self.idx2word = [""] + list(np.load(os.path.join(os.path.dirname(__file__),"idx2word_new.npy")))[:limit-1] 129 | self.word2idx = dict([(word, i) for i, word in enumerate(self.idx2word)]) 130 | 131 | def corpus_generator(self, filename="shuff_corpus.txt"): 132 | f = open(os.path.join(os.path.dirname(__file__),filename)) 133 | batch_list = [] 134 | for i in f: 135 | batch_list.append(self.process([""]+sub_unk(i.strip()).split(' ')+[""])) 136 | if len(batch_list) == self.batch_size: 137 | yield batch_list 138 | batch_list = [] 139 | 140 | 141 | def padded_corpus_generator(self, filename="shuff_corpus.txt", max_len=25): 142 | f = open(os.path.join(os.path.dirname(__file__),filename)) 143 | batch_list = [] 144 | for i in f: 145 | org_sent = sub_unk(i.strip()).split(' ') 146 | if len(org_sent) > max_len - 2: 147 | org_sent = org_sent[:max_len - 2] 148 | batch_list.append(self.process([""] + org_sent + [""] + ["" for _ in range(max_len - 2 - len(org_sent))])) 149 | if len(batch_list) == self.batch_size: 150 | yield batch_list 151 | batch_list = [] 152 | 153 | 154 | def sent2idx(self, sents, max_len = 0): 155 | if max_len==0: 156 | for i in sents: 157 | if len(i) > max_len: 158 | max_len = len(i) 159 | sents_lens = [] 160 | sent_mat = np.zeros((len(sents), max_len), dtype=np.int64) 161 | for i in range(len(sents)): 162 | sents_i_len = len(sents[i]) 163 | sents_lens.append(sents_i_len) 164 | for j in range(max_len): 165 | if j < sents_i_len: 166 | sent_mat[i][j] = self.word2idx[sents[i][j]] 167 | return [sent_mat, np.array(sents_lens)] 168 | 169 | def forward(self, input_seq, input_lengths, hidden=None): 170 | embedded = torch.from_numpy(input_seq).to(self.device) 171 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 172 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 173 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 174 | pred_prob = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 175 | embedded.cpu() 176 | return pred_prob 177 | 178 | def train_model(self, num_epochs=1, step_num=100000, step_to_save_model=1000, filename="shuff_corpus.txt"): 179 | self.to(self.device) 180 | for epoch in range(num_epochs): 181 | ct = Clock(step_num) 182 | His_loss = History(title="Loss", xlabel="step", ylabel="loss", 183 | item_name=["train_loss"]) 184 | His_ppl = History(title="Perplexity", xlabel="step", ylabel="loss", 185 | item_name=["train_ppl"]) 186 | for step, batch_x in enumerate(self.padded_corpus_generator(filename=filename)): 187 | batch_y, x_lens = self.sent2idx(batch_x) 188 | elmo_x, x_lens = self.elmo(batch_x) 189 | (elmo_x, x_lens, batch_y), _ind = sort_numpy([elmo_x, x_lens, batch_y], piv=1) 190 | target = torch.from_numpy(batch_y).cuda() if self.device!="cpu" else torch.from_numpy(batch_y) 191 | pred = self.forward(elmo_x, x_lens) 192 | loss = self.criterion(pred.transpose(1, 2), target) 193 | self.optimizer.zero_grad() 194 | loss.backward() 195 | self.optimizer.step() 196 | ppl = math.exp(loss.cpu().item()) 197 | info_dict = {"loss": loss, "ppl": ppl} 198 | ct.flush(info=info_dict) 199 | His_loss.append_history(0, (step, loss)) 200 | His_ppl.append_history(0, (step, ppl)) 201 | target.cpu() 202 | if step == step_num: 203 | break 204 | if (step + 1) % step_to_save_model == 0: 205 | torch.save(self.state_dict(), os.path.join(os.path.dirname(__file__),'model.ckpt')) 206 | His_loss.plot("loss_plot") 207 | His_ppl.plot("pll_plot") 208 | test_corpus(self, "small_cou.txt") 209 | torch.save(self.state_dict(), os.path.join(os.path.dirname(__file__),'model.ckpt')) 210 | His_loss.plot() 211 | His_ppl.plot() 212 | 213 | 214 | def load_model(self, filename='model.ckpt', device="cuda:0"): 215 | self.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__),filename), map_location=device)) 216 | print("model.ckpt load!") 217 | 218 | def test(self, input_seq, input_lengths, hidden=None): 219 | pred_prob = self.forward(input_seq, input_lengths) 220 | pred_idx = pred_prob.argmax(2) 221 | return [[self.idx2word[x] for x in r] for r in pred_idx] -------------------------------------------------------------------------------- /my_transformer_cyclegan.py: -------------------------------------------------------------------------------- 1 | from new_trans import * 2 | from char_cnn_discriminator import * 3 | import argparse 4 | seq_len = 17 5 | 6 | def continuous_decode(model, src, src_mask, max_len, start_symbol=2): 7 | memory = model.encode(src, src_mask) # encode is discrete 8 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device) 9 | collect_out = model.tgt_embed[0](ys).float() 10 | # word_col = ys[:] 11 | # ys = model.tgt_embed(ys) 12 | for i in range(max_len-1): 13 | out = model.conti_decode(memory, src_mask, Variable(model.tgt_embed[1](collect_out)), 14 | Variable(subsequent_mask(collect_out.size(1)) 15 | .type_as(src.data))) 16 | # collect_out.append(out[:, -1]) 17 | collect_out = torch.cat([collect_out, out[:, -1].unsqueeze(1)], dim=1) 18 | # prob = model.generator(out[:, -1]) 19 | # _, next_word = torch.max(prob, dim = 1) 20 | # word_col = torch.cat([word_col, next_word.unsqueeze(-1)], dim=1) 21 | # ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 22 | # ys = torch.cat([ys, model.src_embed(next_word.unsqueeze(1))], dim=1) 23 | # ys = torch.cat([ys, out[:, -1].unsqueeze(1)], dim=1) 24 | return collect_out#ys #, word_col 25 | 26 | def decode_with_output(model, src, src_mask, max_len, start_symbol=2): 27 | memory = model.encode(src, src_mask) # encode is discrete 28 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device) 29 | collect_out = model.tgt_embed(ys) 30 | for i in range(max_len-1): 31 | out = model.decode(memory, src_mask, Variable(ys), 32 | Variable(subsequent_mask(ys.size(1)) 33 | .type_as(src.data))) 34 | collect_out = torch.cat([collect_out, out[:, -1].unsqueeze(1)], dim=1) 35 | prob = model.generator(out[:, -1]) 36 | _, next_word = torch.max(prob, dim = 1) 37 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 38 | return ys, collect_out 39 | 40 | def prob_backward(model, src, src_mask, max_len, start_symbol=2): 41 | memory = model.encode(src, src_mask) 42 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device) 43 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 44 | for i in range(max_len-1): 45 | out = model.decode(memory, src_mask, 46 | Variable(ys), 47 | Variable(subsequent_mask(ys.size(1)) 48 | .type_as(src.data))) 49 | prob = model.generator(out[:, -1], scale=10) 50 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 51 | _, next_word = torch.max(prob, dim = 1) 52 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 53 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 54 | return ret_back 55 | 56 | def reconstruct(model, src, max_len, start_symbol=2): 57 | memory = model.encoder(model.src_embed[1](src), None) 58 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) 59 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 60 | for i in range(max_len-1): 61 | out = model.decode(memory, None, 62 | Variable(ys), 63 | Variable(subsequent_mask(ys.size(1)) 64 | .type_as(src.data))) 65 | prob = model.generator(out[:, -1]) 66 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 67 | _, next_word = torch.max(prob, dim = 1) 68 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 69 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 70 | return ret_back 71 | 72 | class CycleGAN(nn.Module): 73 | def __init__(self, discriminator, generator, utils): 74 | super(CycleGAN, self).__init__() 75 | self.D = discriminator 76 | self.G = generator 77 | self.R = copy.deepcopy(generator) 78 | self.D_opt = torch.optim.Adam(self.D.parameters()) 79 | self.G_opt = torch.optim.Adam(self.G.parameters()) 80 | self.R_opt = torch.optim.Adam(self.R.parameters()) 81 | 82 | self.utils = utils 83 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 84 | self.mse = nn.MSELoss() 85 | 86 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): 87 | torch.save(self.D.state_dict(), d_path) 88 | torch.save(self.G.state_dict(), g_path) 89 | torch.save(self.R.state_dict(), r_path) 90 | 91 | def load_model(self, path="", g_file=None, d_file=None, r_file=None): 92 | if g_file!=None: 93 | self.D.load_state_dict(torch.load(os.path.join(path, g_file))) 94 | if d_file!=None: 95 | self.G.load_state_dict(torch.load(os.path.join(path, d_file))) 96 | if r_file!=None: 97 | self.R.load_state_dict(torch.load(os.path.join(path, r_file))) 98 | print("model loaded!") 99 | 100 | 101 | def train_model(self, num_epochs=100, d_steps=50, g_steps=70, main_device="cuda:0", sec_device="cuda:1"): 102 | # self.D.to(self.D.device) 103 | # self.G.to(self.G.device) 104 | # self.R.to(self.R.device) 105 | X_datagen = self.utils.data_generator("X") 106 | Y_datagen = self.utils.data_generator("Y") 107 | for epoch in range(num_epochs): 108 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 109 | for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 110 | # 1. Train D on real+fake 111 | self.D.zero_grad() 112 | 113 | # 1A: Train D on real 114 | d_real_pred = self.D(self.G.tgt_embed[0](Y_data.src)) 115 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true 116 | d_real_error.backward() # compute/store gradients, but don't change params 117 | self.D_opt.step() 118 | 119 | # 1B: Train D on fake 120 | self.G.to(main_device) 121 | d_fake_data = prob_backward(self.G, X_data.src, X_data.src_mask, max_len=seq_len).detach() # detach to avoid training G on these labels 122 | d_fake_pred = self.D(d_fake_data) 123 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake 124 | d_fake_error.backward() 125 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 126 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 127 | 128 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 129 | for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): 130 | # 2. Train G on D's response (but DO NOT train D on these labels) 131 | self.G.zero_grad() 132 | g_fake_data = prob_backward(self.G, X_data.src, X_data.src_mask, max_len=seq_len) 133 | dg_fake_pred = self.D(g_fake_data) 134 | g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 135 | 136 | g_error.backward(retain_graph=True) 137 | self.G_opt.step() # Only optimizes G's parameters 138 | self.G.zero_grad() 139 | g_ct.flush(info={"G_loss": g_error.item()}) 140 | 141 | # 3. reconstructor 142 | r_reco_data = reconstruct(self.R, g_fake_data, max_len=seq_len) 143 | x_orgi_data = self.R.tgt_embed[0].pure_emb(X_data.src) 144 | r_error = self.mse(r_reco_data.float(), x_orgi_data.float()) 145 | r_error.backward() 146 | self.R.zero_grad() 147 | self.R_opt.step() 148 | self.G_opt.step() 149 | g_ct.flush(info={"G_loss": g_error.item(), 150 | "R_loss": r_error.item()}) 151 | 152 | with torch.no_grad(): 153 | x = utils.idx2sent(greedy_decode(model, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 154 | y = utils.idx2sent(greedy_decode(model, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 155 | 156 | for i,j in zip(X_test_batch.src, x): 157 | print("===") 158 | k = utils.idx2sent([i])[0] 159 | print("ORG:", " ".join(k[:k.index('')+1])) 160 | print("--") 161 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 162 | print("===") 163 | print("=====") 164 | for i, j in zip(Y_test_batch.src, y): 165 | print("===") 166 | k = utils.idx2sent([i])[0] 167 | print("ORG:", " ".join(k[:k.index('')+1])) 168 | print("--") 169 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 170 | print("===") 171 | self.save_model() 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("mode", help="execute mode") 176 | # parser.add_argument("-filename", default=None, required=False, help="test filename") 177 | # parser.add_argument("-actual_name", default=None, required=False, help="test filename") 178 | # parser.add_argument("-load_model", default=False, required=False, help="test filename") 179 | # parser.add_argument("-model_name", default="model.ckpt", required=False, help="test filename") 180 | # parser.add_argument("-save_path", default="", required=False, help="test filename") 181 | # parser.add_argument("-train_file", default=None, required=False, help="test filename") 182 | # parser.add_argument("-test_file", default=None, required=True, help="test filename") 183 | # parser.add_argument("-epoch", default=1, required=False, help="test filename") 184 | args = parser.parse_args() 185 | 186 | utils = Utils(X_data_path="small_cou.txt", Y_data_path="small_cna.txt") 187 | # Train the simple copy task. 188 | V = 10000 189 | # _,_, emb_mat = load_embedding(limit=100000) 190 | criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0) 191 | model = make_model(V, V, utils.emb_mat, utils.emb_mat) 192 | d = Discriminator(utils.emb_mat.shape[1], int(utils.emb_mat.shape[1]/2), 15+2) 193 | cyclegan = CycleGAN(d, model, utils) 194 | cyclegan.D.src_embed = cyclegan.D.tgt_embed = cyclegan.R.src_embed = cyclegan.R.tgt_embed 195 | # cyclegan.D = torch.nn.DataParallel(cyclegan.D, device_ids=[0, 1]).cuda().module 196 | cyclegan.G.load_model(filename="model_9.ckpt") 197 | # cyclegan.G = torch.nn.DataParallel(cyclegan.G, device_ids=[0, 1]).cuda().module 198 | cyclegan.R.load_model(filename="model_9.ckpt") 199 | cyclegan = torch.nn.DataParallel(cyclegan, device_ids=[0, 1]).cuda().module 200 | if args.mode == "train": 201 | cyclegan.train_model() -------------------------------------------------------------------------------- /v3_cyclegan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | from attn import Transformer, LabelSmoothing, \ 10 | data_gen, NoamOpt, Generator, SimpleLossCompute, \ 11 | greedy_decode, subsequent_mask 12 | from utils import Utils 13 | from transformer_discriminator import Discriminator 14 | import argparse 15 | 16 | device = "cuda:0" 17 | 18 | def prob_backward(model, embed, src, src_mask, max_len, start_symbol=2, raw=False): 19 | if raw==False: 20 | memory = model.encode(embed(src.to(device)), src_mask) 21 | else: 22 | memory = model.encode(src.to(device), src_mask) 23 | 24 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 25 | probs = [] 26 | for i in range(max_len+2-1): 27 | out = model.decode(memory, src_mask, 28 | embed(Variable(ys)), 29 | Variable(subsequent_mask(ys.size(1)) 30 | .type_as(src.data))) 31 | prob = model.generator(out[:, -1]) 32 | probs.append(prob.unsqueeze(1)) 33 | _, next_word = torch.max(prob, dim = 1) 34 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 35 | ret = torch.cat(probs, dim=1) 36 | return ret 37 | 38 | def backward_decode(model, embed, src, src_mask, max_len, start_symbol=2, raw=False, return_term=-1): 39 | if raw==False: 40 | memory = model.encode(embed(src.to(device)), src_mask) 41 | else: 42 | memory = model.encode(src.to(device), src_mask) 43 | 44 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 45 | ret_back = embed(ys).float() 46 | for i in range(max_len+2-1): 47 | out = model.decode(memory, src_mask, 48 | embed(Variable(ys)), 49 | Variable(subsequent_mask(ys.size(1)) 50 | .type_as(src.data))) 51 | prob = model.generator.scaled_forward(out[:, -1], scale=10.0) 52 | back = torch.matmul(prob ,embed.weight.data.float()) 53 | _, next_word = torch.max(prob, dim = 1) 54 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 55 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 56 | return (ret_back, ys) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else None 57 | 58 | def reconstruct(model, src, max_len, start_symbol=2): 59 | memory = model.encoder(model.src_embed[1](src), None) 60 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) 61 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 62 | for i in range(max_len-1): 63 | out = model.decode(memory, None, 64 | Variable(ys), 65 | Variable(subsequent_mask(ys.size(1)) 66 | .type_as(src.data))) 67 | prob = model.generator(out[:, -1]) 68 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 69 | _, next_word = torch.max(prob, dim = 1) 70 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 71 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 72 | return ret_back 73 | 74 | class CycleGAN(nn.Module): 75 | def __init__(self, discriminator, generator, utils, embedder): 76 | super(CycleGAN, self).__init__() 77 | self.D = discriminator 78 | self.G = generator 79 | self.R = copy.deepcopy(generator) 80 | self.D_opt = torch.optim.Adam(self.D.parameters()) 81 | # self.G_opt = torch.optim.Adam(self.G.parameters()) 82 | self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 83 | torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 84 | # self.R_opt = torch.optim.Adam(self.R.parameters()) 85 | self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 86 | torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 87 | self.embed = embedder 88 | 89 | self.utils = utils 90 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 91 | self.mse = nn.MSELoss() 92 | self.cos = nn.CosineSimilarity(dim=-1) 93 | self.cosloss=nn.CosineEmbeddingLoss() 94 | self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 95 | self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt) 96 | 97 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): 98 | torch.save(self.D.state_dict(), d_path) 99 | torch.save(self.G.state_dict(), g_path) 100 | torch.save(self.R.state_dict(), r_path) 101 | 102 | def load_model(self, path="", g_file=None, d_file=None, r_file=None): 103 | if g_file!=None: 104 | self.G.load_state_dict(torch.load(os.path.join(path, g_file))) 105 | if d_file!=None: 106 | self.D.load_state_dict(torch.load(os.path.join(path, d_file))) 107 | if r_file!=None: 108 | self.R.load_state_dict(torch.load(os.path.join(path, r_file))) 109 | print("model loaded!") 110 | 111 | def pretrain_disc(self, num_epochs=100): 112 | # self.D.to(self.D.device) 113 | # self.G.to(self.G.device) 114 | # self.R.to(self.R.device) 115 | X_datagen = self.utils.data_generator("X") 116 | Y_datagen = self.utils.data_generator("Y") 117 | for epoch in range(num_epochs): 118 | d_steps = self.utils.train_step_num 119 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 120 | for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 121 | # 1. Train D on real+fake 122 | # if epoch == 0: 123 | # break 124 | self.D.zero_grad() 125 | 126 | # 1A: Train D on real 127 | d_real_pred = self.D(self.embed(Y_data.src.to(device))) 128 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(device)) # ones = true 129 | 130 | # 1B: Train D on fake 131 | d_fake_pred = self.D(self.embed(X_data.src.to(device))) 132 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(device)) # zeros = fake 133 | (d_fake_error + d_real_error).backward() 134 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 135 | d_ct.flush(info={"D_loss": d_fake_error.item()}) 136 | torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt") 137 | 138 | def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0, main_device="cuda:0", sec_device="cuda:1"): 139 | # self.D.to(self.D.device) 140 | # self.G.to(self.G.device) 141 | # self.R.to(self.R.device) 142 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 143 | X_test_batch = batch 144 | break 145 | 146 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 147 | Y_test_batch = batch 148 | break 149 | X_datagen = self.utils.data_generator("X") 150 | Y_datagen = self.utils.data_generator("Y") 151 | for epoch in range(num_epochs): 152 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) 153 | if epoch>0: 154 | for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 155 | # 1. Train D on real+fake 156 | # if epoch == 0: 157 | # break 158 | self.D.zero_grad() 159 | 160 | # 1A: Train D on real 161 | d_real_pred = self.D(self.embed(Y_data.src.to(device))) 162 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(device)) # ones = true 163 | 164 | # 1B: Train D on fake 165 | self.G.to(main_device) 166 | d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels 167 | d_fake_pred = self.D(d_fake_data) 168 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(device)) # zeros = fake 169 | (d_fake_error + d_real_error).backward() 170 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 171 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 172 | 173 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 174 | r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs)) 175 | if epoch>0: 176 | for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): 177 | # 2. Train G on D's response (but DO NOT train D on these labels) 178 | self.G.zero_grad() 179 | g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0) 180 | dg_fake_pred = self.D(g_fake_data) 181 | g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(device)) # we want to fool, so pretend it's all genuine 182 | 183 | g_error.backward(retain_graph=True) 184 | self.G_opt.step() # Only optimizes G's parameters 185 | self.G.zero_grad() 186 | g_ct.flush(info={"G_loss": g_error.item()}) 187 | 188 | # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com 189 | # way_3 190 | out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 191 | None, X_data.trg_mask) 192 | r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) 193 | # way_2 194 | # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) 195 | # x_orgi_data = X_data.src[:, 1:] 196 | # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) 197 | # way_1 198 | # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] 199 | # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) 200 | self.G_opt.step() 201 | self.G_opt.optimizer.zero_grad() 202 | r_ct.flush(info={"G_loss": g_error.item(), 203 | "R_loss": r_loss / X_data.ntokens.float().to(device)}) 204 | 205 | with torch.no_grad(): 206 | x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) 207 | x = utils.idx2sent(x_ys) 208 | y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) 209 | y = utils.idx2sent(y_ys) 210 | r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 211 | r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 212 | 213 | for i,j,l in zip(X_test_batch.src, x, r_x): 214 | print("===") 215 | k = utils.idx2sent([i])[0] 216 | print("ORG:", " ".join(k[:k.index('')+1])) 217 | print("--") 218 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 219 | print("--") 220 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 221 | print("===") 222 | print("=====") 223 | for i, j, l in zip(Y_test_batch.src, y, r_y): 224 | print("===") 225 | k = utils.idx2sent([i])[0] 226 | print("ORG:", " ".join(k[:k.index('')+1])) 227 | print("--") 228 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 229 | print("--") 230 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 231 | print("===") 232 | # self.save_model() 233 | 234 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer): 235 | "Standard Training and Logging Function" 236 | start = time.time() 237 | total_tokens = 0 238 | total_loss = 0 239 | tokens = 0 240 | ct = Clock(train_step_num) 241 | embedding_layer.to(device) 242 | model.to(device) 243 | for i, batch in enumerate(data_iter): 244 | batch.to(device) 245 | out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), 246 | batch.src_mask, batch.trg_mask) 247 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 248 | total_loss += loss 249 | total_tokens += batch.ntokens 250 | tokens += batch.ntokens 251 | batch.to("cpu") 252 | if i % 50 == 1: 253 | elapsed = time.time() - start 254 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) 255 | # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 256 | # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) 257 | start = time.time() 258 | tokens = 0 259 | else: 260 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 261 | return total_loss / total_tokens.float().to(device) 262 | 263 | def get_embedding_layer(utils): 264 | d_model = utils.emb_mat.shape[1] 265 | vocab = utils.emb_mat.shape[0] 266 | embedding_layer = nn.Embedding(vocab, d_model) 267 | embedding_layer.weight.data = torch.tensor(utils.emb_mat) 268 | embedding_layer.weight.requires_grad = False 269 | embedding_layer.to(device) 270 | return embedding_layer 271 | 272 | def pretrain(model, embedding_layer, utils, epoch_num=1): 273 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 274 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 275 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 276 | X_test_batch = None 277 | Y_test_batch = None 278 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 279 | X_test_batch = batch 280 | break 281 | 282 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 283 | Y_test_batch = batch 284 | break 285 | model.to(device) 286 | for epoch in range(epoch_num): 287 | model.train() 288 | print("EPOCH %d:"%(epoch+1)) 289 | pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 290 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 291 | pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 292 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 293 | model.eval() 294 | torch.save(model.state_dict(), 'model_pretrain.ckpt') 295 | x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 296 | y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 297 | 298 | for i,j in zip(X_test_batch.src, x): 299 | print("===") 300 | k = utils.idx2sent([i])[0] 301 | print("ORG:", " ".join(k[:k.index('')+1])) 302 | print("--") 303 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 304 | print("===") 305 | print("=====") 306 | for i, j in zip(Y_test_batch.src, y): 307 | print("===") 308 | k = utils.idx2sent([i])[0] 309 | print("ORG:", " ".join(k[:k.index('')+1])) 310 | print("--") 311 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 312 | print("===") 313 | 314 | if __name__ == "__main__": 315 | parser = argparse.ArgumentParser() 316 | parser.add_argument("mode", help="execute mode") 317 | parser.add_argument("-filename", default=None, required=False, help="test filename") 318 | parser.add_argument("-load_model", default=False, required=False, help="test filename") 319 | parser.add_argument("-model_name", default="model.ckpt", required=False, help="test filename") 320 | parser.add_argument("-save_path", default="", required=False, help="test filename") 321 | parser.add_argument("-train_file", default=None, required=False, help="test filename") 322 | parser.add_argument("-test_file", default=None, required=False, help="test filename") 323 | parser.add_argument("-epoch", default=1, required=False, help="test filename") 324 | parser.add_argument("-max_len", default=20, required=False, help="test filename") 325 | args = parser.parse_args() 326 | 327 | model = Transformer(N=2) 328 | utils = Utils(X_data_path="big_cna.txt", Y_data_path="big_cou.txt") 329 | embedding_layer = get_embedding_layer(utils).to(device) 330 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 331 | if args.load_model: 332 | model.load_state_dict(torch.load(args.model_name)) 333 | if args.mode == "pretrain": 334 | pretrain(model, embedding_layer, utils, int(args.epoch)) 335 | if args.mode == "cycle": 336 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=2048, seq_len=20) 337 | main_model = CycleGAN(disc, model, utils, embedding_layer) 338 | main_model.to(device) 339 | main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file="model_disc_pretrain.ckpt") 340 | main_model.train_model() 341 | if args.mode == "disc": 342 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=2048, seq_len=20) 343 | main_model = CycleGAN(disc, model, utils, embedding_layer) 344 | main_model.to(device) 345 | main_model.pretrain_disc() 346 | 347 | if args.mode == "dev": 348 | model = Transformer(N=2) 349 | utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt") 350 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 351 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 352 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400, 353 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 354 | X_test_batch = None 355 | Y_test_batch = None 356 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 357 | X_test_batch = batch 358 | break 359 | 360 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 361 | Y_test_batch = batch 362 | break 363 | 364 | # if args.load_model: 365 | # model.load_model(filename=args.model_name) 366 | # if args.mode == "train": 367 | # model.train_model(num_epochs=int(args.epoch)) 368 | # print("========= Testing =========") 369 | # model.load_model() 370 | # model.test_corpus() 371 | # if args.mode == "test": 372 | # model.test_corpus() 373 | -------------------------------------------------------------------------------- /v2_cyclegan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | from attn import Transformer, LabelSmoothing, \ 10 | data_gen, NoamOpt, Generator, SimpleLossCompute, \ 11 | greedy_decode, subsequent_mask 12 | from utils import Utils 13 | from char_cnn_discriminator import Discriminator 14 | import argparse 15 | 16 | device = "cuda:1" 17 | 18 | 19 | def prob_backward(model, embed, src, src_mask, max_len, start_symbol=2, raw=False): 20 | if raw==False: 21 | memory = model.encode(embed(src.to(device)), src_mask) 22 | else: 23 | memory = model.encode(src.to(device), src_mask) 24 | 25 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 26 | probs = [] 27 | for i in range(max_len+2-1): 28 | out = model.decode(memory, src_mask, 29 | embed(Variable(ys)), 30 | Variable(subsequent_mask(ys.size(1)) 31 | .type_as(src.data))) 32 | prob = model.generator(out[:, -1]) 33 | probs.append(prob.unsqueeze(1)) 34 | _, next_word = torch.max(prob, dim = 1) 35 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 36 | ret = torch.cat(probs, dim=1) 37 | return ret 38 | 39 | def backward_decode(model, embed, src, src_mask, max_len, start_symbol=2, raw=False, return_term=-1): 40 | if raw==False: 41 | memory = model.encode(embed(src.to(device)), src_mask) 42 | else: 43 | memory = model.encode(src.to(device), src_mask) 44 | 45 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 46 | ret_back = embed(ys).float() 47 | for i in range(max_len+2-1): 48 | out = model.decode(memory, src_mask, 49 | embed(Variable(ys)), 50 | Variable(subsequent_mask(ys.size(1)) 51 | .type_as(src.data))) 52 | prob = model.generator.scaled_forward(out[:, -1], scale=10.0) 53 | back = torch.matmul(prob ,embed.weight.data.float()) 54 | _, next_word = torch.max(prob, dim = 1) 55 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 56 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 57 | return (ret_back, ys) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else None 58 | 59 | def reconstruct(model, src, max_len, start_symbol=2): 60 | memory = model.encoder(model.src_embed[1](src), None) 61 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) 62 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 63 | for i in range(max_len-1): 64 | out = model.decode(memory, None, 65 | Variable(ys), 66 | Variable(subsequent_mask(ys.size(1)) 67 | .type_as(src.data))) 68 | prob = model.generator(out[:, -1]) 69 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 70 | _, next_word = torch.max(prob, dim = 1) 71 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 72 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 73 | return ret_back 74 | 75 | class CycleGAN(nn.Module): 76 | def __init__(self, discriminator, generator, utils, embedder): 77 | super(CycleGAN, self).__init__() 78 | self.D = discriminator 79 | self.G = generator 80 | self.R = copy.deepcopy(generator) 81 | self.D_opt = torch.optim.Adam(self.D.parameters()) 82 | # self.G_opt = torch.optim.Adam(self.G.parameters()) 83 | self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 84 | torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 85 | # self.R_opt = torch.optim.Adam(self.R.parameters()) 86 | self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 87 | torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 88 | self.embed = embedder 89 | 90 | self.utils = utils 91 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 92 | self.mse = nn.MSELoss() 93 | self.cos = nn.CosineSimilarity(dim=-1) 94 | self.cosloss=nn.CosineEmbeddingLoss() 95 | self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 96 | self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt) 97 | 98 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): 99 | torch.save(self.D.state_dict(), d_path) 100 | torch.save(self.G.state_dict(), g_path) 101 | torch.save(self.R.state_dict(), r_path) 102 | 103 | def load_model(self, path="", g_file=None, d_file=None, r_file=None): 104 | if g_file!=None: 105 | self.G.load_state_dict(torch.load(os.path.join(path, g_file), map_location=device)) 106 | if d_file!=None: 107 | self.D.load_state_dict(torch.load(os.path.join(path, d_file), map_location=device)) 108 | if r_file!=None: 109 | self.R.load_state_dict(torch.load(os.path.join(path, r_file), map_location=device)) 110 | print("model loaded!") 111 | 112 | def pretrain_disc(self, num_epochs=100): 113 | # self.D.to(self.D.device) 114 | # self.G.to(self.G.device) 115 | # self.R.to(self.R.device) 116 | X_datagen = self.utils.data_generator("X") 117 | Y_datagen = self.utils.data_generator("Y") 118 | for epoch in range(num_epochs): 119 | d_steps = self.utils.train_step_num 120 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 121 | for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 122 | # 1. Train D on real+fake 123 | # if epoch == 0: 124 | # break 125 | self.D.zero_grad() 126 | 127 | # 1A: Train D on real 128 | d_real_pred = self.D(self.embed(Y_data.src.to(device))) 129 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true 130 | 131 | # 1B: Train D on fake 132 | d_fake_pred = self.D(self.embed(X_data.src.to(device))) 133 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake 134 | (d_fake_error + d_real_error).backward() 135 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 136 | d_ct.flush(info={"D_loss": d_fake_error.item()}) 137 | torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt") 138 | 139 | def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0): 140 | # self.D.to(self.D.device) 141 | # self.G.to(self.G.device) 142 | # self.R.to(self.R.device) 143 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 144 | X_test_batch = batch 145 | break 146 | 147 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 148 | Y_test_batch = batch 149 | break 150 | X_datagen = self.utils.data_generator("X") 151 | Y_datagen = self.utils.data_generator("Y") 152 | for epoch in range(num_epochs): 153 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) 154 | if epoch>0: 155 | for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 156 | # 1. Train D on real+fake 157 | # if epoch == 0: 158 | # break 159 | self.D.zero_grad() 160 | 161 | # 1A: Train D on real 162 | d_real_pred = self.D(self.embed(Y_data.src.to(device))) 163 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true 164 | 165 | # 1B: Train D on fake 166 | self.G.to(device) 167 | d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels 168 | d_fake_pred = self.D(d_fake_data) 169 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake 170 | (d_fake_error + d_real_error).backward() 171 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 172 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 173 | 174 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 175 | r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs)) 176 | if epoch>0: 177 | for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): 178 | # 2. Train G on D's response (but DO NOT train D on these labels) 179 | self.G.zero_grad() 180 | g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0) 181 | dg_fake_pred = self.D(g_fake_data) 182 | g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 183 | 184 | g_error.backward(retain_graph=True) 185 | self.G_opt.step() # Only optimizes G's parameters 186 | self.G.zero_grad() 187 | g_ct.flush(info={"G_loss": g_error.item()}) 188 | 189 | # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com 190 | # way_3 191 | out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 192 | None, X_data.trg_mask) 193 | r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) 194 | # way_2 195 | # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) 196 | # x_orgi_data = X_data.src[:, 1:] 197 | # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) 198 | # way_1 199 | # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] 200 | # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) 201 | self.G_opt.step() 202 | self.G_opt.optimizer.zero_grad() 203 | r_ct.flush(info={"G_loss": g_error.item(), 204 | "R_loss": r_loss / X_data.ntokens.float().to(device)}) 205 | 206 | with torch.no_grad(): 207 | x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) 208 | x = utils.idx2sent(x_ys) 209 | y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) 210 | y = utils.idx2sent(y_ys) 211 | r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 212 | r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 213 | 214 | for i,j,l in zip(X_test_batch.src, x, r_x): 215 | print("===") 216 | k = utils.idx2sent([i])[0] 217 | print("ORG:", " ".join(k[:k.index('')+1])) 218 | print("--") 219 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 220 | print("--") 221 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 222 | print("===") 223 | print("=====") 224 | for i, j, l in zip(Y_test_batch.src, y, r_y): 225 | print("===") 226 | k = utils.idx2sent([i])[0] 227 | print("ORG:", " ".join(k[:k.index('')+1])) 228 | print("--") 229 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 230 | print("--") 231 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 232 | print("===") 233 | # self.save_model() 234 | 235 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer): 236 | "Standard Training and Logging Function" 237 | start = time.time() 238 | total_tokens = 0 239 | total_loss = 0 240 | tokens = 0 241 | ct = Clock(train_step_num) 242 | embedding_layer.to(device) 243 | model.to(device) 244 | for i, batch in enumerate(data_iter): 245 | batch.to(device) 246 | out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), 247 | batch.src_mask, batch.trg_mask) 248 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 249 | total_loss += loss 250 | total_tokens += batch.ntokens 251 | tokens += batch.ntokens 252 | batch.to("cpu") 253 | if i % 50 == 1: 254 | elapsed = time.time() - start 255 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) 256 | # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 257 | # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) 258 | start = time.time() 259 | tokens = 0 260 | else: 261 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 262 | return total_loss / total_tokens.float().to(device) 263 | 264 | def get_embedding_layer(utils): 265 | d_model = utils.emb_mat.shape[1] 266 | vocab = utils.emb_mat.shape[0] 267 | embedding_layer = nn.Embedding(vocab, d_model) 268 | embedding_layer.weight.data = torch.tensor(utils.emb_mat) 269 | embedding_layer.weight.requires_grad = False 270 | embedding_layer.to(device) 271 | return embedding_layer 272 | 273 | def pretrain(model, embedding_layer, utils, epoch_num=1): 274 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 275 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 276 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 277 | X_test_batch = None 278 | Y_test_batch = None 279 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 280 | X_test_batch = batch 281 | break 282 | 283 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 284 | Y_test_batch = batch 285 | break 286 | model.to(device) 287 | for epoch in range(epoch_num): 288 | model.train() 289 | print("EPOCH %d:"%(epoch+1)) 290 | pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 291 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 292 | pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 293 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 294 | model.eval() 295 | torch.save(model.state_dict(), 'model_pretrain.ckpt') 296 | x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 297 | y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 298 | 299 | for i,j in zip(X_test_batch.src, x): 300 | print("===") 301 | k = utils.idx2sent([i])[0] 302 | print("ORG:", " ".join(k[:k.index('')+1])) 303 | print("--") 304 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 305 | print("===") 306 | print("=====") 307 | for i, j in zip(Y_test_batch.src, y): 308 | print("===") 309 | k = utils.idx2sent([i])[0] 310 | print("ORG:", " ".join(k[:k.index('')+1])) 311 | print("--") 312 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 313 | print("===") 314 | 315 | if __name__ == "__main__": 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument("mode", help="execute mode") 318 | parser.add_argument("-filename", default=None, required=False, help="test filename") 319 | parser.add_argument("-load_model", default=False, required=False, help="test filename") 320 | parser.add_argument("-model_name", default="model.ckpt", required=False, help="test filename") 321 | parser.add_argument("-disc_name", default="cnn_disc/model_disc_pretrain_xy_inv.ckpt", required=False, help="test filename") 322 | parser.add_argument("-save_path", default="", required=False, help="test filename") 323 | parser.add_argument("-X_file", default="big_cna.txt", required=False, help="X domain text filename") 324 | parser.add_argument("-Y_file", default="big_cou.txt", required=False, help="Y domain text filename") 325 | parser.add_argument("-epoch", default=1, required=False, help="test filename") 326 | parser.add_argument("-max_len", default=20, required=False, help="test filename") 327 | parser.add_argument("-batch_size", default=32, required=False, help="batch size") 328 | args = parser.parse_args() 329 | 330 | model = Transformer(N=2) 331 | utils = Utils(X_data_path=args.X_file, Y_data_path=args.Y_file, batch_size=int(args.batch_size)) 332 | embedding_layer = get_embedding_layer(utils).to(device) 333 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 334 | if args.load_model: 335 | model.load_state_dict(torch.load(args.model_name)) 336 | if args.mode == "pretrain": 337 | pretrain(model, embedding_layer, utils, int(args.epoch)) 338 | if args.mode == "cycle": 339 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 340 | main_model = CycleGAN(disc, model, utils, embedding_layer) 341 | main_model.to(device) 342 | main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file=args.disc_name) 343 | main_model.train_model() 344 | if args.mode == "disc": 345 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 346 | main_model = CycleGAN(disc, model, utils, embedding_layer) 347 | main_model.to(device) 348 | main_model.pretrain_disc(2) 349 | 350 | if args.mode == "dev": 351 | model = Transformer(N=2) 352 | utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt") 353 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 354 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 355 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400, 356 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 357 | X_test_batch = None 358 | Y_test_batch = None 359 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 360 | X_test_batch = batch 361 | break 362 | 363 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 364 | Y_test_batch = batch 365 | break 366 | 367 | # if args.load_model: 368 | # model.load_model(filename=args.model_name) 369 | # if args.mode == "train": 370 | # model.train_model(num_epochs=int(args.epoch)) 371 | # print("========= Testing =========") 372 | # model.load_model() 373 | # model.test_corpus() 374 | # if args.mode == "test": 375 | # model.test_corpus() 376 | -------------------------------------------------------------------------------- /attn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | global device 10 | device = "cuda:1" 11 | 12 | class EncoderDecoder(nn.Module): 13 | """ 14 | A standard Encoder-Decoder architecture. Base for this and many 15 | other models. 16 | """ 17 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 18 | super(EncoderDecoder, self).__init__() 19 | self.encoder = encoder 20 | self.decoder = decoder 21 | self.src_embed = src_embed # Embedding function 22 | self.tgt_embed = tgt_embed # Embedding function 23 | self.generator = generator 24 | 25 | def forward(self, src, tgt, src_mask, tgt_mask): 26 | "Take in and process masked src and target sequences." 27 | return self.decode(self.encode(src, src_mask), src_mask, 28 | tgt, tgt_mask) 29 | 30 | def encode(self, src, src_mask): 31 | return self.encoder(self.src_embed(src), src_mask) 32 | 33 | def decode(self, memory, src_mask, tgt, tgt_mask): 34 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 35 | 36 | 37 | class Generator(nn.Module): 38 | "Define standard linear + softmax generation step." 39 | def __init__(self, d_model, vocab): 40 | super(Generator, self).__init__() 41 | self.proj = nn.Linear(d_model, vocab) 42 | self.softmax = nn.Softmax(dim=-1) 43 | 44 | def forward(self, x): 45 | return F.log_softmax(self.proj(x), dim=-1) 46 | 47 | def scaled_forward(self, x, scale=1.0): 48 | return self.softmax(self.proj(x)*scale) 49 | 50 | def clones(module, N): 51 | "Produce N identical layers." 52 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 53 | 54 | class Encoder(nn.Module): 55 | "Core encoder is a stack of N layers" 56 | def __init__(self, layer, N): 57 | super(Encoder, self).__init__() 58 | self.layers = clones(layer, N) # layer = EncoderLayer() 59 | self.norm = LayerNorm(layer.size) 60 | 61 | def forward(self, x, mask): 62 | "Pass the input (and mask) through each layer in turn." 63 | for layer in self.layers: 64 | x = layer(x, mask) 65 | return self.norm(x) 66 | 67 | class FullEncoder(nn.Module): 68 | """ 69 | A standard Encoder-Decoder architecture. Base for this and many 70 | other models. 71 | """ 72 | def __init__(self, encoder, src_embed): 73 | super(FullEncoder, self).__init__() 74 | self.encoder = encoder 75 | self.src_embed = src_embed # Embedding function 76 | 77 | def forward(self, src, src_mask): 78 | return self.encoder(self.src_embed(src), src_mask) 79 | 80 | 81 | class LayerNorm(nn.Module): 82 | "Construct a layernorm module (See citation for details)." 83 | def __init__(self, features, eps=1e-6): 84 | super(LayerNorm, self).__init__() 85 | self.a_2 = nn.Parameter(torch.ones(features)) 86 | self.b_2 = nn.Parameter(torch.zeros(features)) 87 | self.eps = eps 88 | 89 | def forward(self, x): 90 | mean = x.mean(-1, keepdim=True) 91 | std = x.std(-1, keepdim=True) 92 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 93 | 94 | class SublayerConnection(nn.Module): 95 | """ 96 | A residual connection followed by a layer norm. 97 | Note for code simplicity the norm is first as opposed to last. 98 | """ 99 | def __init__(self, size, dropout): 100 | super(SublayerConnection, self).__init__() 101 | self.norm = LayerNorm(size) 102 | self.dropout = nn.Dropout(dropout) 103 | 104 | def forward(self, x, sublayer): 105 | "Apply residual connection to any sublayer with the same size." 106 | return x + self.dropout(sublayer(self.norm(x))) 107 | 108 | class EncoderLayer(nn.Module): 109 | "Encoder is made up of self-attn and feed forward (defined below)" 110 | def __init__(self, size, self_attn, feed_forward, dropout): 111 | super(EncoderLayer, self).__init__() 112 | self.self_attn = self_attn 113 | self.feed_forward = feed_forward 114 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 115 | self.size = size 116 | 117 | def forward(self, x, mask): 118 | "Follow Figure 1 (left) for connections." 119 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 120 | return self.sublayer[1](x, self.feed_forward) 121 | 122 | class Decoder(nn.Module): 123 | "Generic N layer decoder with masking." 124 | def __init__(self, layer, N): 125 | super(Decoder, self).__init__() 126 | self.layers = clones(layer, N) 127 | self.norm = LayerNorm(layer.size) 128 | 129 | def forward(self, x, memory, src_mask, tgt_mask): 130 | for layer in self.layers: 131 | x = layer(x, memory, src_mask, tgt_mask) 132 | return self.norm(x) 133 | 134 | class DecoderLayer(nn.Module): 135 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 136 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 137 | super(DecoderLayer, self).__init__() 138 | self.size = size 139 | self.self_attn = self_attn 140 | self.src_attn = src_attn 141 | self.feed_forward = feed_forward 142 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 143 | 144 | def forward(self, x, memory, src_mask, tgt_mask): 145 | "Follow Figure 1 (right) for connections." 146 | m = memory 147 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 148 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 149 | return self.sublayer[2](x, self.feed_forward) 150 | 151 | def subsequent_mask(size): 152 | "Mask out subsequent positions." 153 | attn_shape = (1, size, size) 154 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 155 | return torch.from_numpy(subsequent_mask) == 0 156 | 157 | def attention(query, key, value, mask=None, dropout=None): 158 | "Compute 'Scaled Dot Product Attention'" 159 | d_k = query.size(-1) 160 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 161 | / math.sqrt(d_k) 162 | if mask is not None: 163 | scores = scores.masked_fill(mask == 0, -1e9) 164 | p_attn = F.softmax(scores, dim = -1) 165 | if dropout is not None: 166 | p_attn = dropout(p_attn) 167 | return torch.matmul(p_attn, value), p_attn 168 | 169 | class MultiHeadedAttention(nn.Module): 170 | def __init__(self, h, d_model, dropout=0.1): 171 | "Take in model size and number of heads." 172 | super(MultiHeadedAttention, self).__init__() 173 | assert d_model % h == 0 174 | # We assume d_v always equals d_k 175 | self.d_k = d_model // h 176 | self.h = h 177 | self.linears = clones(nn.Linear(d_model, d_model), 4) 178 | self.attn = None 179 | self.dropout = nn.Dropout(p=dropout) 180 | 181 | def forward(self, query, key, value, mask=None): 182 | "Implements Figure 2" 183 | if mask is not None: 184 | # Same mask applied to all h heads. 185 | mask = mask.unsqueeze(1) 186 | nbatches = query.size(0) 187 | 188 | # 1) Do all the linear projections in batch from d_model => h x d_k 189 | query, key, value = \ 190 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 191 | for l, x in zip(self.linears, (query, key, value))] 192 | 193 | # 2) Apply attention on all the projected vectors in batch. 194 | x, self.attn = attention(query, key, value, mask=mask.to(device) if type(mask)!=type(None) else mask, 195 | dropout=self.dropout) 196 | 197 | # 3) "Concat" using a view and apply a final linear. 198 | x = x.transpose(1, 2).contiguous() \ 199 | .view(nbatches, -1, self.h * self.d_k) 200 | return self.linears[-1](x) 201 | 202 | class PositionwiseFeedForward(nn.Module): 203 | "Implements FFN equation." 204 | def __init__(self, d_model, d_ff, dropout=0.1): 205 | super(PositionwiseFeedForward, self).__init__() 206 | self.w_1 = nn.Linear(d_model, d_ff) 207 | self.w_2 = nn.Linear(d_ff, d_model) 208 | self.dropout = nn.Dropout(dropout) 209 | 210 | def forward(self, x): 211 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 212 | 213 | class Embeddings(nn.Module): 214 | def __init__(self, d_model, vocab, pre_trained_matrix): 215 | super(Embeddings, self).__init__() 216 | self.lut = nn.Embedding(vocab, d_model) 217 | self.lut.weight.data = torch.tensor(pre_trained_matrix) 218 | self.lut.weight.requires_grad = False 219 | self.lut.to(device) 220 | self.d_model = d_model 221 | 222 | def forward(self, x): 223 | return self.lut(x.to(device)) * math.sqrt(self.d_model) 224 | 225 | class Scale(nn.Module): 226 | def __init__(self, d_model): 227 | super(Scale, self).__init__() 228 | self.d_model = d_model 229 | 230 | def forward(self, x): 231 | return x * math.sqrt(self.d_model) 232 | 233 | class PositionalEncoding(nn.Module): 234 | "Implement the PE function." 235 | def __init__(self, d_model, dropout, max_len=5000): 236 | super(PositionalEncoding, self).__init__() 237 | self.dropout = nn.Dropout(p=dropout) 238 | 239 | # Compute the positional encodings once in log space. 240 | pe = torch.zeros(max_len, d_model, dtype=torch.float32) 241 | position = torch.arange(0, max_len).unsqueeze(1).float() 242 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 243 | -(math.log(10000.0) / d_model)) 244 | pe[:, 0::2] = torch.sin(position * div_term) 245 | pe[:, 1::2] = torch.cos(position * div_term) 246 | pe = pe.unsqueeze(0) 247 | self.register_buffer('pe', pe) 248 | 249 | def forward(self, x): 250 | x = x.float() + self.pe[:, :x.size(1)] 251 | return self.dropout(x) 252 | 253 | def Transformer(N=6, d_model=1024, d_ff=2048, h=8, dropout=0.1): 254 | "Helper: Construct a model from hyperparameters." 255 | c = copy.deepcopy 256 | attn = MultiHeadedAttention(h, d_model) 257 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 258 | position = PositionalEncoding(d_model, dropout) 259 | model = EncoderDecoder( 260 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 261 | Decoder(DecoderLayer(d_model, c(attn), c(attn), 262 | c(ff), dropout), N), 263 | nn.Sequential(Scale(d_model), c(position)), 264 | nn.Sequential(Scale(d_model), c(position)), None) 265 | 266 | # This was important from their code. 267 | # Initialize parameters with Glorot / fan_avg. 268 | for p in model.parameters(): 269 | if p.dim() > 1: 270 | nn.init.xavier_uniform_(p) 271 | return model 272 | 273 | def TransformerEncoder(N=6, d_model=1024, d_ff=2048, h=8, dropout=0.1): 274 | "Helper: Construct a model from hyperparameters." 275 | c = copy.deepcopy 276 | attn = MultiHeadedAttention(h, d_model) 277 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 278 | position = PositionalEncoding(d_model, dropout) 279 | model = FullEncoder( 280 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 281 | nn.Sequential(Scale(d_model), c(position))) 282 | 283 | # This was important from their code. 284 | # Initialize parameters with Glorot / fan_avg. 285 | for p in model.parameters(): 286 | if p.dim() > 1: 287 | nn.init.xavier_uniform_(p) 288 | return model 289 | 290 | class Batch: 291 | "Object for holding a batch of data with mask during training." 292 | def __init__(self, src, trg=None, pad=0): 293 | self.src = src 294 | self.src_mask = (src != pad).unsqueeze(-2) 295 | if trg is not None: 296 | self.trg = trg[:, :-1] 297 | self.trg_y = trg[:, 1:] 298 | self.trg_mask = \ 299 | self.make_std_mask(self.trg, pad) 300 | self.ntokens = (self.trg_y != pad).data.sum() 301 | 302 | def to(self, device): 303 | self.src.to(device) 304 | if type(self.trg)!=type(None): 305 | self.trg.to(device) 306 | self.trg_y.to(device) 307 | 308 | @staticmethod 309 | def make_std_mask(tgt, pad): 310 | "Create a mask to hide padding and future words." 311 | tgt_mask = (tgt != pad).unsqueeze(-2) 312 | tgt_mask = tgt_mask & Variable( 313 | subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) 314 | return tgt_mask 315 | 316 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num): 317 | "Standard Training and Logging Function" 318 | start = time.time() 319 | total_tokens = 0 320 | total_loss = 0 321 | tokens = 0 322 | ct = Clock(train_step_num) 323 | for i, batch in enumerate(data_iter): 324 | out = model.forward(batch.src, batch.trg, 325 | batch.src_mask, batch.trg_mask) 326 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 327 | total_loss += loss 328 | total_tokens += batch.ntokens 329 | tokens += batch.ntokens 330 | if i % 50 == 1: 331 | elapsed = time.time() - start 332 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) 333 | # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 334 | # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) 335 | start = time.time() 336 | tokens = 0 337 | else: 338 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 339 | return total_loss / total_tokens.float().to(device) 340 | 341 | def data_gen(iterator, sent2idx): 342 | "Generate random data for a src-tgt copy task." 343 | for i in iterator: 344 | data = torch.from_numpy(sent2idx(i)).long() 345 | # data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10))) 346 | # data[:, 0] = 1 347 | data = torch.cat((torch.full((data.shape[0], 1), 2, dtype=torch.long), data), dim=1) 348 | src = Variable(data, requires_grad=False) 349 | tgt = Variable(data, requires_grad=False) 350 | yield Batch(src, tgt, 0) 351 | 352 | class SimpleLossCompute: 353 | "A simple loss compute and train function." 354 | def __init__(self, generator, criterion, opt=None): 355 | self.generator = generator 356 | self.criterion = criterion 357 | self.opt = opt 358 | 359 | def __call__(self, x, y, norm): 360 | if self.generator != None: 361 | x = self.generator(x) 362 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)).to(device), 363 | y.contiguous().view(-1)).to(device) / norm.float().to(device) 364 | loss.backward() 365 | if self.opt is not None: 366 | self.opt.step() 367 | self.opt.optimizer.zero_grad() 368 | # print("ddd:", loss.data) 369 | return loss.data.to(device) * norm.float().to(device) 370 | 371 | global max_src_in_batch, max_tgt_in_batch 372 | def batch_size_fn(new, count, sofar): 373 | "Keep augmenting batch and calculate total number of tokens + padding." 374 | global max_src_in_batch, max_tgt_in_batch 375 | if count == 1: 376 | max_src_in_batch = 0 377 | max_tgt_in_batch = 0 378 | max_src_in_batch = max(max_src_in_batch, len(new.src)) 379 | max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2) 380 | src_elements = count * max_src_in_batch 381 | tgt_elements = count * max_tgt_in_batch 382 | return max(src_elements, tgt_elements) 383 | 384 | class NoamOpt: 385 | "Optim wrapper that implements rate." 386 | def __init__(self, model_size, factor, warmup, optimizer): 387 | self.optimizer = optimizer 388 | self._step = 0 389 | self.warmup = warmup 390 | self.factor = factor 391 | self.model_size = model_size 392 | self._rate = 0 393 | 394 | def step(self): 395 | "Update parameters and rate" 396 | self._step += 1 397 | rate = self.rate() 398 | for p in self.optimizer.param_groups: 399 | p['lr'] = rate 400 | self._rate = rate 401 | self.optimizer.step() 402 | 403 | def rate(self, step = None): 404 | "Implement `lrate` above" 405 | if step is None: 406 | step = self._step 407 | return self.factor * \ 408 | (self.model_size ** (-0.5) * 409 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 410 | 411 | def get_std_opt(model): 412 | return NoamOpt(model.src_embed[0].d_model, 2, 4000, 413 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 414 | 415 | def greedy_decode(model, embed, src, src_mask, max_len, start_symbol=2): 416 | memory = model.encode(embed(src.to(device)), src_mask) 417 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).type_as(src.data).to(device) 418 | for i in range(max_len-1): 419 | out = model.decode(memory, src_mask, 420 | Variable(embed(ys)), 421 | Variable(subsequent_mask(ys.size(1)) 422 | .type_as(src.data))) 423 | prob = model.generator(out[:, -1]) 424 | _, next_word = torch.max(prob, dim = 1) 425 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 426 | return ys 427 | 428 | class LabelSmoothing(nn.Module): 429 | "Implement label smoothing." 430 | def __init__(self, size, padding_idx, smoothing=0.0): 431 | super(LabelSmoothing, self).__init__() 432 | self.criterion = nn.KLDivLoss(size_average=False) 433 | self.padding_idx = padding_idx 434 | self.confidence = 1.0 - smoothing 435 | self.smoothing = smoothing 436 | self.size = size 437 | self.true_dist = None 438 | 439 | def forward(self, x, target): 440 | assert x.size(1) == self.size 441 | true_dist = x.data.clone() 442 | true_dist.fill_(self.smoothing / (self.size - 2)) 443 | true_dist.scatter_(1, target.data.unsqueeze(1).to(device), self.confidence) 444 | true_dist[:, self.padding_idx] = 0 445 | mask = torch.nonzero(target.data == self.padding_idx) 446 | if mask.dim() > 0: 447 | true_dist.index_fill_(0, mask.squeeze().to(device), 0.0) 448 | self.true_dist = true_dist 449 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 450 | 451 | 452 | 453 | if __name__ == "__main__": 454 | utils = Utils(X_data_path="small_cou.txt", Y_data_path="small_cna.txt") 455 | # Train the simple copy task. 456 | V = utils.emb_mat.shape[0] 457 | criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0) 458 | model = make_model(V, V, utils.emb_mat, utils.emb_mat) 459 | model.to("cuda:0") 460 | model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400, 461 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 462 | 463 | X_test_batch = None 464 | Y_test_batch = None 465 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 466 | X_test_batch = batch 467 | break 468 | 469 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 470 | Y_test_batch = batch 471 | break 472 | 473 | if sys.argv[1] == "train": 474 | for epoch in range(int(sys.argv[2])): 475 | model.train() 476 | print("EPOCH %d:"%(epoch+1)) 477 | pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 478 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num) 479 | model.eval() 480 | 481 | x = utils.idx2sent(greedy_decode(model, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 482 | y = utils.idx2sent(greedy_decode(model, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 483 | 484 | for i,j in zip(X_test_batch.src, x): 485 | print("===") 486 | k = utils.idx2sent([i])[0] 487 | print("ORG:", " ".join(k[:k.index('')+1])) 488 | print("--") 489 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 490 | print("===") 491 | print("=====") 492 | for i, j in zip(Y_test_batch.src, y): 493 | print("===") 494 | k = utils.idx2sent([i])[0] 495 | print("ORG:", " ".join(k[:k.index('')+1])) 496 | print("--") 497 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 498 | print("===") 499 | 500 | # print(pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 501 | # SimpleLossCompute(model.generator, criterion, None), utils.train_step_num)) 502 | 503 | torch.save(model.state_dict(), 'model.ckpt') -------------------------------------------------------------------------------- /v4_cyclegan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | from attn import Transformer, LabelSmoothing, \ 10 | data_gen, NoamOpt, Generator, SimpleLossCompute, \ 11 | greedy_decode, subsequent_mask 12 | from utils import Utils 13 | from char_cnn_discriminator import Discriminator 14 | import argparse 15 | 16 | device = "cuda:1" 17 | 18 | 19 | def prob_backward(model, embed, src, src_mask, max_len, start_symbol=2, raw=False): 20 | if raw==False: 21 | memory = model.encode(embed(src.to(device)), src_mask) 22 | else: 23 | memory = model.encode(src.to(device), src_mask) 24 | 25 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 26 | probs = [] 27 | for i in range(max_len+2-1): 28 | out = model.decode(memory, src_mask, 29 | embed(Variable(ys)), 30 | Variable(subsequent_mask(ys.size(1)) 31 | .type_as(src.data))) 32 | prob = model.generator(out[:, -1]) 33 | probs.append(prob.unsqueeze(1)) 34 | _, next_word = torch.max(prob, dim = 1) 35 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 36 | ret = torch.cat(probs, dim=1) 37 | return ret 38 | 39 | def perplexity(prob): 40 | entropy = -(prob * torch.log(prob)).sum(-1) 41 | return torch.exp(entropy.mean()) 42 | 43 | 44 | def backward_decode(model, embed, src, src_mask, max_len, start_symbol=2, raw=False, return_term=-1): 45 | if raw==False: 46 | memory = model.encode(embed(src.to(device)), src_mask) 47 | else: 48 | memory = model.encode(src.to(device), src_mask) 49 | 50 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 51 | ret_back = embed(ys).float() 52 | if return_term == 2: 53 | ppls = 0 54 | for i in range(max_len+2-1): 55 | out = model.decode(memory, src_mask, 56 | embed(Variable(ys)), 57 | Variable(subsequent_mask(ys.size(1)) 58 | .type_as(src.data))) 59 | prob = model.generator.scaled_forward(out[:, -1], scale=10.0) 60 | if return_term == 2: 61 | ppls += perplexity(model.generator.scaled_forward(out[:, -1], scale=1.0)) 62 | back = torch.matmul(prob ,embed.weight.data.float()) 63 | _, next_word = torch.max(prob, dim = 1) 64 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 65 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 66 | return (ret_back, ys) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else (ret_back, ppls) if return_term == 2 else None 67 | 68 | def reconstruct(model, src, max_len, start_symbol=2): 69 | memory = model.encoder(model.src_embed[1](src), None) 70 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) 71 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 72 | for i in range(max_len-1): 73 | out = model.decode(memory, None, 74 | Variable(ys), 75 | Variable(subsequent_mask(ys.size(1)) 76 | .type_as(src.data))) 77 | prob = model.generator(out[:, -1]) 78 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 79 | _, next_word = torch.max(prob, dim = 1) 80 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 81 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 82 | return ret_back 83 | 84 | def wgan_pg(netD, fake_data, real_data, lamb=10): 85 | batch_size = fake_data.shape[0] 86 | ## 1. interpolation 87 | alpha = torch.rand(batch_size, 1, 1).expand(real_data.size()).to(device) 88 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 89 | interpolates = Variable(interpolates.to(device), requires_grad=True) 90 | ## 2. gradient penalty 91 | disc_interpolates = netD(interpolates).view(batch_size, ) 92 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 93 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 94 | create_graph=True, retain_graph=True, only_inputs=True)[0] 95 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb 96 | ## 3. append it to loss function 97 | return gradient_penalty 98 | 99 | class CycleGAN(nn.Module): 100 | def __init__(self, discriminator, generator, utils, embedder): 101 | super(CycleGAN, self).__init__() 102 | self.D = discriminator 103 | self.G = generator 104 | self.R = copy.deepcopy(generator) 105 | self.D_opt = torch.optim.Adam(self.D.parameters()) 106 | # self.G_opt = torch.optim.Adam(self.G.parameters()) 107 | self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 108 | torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 109 | # self.R_opt = torch.optim.Adam(self.R.parameters()) 110 | self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 111 | torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 112 | self.embed = embedder 113 | 114 | self.utils = utils 115 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 116 | self.mse = nn.MSELoss() 117 | self.cos = nn.CosineSimilarity(dim=-1) 118 | self.cosloss=nn.CosineEmbeddingLoss() 119 | self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 120 | self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt) 121 | 122 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): 123 | torch.save(self.D.state_dict(), d_path) 124 | torch.save(self.G.state_dict(), g_path) 125 | torch.save(self.R.state_dict(), r_path) 126 | 127 | def load_model(self, path="", g_file=None, d_file=None, r_file=None): 128 | if g_file!=None: 129 | self.G.load_state_dict(torch.load(os.path.join(path, g_file), map_location=device)) 130 | if d_file!=None: 131 | self.D.load_state_dict(torch.load(os.path.join(path, d_file), map_location=device)) 132 | if r_file!=None: 133 | self.R.load_state_dict(torch.load(os.path.join(path, r_file), map_location=device)) 134 | print("model loaded!") 135 | 136 | def pretrain_disc(self, num_epochs=100): 137 | X_datagen = self.utils.data_generator("X") 138 | Y_datagen = self.utils.data_generator("Y") 139 | for epoch in range(num_epochs): 140 | d_steps = self.utils.train_step_num 141 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 142 | for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 143 | # 1. Train D on real+fake 144 | # if epoch == 0: 145 | # break 146 | self.D.zero_grad() 147 | 148 | # 1A: Train D on real 149 | d_real_pred = self.D(self.embed(Y_data.src.to(device))) 150 | d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(device)) # ones = true 151 | 152 | # 1B: Train D on fake 153 | d_fake_pred = self.D(self.embed(X_data.src.to(device))) 154 | d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(device)) # zeros = fake 155 | (d_fake_error + d_real_error).backward() 156 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 157 | d_ct.flush(info={"D_loss": d_fake_error.item()}) 158 | torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt") 159 | 160 | def train_model(self, num_epochs=100, g_scale=1.0, r_scale=1.0): 161 | for i, batch in enumerate(data_gen(utils.test_generator("X"), utils.sents2idx)): 162 | X_test_batch = batch 163 | break 164 | 165 | for i, batch in enumerate(data_gen(utils.test_generator("Y"), utils.sents2idx)): 166 | Y_test_batch = batch 167 | break 168 | X_datagen = self.utils.data_generator("X") 169 | Y_datagen = self.utils.data_generator("Y") 170 | for epoch in range(num_epochs): 171 | ct = Clock(self.utils.train_step_num, title="Train G/D (%d/%d)" % (epoch, num_epochs)) 172 | for i, X_data, Y_data in zip(range(self.utils.train_step_num), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 173 | for d_step in range(4): 174 | # 1. Train D on real+fake 175 | # if epoch == 0: 176 | # break 177 | self.D.zero_grad() 178 | 179 | # 1A: Train D on real 180 | d_real_data = self.embed(Y_data.src.to(device)).float() 181 | d_real_pred = self.D(d_real_data) 182 | # d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(device)) # ones = true 183 | 184 | # 1B: Train D on fake 185 | self.G.to(device) 186 | d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels 187 | d_fake_pred = self.D(d_fake_data) 188 | # d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(device)) # zeros = fake 189 | # (d_fake_error + d_real_error).backward() 190 | d_loss = d_fake_pred.mean() - d_real_pred.mean() 191 | d_loss += wgan_pg(self.D, d_fake_data, d_real_data, lamb=10) 192 | 193 | d_loss.backward() 194 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 195 | 196 | self.G.zero_grad() 197 | g_fake_data, g_ppl = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=2) 198 | dg_fake_pred = self.D(g_fake_data) 199 | # g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(device)) # we want to fool, so pretend it's all genuine 200 | g_loss = -dg_fake_pred.mean() + 0.1*(60.0 - g_ppl)**2 201 | 202 | g_loss.backward(retain_graph=True) 203 | self.G_opt.step() # Only optimizes G's parameters 204 | self.G.zero_grad() 205 | 206 | # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com 207 | # way_3 208 | out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 209 | None, X_data.trg_mask) 210 | r_loss = 10000*self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) 211 | self.G_opt.step() 212 | self.G_opt.optimizer.zero_grad() 213 | ct.flush(info={"D": d_loss.item(), 214 | "G": g_loss.item(), 215 | "R": r_loss / X_data.ntokens.float().to(device)}) 216 | # way_2 217 | # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) 218 | # x_orgi_data = X_data.src[:, 1:] 219 | # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) 220 | # way_1 221 | # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] 222 | # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) 223 | if i%100 == 0: 224 | with torch.no_grad(): 225 | x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) 226 | x = utils.idx2sent(x_ys) 227 | y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) 228 | y = utils.idx2sent(y_ys) 229 | r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 230 | r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 231 | 232 | for i,j,l in zip(X_test_batch.src, x, r_x): 233 | print("===") 234 | k = utils.idx2sent([i])[0] 235 | print("ORG:", " ".join(k[:k.index('')+1])) 236 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 237 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 238 | print("=====") 239 | for i, j, l in zip(Y_test_batch.src, y, r_y): 240 | print("===") 241 | k = utils.idx2sent([i])[0] 242 | print("ORG:", " ".join(k[:k.index('')+1])) 243 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 244 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 245 | self.save_model() 246 | 247 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer): 248 | "Standard Training and Logging Function" 249 | start = time.time() 250 | total_tokens = 0 251 | total_loss = 0 252 | tokens = 0 253 | ct = Clock(train_step_num) 254 | embedding_layer.to(device) 255 | model.to(device) 256 | for i, batch in enumerate(data_iter): 257 | batch.to(device) 258 | out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), 259 | batch.src_mask, batch.trg_mask) 260 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 261 | total_loss += loss 262 | total_tokens += batch.ntokens 263 | tokens += batch.ntokens 264 | batch.to("cpu") 265 | if i % 50 == 1: 266 | elapsed = time.time() - start 267 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) 268 | # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 269 | # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) 270 | start = time.time() 271 | tokens = 0 272 | else: 273 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 274 | return total_loss / total_tokens.float().to(device) 275 | 276 | def get_embedding_layer(utils): 277 | d_model = utils.emb_mat.shape[1] 278 | vocab = utils.emb_mat.shape[0] 279 | embedding_layer = nn.Embedding(vocab, d_model) 280 | embedding_layer.weight.data = torch.tensor(utils.emb_mat) 281 | embedding_layer.weight.requires_grad = False 282 | embedding_layer.to(device) 283 | return embedding_layer 284 | 285 | def pretrain(model, embedding_layer, utils, epoch_num=1): 286 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 287 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 288 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 289 | X_test_batch = None 290 | Y_test_batch = None 291 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 292 | X_test_batch = batch 293 | break 294 | 295 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 296 | Y_test_batch = batch 297 | break 298 | model.to(device) 299 | for epoch in range(epoch_num): 300 | model.train() 301 | print("EPOCH %d:"%(epoch+1)) 302 | pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 303 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 304 | pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 305 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 306 | model.eval() 307 | torch.save(model.state_dict(), 'model_pretrain.ckpt') 308 | x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 309 | y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 310 | 311 | for i,j in zip(X_test_batch.src, x): 312 | print("===") 313 | k = utils.idx2sent([i])[0] 314 | print("ORG:", " ".join(k[:k.index('')+1])) 315 | print("--") 316 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 317 | print("===") 318 | print("=====") 319 | for i, j in zip(Y_test_batch.src, y): 320 | print("===") 321 | k = utils.idx2sent([i])[0] 322 | print("ORG:", " ".join(k[:k.index('')+1])) 323 | print("--") 324 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 325 | print("===") 326 | 327 | if __name__ == "__main__": 328 | parser = argparse.ArgumentParser() 329 | parser.add_argument("mode", help="execute mode") 330 | parser.add_argument("-filename", default=None, required=False, help="test filename") 331 | parser.add_argument("-load_model", default=False, required=False, help="test filename") 332 | parser.add_argument("-model_name", default="model.ckpt", required=False, help="test filename") 333 | parser.add_argument("-disc_name", default="cnn_disc/model_disc_pretrain_xy_inv.ckpt", required=False, help="test filename") 334 | parser.add_argument("-save_path", default="", required=False, help="test filename") 335 | parser.add_argument("-X_file", default="shuf_cna.txt", required=False, help="X domain text filename") 336 | parser.add_argument("-Y_file", default="shuf_cou.txt", required=False, help="Y domain text filename") 337 | parser.add_argument("-X_test", default="test.cna", required=False, help="X domain text filename") 338 | parser.add_argument("-Y_test", default="test.cou", required=False, help="Y domain text filename") 339 | parser.add_argument("-epoch", default=1, required=False, help="test filename") 340 | parser.add_argument("-max_len", default=20, required=False, help="test filename") 341 | parser.add_argument("-batch_size", default=32, required=False, help="batch size") 342 | args = parser.parse_args() 343 | 344 | model = Transformer(N=2) 345 | utils = Utils(X_data_path=args.X_file, Y_data_path=args.Y_file, 346 | X_test_path=args.X_test, Y_test_path=args.Y_test, 347 | batch_size=int(args.batch_size)) 348 | embedding_layer = get_embedding_layer(utils).to(device) 349 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 350 | if args.load_model: 351 | model.load_state_dict(torch.load(args.model_name)) 352 | if args.mode == "pretrain": 353 | pretrain(model, embedding_layer, utils, int(args.epoch)) 354 | if args.mode == "cycle": 355 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 356 | main_model = CycleGAN(disc, model, utils, embedding_layer) 357 | main_model.to(device) 358 | main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file=None) 359 | main_model.train_model() 360 | if args.mode == "disc": 361 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 362 | main_model = CycleGAN(disc, model, utils, embedding_layer) 363 | main_model.to(device) 364 | main_model.pretrain_disc(2) 365 | 366 | if args.mode == "dev": 367 | model = Transformer(N=2) 368 | utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt") 369 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 370 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 371 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400, 372 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 373 | X_test_batch = None 374 | Y_test_batch = None 375 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 376 | X_test_batch = batch 377 | break 378 | 379 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 380 | Y_test_batch = batch 381 | break 382 | 383 | # if args.load_model: 384 | # model.load_model(filename=args.model_name) 385 | # if args.mode == "train": 386 | # model.train_model(num_epochs=int(args.epoch)) 387 | # print("========= Testing =========") 388 | # model.load_model() 389 | # model.test_corpus() 390 | # if args.mode == "test": 391 | # model.test_corpus() 392 | -------------------------------------------------------------------------------- /maskgan.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import os 5 | from ELMoForManyLangs import elmo 6 | import numpy as np 7 | from jexus import Clock, History 8 | import math 9 | from sklearn.metrics import f1_score 10 | from sklearn.metrics import accuracy_score 11 | import argparse 12 | import sys 13 | import random, time 14 | cwd = os.getcwd() 15 | sys.path.append(os.path.join(os.path.dirname(__file__))) 16 | sys.path.append(os.path.join(cwd, "../InverseELMo")) 17 | sys.path.append(os.path.join(cwd, "../CycleGAN-sentiment-transfer")) 18 | from invELMo import invELMo 19 | 20 | def f2h(s): 21 | s = list(s) 22 | for i in range(len(s)): 23 | num = ord(s[i]) 24 | if num == 0x3000: 25 | num = 32 26 | elif 0xFF01 <= num <= 0xFF5E: 27 | num -= 0xfee0 28 | s[i] = chr(num).translate(str.maketrans('﹕﹐﹑。﹔﹖﹗﹘ ', ':,、。;?!- ')) 29 | return re.sub(r"( | )+", " ", "".join(s)).strip() 30 | 31 | def sort_list(li, piv=2,unsort_ind=None): 32 | ind = [] 33 | if unsort_ind == None: 34 | ind = sorted(range(len(li[piv])), key=(lambda k: li[piv][k])) 35 | else: 36 | ind = unsort_ind 37 | for i in range(len(li)): 38 | li[i] = [li[i][j] for j in ind] 39 | return li, ind 40 | 41 | def sort_numpy(li, piv=2,unsort=False): 42 | ind = np.argsort(-li[piv] if not unsort else li[piv], axis=0) 43 | for i in range(len(li)): 44 | if type(li[i]).__module__ == np.__name__ or type(li[i]).__module__ == torch.__name__: 45 | li[i] = li[i][ind] 46 | else: 47 | li[i] = [li[i][j] for j in ind] 48 | return li, ind 49 | 50 | def sort_torch(li, piv=2,unsort=False): 51 | li[piv], ind = torch.sort(li[piv], dim=0, descending=(not unsort)) 52 | for i in range(len(li)): 53 | if i == piv: 54 | continue 55 | else: 56 | li[i] = li[i][ind] 57 | return li, ind 58 | 59 | def sort_by(li, piv=2, unsort=False): 60 | if type(li[piv]).__module__ == np.__name__: 61 | return sort_numpy(li, piv, unsort) 62 | elif type(li[piv]).__module__ == torch.__name__: 63 | return sort_torch(li, piv, unsort) 64 | else: 65 | return sort_list(li, piv, unsort) 66 | 67 | 68 | class Embedder(): 69 | def __init__(self, seq_len=0, use_cuda=True, device=None): 70 | self.embedder = elmo.Embedder(batch_size=512, use_cuda=use_cuda) 71 | self.seq_len = seq_len 72 | self.bos_vec, self.eos_vec = np.load("bos_eos.npy") 73 | self.pad, self.oov = np.load("pad_oov.npy") 74 | self.device = device 75 | if self.device != None: 76 | self.embedder.model.to(self.device) 77 | 78 | def __call__(self, sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False): 79 | seq_lens = np.array([len(x) for x in sents], dtype=np.int64) 80 | sents = [[self.sub_unk(x) for x in sent] for sent in sents] 81 | if max_len != 0: 82 | pass 83 | elif self.seq_len != 0: 84 | max_len = self.seq_len 85 | else: 86 | max_len = seq_lens.max() 87 | emb_list = self.embedder.sents2elmo(sents, output_layer=layer) 88 | if not with_bos_eos: 89 | for i in range(len(emb_list)): 90 | if max_len - seq_lens[i] > 0: 91 | if pad_matters: 92 | emb_list[i] = np.concatenate([emb_list[i], np.tile(self.pad,[max_len - seq_lens[i],1])], axis=0) 93 | else: 94 | emb_list[i] = np.concatenate([emb_list[i], np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))]) 95 | else: 96 | emb_list[i] = emb_list[i][:max_len] 97 | elif with_bos_eos: 98 | for i in range(len(emb_list)): 99 | if max_len - seq_lens[i] > 0: 100 | if pad_matters: 101 | emb_list[i] = np.concatenate([ 102 | self.bos_vec[np.newaxis], 103 | emb_list[i], 104 | self.eos_vec[np.newaxis], 105 | np.tile(self.pad, [max_len - seq_lens[i], 1])], axis=0) 106 | else: 107 | emb_list[i] = np.concatenate([ 108 | self.bos_vec[np.newaxis], 109 | emb_list[i], 110 | self.eos_vec[np.newaxis], 111 | np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))], axis=0) 112 | else: 113 | emb_list[i] = np.concatenate([self.bos_vec[np.newaxis], emb_list[i][:max_len],self.eos_vec[np.newaxis]], axis=0) 114 | embedded = np.array(emb_list, dtype=np.float32) 115 | seq_lens = seq_lens+2 if with_bos_eos else seq_lens 116 | return embedded, seq_lens 117 | 118 | def sub_unk(self, e): 119 | e = e.replace(',',',') 120 | e = e.replace(':',':') 121 | e = e.replace(';',';') 122 | e = e.replace('?','?') 123 | e = e.replace('!', '!') 124 | return e 125 | 126 | 127 | class Utils(): 128 | def __init__(self, 129 | training_data_path, 130 | testing_data_path, 131 | batch_size = 32, elmo_device=None): 132 | self.training_data_path = training_data_path 133 | self.training_line_num = int(os.popen("wc -l %s"%self.training_data_path).read().split(' ')[0]) 134 | self.testing_data_path = testing_data_path 135 | self.testing_line_num = int(os.popen("wc -l %s"%self.testing_data_path).read().split(' ')[0]) 136 | self.elmo = Embedder(device=elmo_device, use_cuda=elmo_device!="cpu") 137 | self.batch_size = batch_size 138 | self.train_step_num = math.floor(self.training_line_num / batch_size) 139 | self.test_step_num = math.floor(self.testing_line_num / batch_size) 140 | self.device="cuda:0" 141 | 142 | def process_sent(self, sent): 143 | sent = f2h(sent) 144 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 145 | char_list = list("".join(word_list)) 146 | label_list = [] 147 | for word in word_list: 148 | label_list += [0] * (len(word) - 1) + [1] 149 | return char_list, label_list 150 | 151 | def sent2list(self, sent): 152 | sent = f2h(sent) 153 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 154 | return word_list 155 | 156 | def data_generator(self, mode="train"): 157 | path = eval("self.%sing_data_path" % mode) 158 | file = open(path) 159 | sents = [] 160 | for sent in file: 161 | if len(sent.strip()) == 0: 162 | continue 163 | word_list = self.sent2list(sent) 164 | sents.append(word_list) 165 | if len(sents) == self.batch_size: 166 | yield sents 167 | sents = [] 168 | fw.close() 169 | if len(sents)!=0: 170 | yield sents 171 | 172 | def raw2elmo(self, batch, with_bos_eos=True): 173 | embedded, seq_lens = self.elmo(batch, with_bos_eos=with_bos_eos) 174 | return embedded, seq_lens 175 | 176 | def elmo2mask(self, embedded, seq_lens, with_bos_eos=True, mask_rate=0.0): 177 | # embedded, seq_lens = self.elmo(batch) 178 | mask = np.full((embedded.shape[0], embedded.shape[1]), -1, dtype=np.int64) 179 | if mask_rate: 180 | (begin, end) = (1, -1) if with_bos_eos else(0, 0) 181 | rand_mat = np.random.rand(embedded.shape[0], embedded.shape[1]) 182 | for row in range(embedded.shape[0]): 183 | for col in range(begin, seq_lens[row] + end): 184 | if rand_mat[row, col] < mask_rate: 185 | embedded[row, col] = torch.zeros(embedded[row, col].shape) 186 | mask[row, col] = 0 187 | else: 188 | mask[row, col] = 1 189 | return embedded, seq_lens, mask 190 | 191 | class Generator(nn.Module): 192 | def __init__(self, 193 | batch_size=32, 194 | device="cuda:0", 195 | hidden_size=300, 196 | input_size=1024, 197 | encode_size=30, 198 | n_layers=3, 199 | dropout=0.33): 200 | super(Generator, self).__init__() 201 | self.device = device if torch.cuda.is_available() else "cpu" 202 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 203 | dropout=(0 if n_layers == 1 else dropout), 204 | bidirectional=True, 205 | batch_first=True) 206 | self.fc1 = nn.Linear(2*hidden_size, input_size) 207 | self.hidden_expander_1 = nn.Linear(encode_size, hidden_size) 208 | self.hidden_expander_2 = nn.Linear(encode_size, hidden_size) 209 | self.optimizer = torch.optim.Adam(self.parameters()) 210 | 211 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 212 | embedded = torch.from_numpy(input_seq).to(self.device) 213 | if sort: 214 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 215 | hidden[0] = self.hidden_expander_1(hidden[0]) 216 | hidden[1] = self.hidden_expander_2(hidden[1]) 217 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 218 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 219 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 220 | pred_seq = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 221 | embedded.cpu() 222 | if unsort: 223 | [pred_seq, _], _ = sort_by([pred_seq, ind], piv=1, unsort=True) 224 | return pred_seq 225 | 226 | class Discriminator(nn.Module): 227 | def __init__(self, 228 | batch_size=32, 229 | device="cuda:0", 230 | hidden_size=300, 231 | input_size=1024, 232 | n_layers=3, 233 | dropout=0.33): 234 | super(Discriminator, self).__init__() 235 | self.device = device if torch.cuda.is_available() else "cpu" 236 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 237 | dropout=(0 if n_layers == 1 else dropout), 238 | bidirectional=True, 239 | batch_first=True) 240 | self.fc1 = nn.Linear(2*hidden_size, 2) 241 | self.optimizer = torch.optim.Adam(self.parameters()) 242 | 243 | def forward(self, input_seq, input_lengths, hidden=None, numpy=True, sort=True, unsort=False): 244 | if numpy: 245 | embedded = torch.from_numpy(input_seq).to(self.device) 246 | else: 247 | embedded = input_seq.to(self.device) 248 | if sort: 249 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 250 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 251 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 252 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 253 | pred_prob = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 254 | embedded.cpu() 255 | if unsort: 256 | [pred_prob, _], _ = sort_by([pred_prob, ind], piv=1, unsort=True) 257 | return pred_prob 258 | 259 | 260 | class Encoder(nn.Module): 261 | def __init__(self, 262 | batch_size=32, 263 | device="cuda:0", 264 | hidden_size=300, 265 | encode_size=30, 266 | input_size=1024, 267 | n_layers=3, 268 | dropout=0.33): 269 | super(Encoder, self).__init__() 270 | self.device = device if torch.cuda.is_available() else "cpu" 271 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 272 | dropout=(0 if n_layers == 1 else dropout), 273 | bidirectional=True, 274 | batch_first=True) 275 | self.fc1 = nn.Linear(hidden_size, encode_size) 276 | self.fc2 = nn.Linear(hidden_size, encode_size) 277 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 278 | self.optimizer = torch.optim.Adam(self.parameters()) 279 | 280 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 281 | embedded = torch.from_numpy(input_seq).to(self.device) 282 | if sort: 283 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 284 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 285 | outputs, (h_n, c_n) = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 286 | outputs = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 287 | code1 = self.fc1(h_n) #nn.Softmax(dim=-1)(self.fc1(outputs)) 288 | code2 = self.fc2(c_n) 289 | embedded.cpu() 290 | if unsort: 291 | [code1, code2, _], _ = sort_by([code1, code2, ind], piv=2, unsort=True) 292 | return [code1, code2] 293 | 294 | def load_cpu_invelmo(): 295 | elmo = invELMo() 296 | elmo.device = "cpu" 297 | elmo.load_model(device="cpu") 298 | elmo.eval() 299 | return elmo 300 | 301 | class MaskGAN(): 302 | def __init__(self, embedder, discriminator, generator, encoder, utils): 303 | self.embedder = embedder 304 | self.D = discriminator 305 | self.G = generator 306 | self.encoder = encoder 307 | self.utils = utils 308 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 309 | self.invelmo = load_cpu_invelmo() 310 | self.mse = nn.MSELoss() 311 | 312 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt"): 313 | torch.save(self.D.state_dict(), d_path) 314 | torch.save(self.G.state_dict(), g_path) 315 | 316 | def load_model(self, path=""): 317 | self.D.load_state_dict(torch.load(os.path.join(path,"Dis_model.ckpt"))) 318 | self.G.load_state_dict(torch.load(os.path.join(path,"Gen_model.ckpt"))) 319 | print("model loaded!") 320 | 321 | def pretrain(self, num_epochs=1): 322 | self.G.to(self.G.device) 323 | self.encoder.to(self.encoder.device) 324 | real_datagen = self.utils.data_generator("train") 325 | test_datagen = self.utils.data_generator("test") 326 | for epoch in range(num_epochs): 327 | ct = Clock(self.utils.train_step_num, title="Pretrain(%d/%d)" % (epoch, num_epochs)) 328 | for real_data in real_datagen: 329 | # 2. Train G on D's response (but DO NOT train D on these labels) 330 | self.G.zero_grad() 331 | 332 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 333 | 334 | gen_input = self.encoder(g_org_data, g_data_seqlen) 335 | g_fake_data = self.G(g_org_data, g_data_seqlen, hidden=gen_input) 336 | loss = self.mse(g_fake_data, torch.from_numpy(g_org_data).to(self.G.device)) 337 | 338 | loss.backward() 339 | self.G.optimizer.step() # Only optimizes G's parameters 340 | self.encoder.optimizer.step() 341 | ct.flush(info={"G_loss": loss.item()}) 342 | 343 | with torch.no_grad(): 344 | for _, real_data in zip(range(2), test_datagen): 345 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 346 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 347 | g_mask_data, g_data_seqlen, g_mask_label = \ 348 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 349 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 350 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 351 | 352 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 353 | for i, j in zip(real_data, gen_sents): 354 | print("="*50) 355 | print(' '.join(i)) 356 | print("---") 357 | print(' '.join(j)) 358 | print("=" * 50) 359 | torch.save(self.G.state_dict(), "pretrain_model.ckpt") 360 | 361 | 362 | def train_model(self, num_epochs=100, d_steps=10, g_steps=10): 363 | self.D.to(self.D.device) 364 | self.G.to(self.G.device) 365 | self.encoder.to(self.encoder.device) 366 | real_datagen = self.utils.data_generator("train") 367 | test_datagen = self.utils.data_generator("test") 368 | for epoch in range(num_epochs): 369 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 370 | for d_step, real_data in zip(range(d_steps), real_datagen): 371 | # 1. Train D on real+fake 372 | self.D.zero_grad() 373 | 374 | # 1A: Train D on real 375 | d_org_data, d_data_seqlen = self.utils.raw2elmo(real_data) 376 | d_mask_data, d_data_seqlen, d_mask_label = \ 377 | self.utils.elmo2mask(d_org_data, d_data_seqlen, mask_rate=epoch/num_epochs) 378 | d_real_pred = self.D(d_org_data, d_data_seqlen) 379 | d_real_error = self.criterion(d_real_pred.transpose(1, 2), torch.ones(d_mask_label.shape, dtype=torch.int64).to(self.D.device)) # ones = true 380 | d_real_error.backward() # compute/store gradients, but don't change params 381 | self.D.optimizer.step() 382 | 383 | # 1B: Train D on fake 384 | d_gen_input = self.encoder(d_org_data, d_data_seqlen) 385 | d_fake_data = self.G(d_mask_data, d_data_seqlen, hidden=d_gen_input).detach() # detach to avoid training G on these labels 386 | d_fake_pred = self.D(d_fake_data, d_data_seqlen, numpy=False) 387 | d_fake_error = self.criterion(d_fake_pred.transpose(1, 2), torch.from_numpy(d_mask_label).to(self.D.device)) # zeros = fake 388 | d_fake_error.backward() 389 | self.D.optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 390 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 391 | 392 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 393 | for g_step, real_data in zip(range(g_steps), real_datagen): 394 | # 2. Train G on D's response (but DO NOT train D on these labels) 395 | self.G.zero_grad() 396 | 397 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 398 | g_mask_data, g_data_seqlen, g_mask_label = \ 399 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 400 | gen_input = self.encoder(g_org_data, g_data_seqlen) 401 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input) 402 | dg_fake_pred = self.D(g_fake_data, g_data_seqlen, numpy=False) 403 | g_error = self.criterion(dg_fake_pred.transpose(1, 2), torch.ones(g_mask_label.shape, dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 404 | 405 | g_error.backward() 406 | self.G.optimizer.step() # Only optimizes G's parameters 407 | self.encoder.optimizer.step() 408 | g_ct.flush(info={"G_loss": g_error.item()}) 409 | 410 | with torch.no_grad(): 411 | for _, real_data in zip(range(2), test_datagen): 412 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 413 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 414 | g_mask_data, g_data_seqlen, g_mask_label = \ 415 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 416 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 417 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 418 | 419 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 420 | for i, j in zip(real_data, gen_sents): 421 | print("="*50) 422 | print(' '.join(i)) 423 | print("---") 424 | print(' '.join(j)) 425 | print("=" * 50) 426 | self.save_model() 427 | 428 | 429 | 430 | if __name__ == "__main__": 431 | parser = argparse.ArgumentParser() 432 | parser.add_argument("mode", help="execute mode") 433 | # parser.add_argument("-train_file", default=None, required=False, help="test filename") 434 | # parser.add_argument("-test_file", default=None, required=True, help="test filename") 435 | parser.add_argument("-load_model_path", default=None, required=False, help="test filename") 436 | parser.add_argument("-epoch", default=1000, required=False, help="test filename") 437 | parser.add_argument("-d_step", default=100, required=False, help="test filename") 438 | parser.add_argument("-g_step", default=100, required=False, help="test filename") 439 | args = parser.parse_args() 440 | 441 | 442 | embedder, discriminator, generator, encoder, utils = \ 443 | Embedder(), Discriminator(), Generator(), Encoder(), \ 444 | Utils(training_data_path="data/train_as.txt", 445 | testing_data_path="data/test_as.txt", elmo_device="cuda:0") 446 | model = MaskGAN(embedder, discriminator, generator, encoder, utils) 447 | if args.load_model_path != None: 448 | model.load_model(args.load_model_path) 449 | if args.mode == "train": 450 | model.train_model(num_epochs=int(args.epoch), d_steps=int(args.d_step), g_steps=int(args.g_step)) 451 | if args.mode == "pretrain": 452 | model.pretrain(num_epochs=int(args.epoch)) 453 | 454 | -------------------------------------------------------------------------------- /cyclegan.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import os 5 | from ELMoForManyLangs import elmo 6 | import numpy as np 7 | from jexus import Clock, History 8 | import math 9 | from sklearn.metrics import f1_score 10 | from sklearn.metrics import accuracy_score 11 | import argparse 12 | import sys 13 | import random, time 14 | cwd = os.getcwd() 15 | sys.path.append(os.path.join(os.path.dirname(__file__))) 16 | sys.path.append(os.path.join(cwd, "../InverseELMo")) 17 | sys.path.append(os.path.join(cwd, "../CycleGAN-sentiment-transfer")) 18 | from invELMo import invELMo 19 | 20 | def f2h(s): 21 | s = list(s) 22 | for i in range(len(s)): 23 | num = ord(s[i]) 24 | if num == 0x3000: 25 | num = 32 26 | elif 0xFF01 <= num <= 0xFF5E: 27 | num -= 0xfee0 28 | s[i] = chr(num).translate(str.maketrans('﹕﹐﹑。﹔﹖﹗﹘ ', ':,、。;?!- ')) 29 | return re.sub(r"( | )+", " ", "".join(s)).strip() 30 | 31 | def sort_list(li, piv=2,unsort_ind=None): 32 | ind = [] 33 | if unsort_ind == None: 34 | ind = sorted(range(len(li[piv])), key=(lambda k: li[piv][k])) 35 | else: 36 | ind = unsort_ind 37 | for i in range(len(li)): 38 | li[i] = [li[i][j] for j in ind] 39 | return li, ind 40 | 41 | def sort_numpy(li, piv=2,unsort=False): 42 | ind = np.argsort(-li[piv] if not unsort else li[piv], axis=0) 43 | for i in range(len(li)): 44 | if type(li[i]).__module__ == np.__name__ or type(li[i]).__module__ == torch.__name__: 45 | li[i] = li[i][ind] 46 | else: 47 | li[i] = [li[i][j] for j in ind] 48 | return li, ind 49 | 50 | def sort_torch(li, piv=2,unsort=False): 51 | li[piv], ind = torch.sort(li[piv], dim=0, descending=(not unsort)) 52 | for i in range(len(li)): 53 | if i == piv: 54 | continue 55 | else: 56 | li[i] = li[i][ind] 57 | return li, ind 58 | 59 | def sort_by(li, piv=2, unsort=False): 60 | if type(li[piv]).__module__ == np.__name__: 61 | return sort_numpy(li, piv, unsort) 62 | elif type(li[piv]).__module__ == torch.__name__: 63 | return sort_torch(li, piv, unsort) 64 | else: 65 | return sort_list(li, piv, unsort) 66 | 67 | 68 | class Embedder(): 69 | def __init__(self, seq_len=0, use_cuda=True, device=None): 70 | self.embedder = elmo.Embedder(batch_size=512, use_cuda=use_cuda) 71 | self.seq_len = seq_len 72 | self.bos_vec, self.eos_vec = np.load("bos_eos.npy") 73 | self.pad, self.oov = np.load("pad_oov.npy") 74 | self.device = device 75 | if self.device != None: 76 | self.embedder.model.to(self.device) 77 | 78 | def __call__(self, sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False): 79 | seq_lens = np.array([len(x) for x in sents], dtype=np.int64) 80 | sents = [[self.sub_unk(x) for x in sent] for sent in sents] 81 | if max_len != 0: 82 | pass 83 | elif self.seq_len != 0: 84 | max_len = self.seq_len 85 | else: 86 | max_len = seq_lens.max() 87 | emb_list = self.embedder.sents2elmo(sents, output_layer=layer) 88 | if not with_bos_eos: 89 | for i in range(len(emb_list)): 90 | if max_len - seq_lens[i] > 0: 91 | if pad_matters: 92 | emb_list[i] = np.concatenate([emb_list[i], np.tile(self.pad,[max_len - seq_lens[i],1])], axis=0) 93 | else: 94 | emb_list[i] = np.concatenate([emb_list[i], np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))]) 95 | else: 96 | emb_list[i] = emb_list[i][:max_len] 97 | elif with_bos_eos: 98 | for i in range(len(emb_list)): 99 | if max_len - seq_lens[i] > 0: 100 | if pad_matters: 101 | emb_list[i] = np.concatenate([ 102 | self.bos_vec[np.newaxis], 103 | emb_list[i], 104 | self.eos_vec[np.newaxis], 105 | np.tile(self.pad, [max_len - seq_lens[i], 1])], axis=0) 106 | else: 107 | emb_list[i] = np.concatenate([ 108 | self.bos_vec[np.newaxis], 109 | emb_list[i], 110 | self.eos_vec[np.newaxis], 111 | np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))], axis=0) 112 | else: 113 | emb_list[i] = np.concatenate([self.bos_vec[np.newaxis], emb_list[i][:max_len],self.eos_vec[np.newaxis]], axis=0) 114 | embedded = np.array(emb_list, dtype=np.float32) 115 | seq_lens = seq_lens+2 if with_bos_eos else seq_lens 116 | return embedded, seq_lens 117 | 118 | def sub_unk(self, e): 119 | e = e.replace(',',',') 120 | e = e.replace(':',':') 121 | e = e.replace(';',';') 122 | e = e.replace('?','?') 123 | e = e.replace('!', '!') 124 | return e 125 | 126 | 127 | class Utils(): 128 | def __init__(self, 129 | training_data_path, 130 | testing_data_path, 131 | batch_size = 32, elmo_device=None): 132 | self.training_data_path = training_data_path 133 | self.training_line_num = int(os.popen("wc -l %s"%self.training_data_path).read().split(' ')[0]) 134 | self.testing_data_path = testing_data_path 135 | self.testing_line_num = int(os.popen("wc -l %s"%self.testing_data_path).read().split(' ')[0]) 136 | self.elmo = Embedder(device=elmo_device, use_cuda=elmo_device!="cpu") 137 | self.batch_size = batch_size 138 | self.train_step_num = math.floor(self.training_line_num / batch_size) 139 | self.test_step_num = math.floor(self.testing_line_num / batch_size) 140 | self.device="cuda:0" 141 | 142 | def process_sent(self, sent): 143 | sent = f2h(sent) 144 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 145 | char_list = list("".join(word_list)) 146 | label_list = [] 147 | for word in word_list: 148 | label_list += [0] * (len(word) - 1) + [1] 149 | return char_list, label_list 150 | 151 | def sent2list(self, sent): 152 | sent = f2h(sent) 153 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 154 | return word_list 155 | 156 | def data_generator(self, mode="train", write_actual_data=False): 157 | if write_actual_data: 158 | fw = open("actual_test_data.utf8", 'w') 159 | path = eval("self.%sing_data_path" % mode) 160 | file = open(path) 161 | sents = [] 162 | for sent in file: 163 | if len(sent.strip()) == 0: 164 | continue 165 | word_list = self.sent2list(sent) 166 | if len(word_list) > 150:# process long sentences 167 | continue 168 | else: 169 | if write_actual_data: 170 | fw.write(' '.join(word_list) + '\n') 171 | sents.append(word_list) 172 | if len(sents) == self.batch_size: 173 | yield sents 174 | sents = [] 175 | fw.close() 176 | if len(sents)!=0: 177 | yield sents 178 | 179 | def raw2elmo(self, batch, with_bos_eos=True): 180 | embedded, seq_lens = self.elmo(batch, with_bos_eos=with_bos_eos) 181 | return embedded, seq_lens 182 | 183 | def elmo2mask(self, embedded, seq_lens, with_bos_eos=True, mask_rate=0.0): 184 | # embedded, seq_lens = self.elmo(batch) 185 | mask = np.full((embedded.shape[0], embedded.shape[1]), -1, dtype=np.int64) 186 | if mask_rate: 187 | (begin, end) = (1, -1) if with_bos_eos else(0, 0) 188 | rand_mat = np.random.rand(embedded.shape[0], embedded.shape[1]) 189 | for row in range(embedded.shape[0]): 190 | for col in range(begin, seq_lens[row] + end): 191 | if rand_mat[row, col] < mask_rate: 192 | embedded[row, col] = torch.zeros(embedded[row, col].shape) 193 | mask[row, col] = 0 194 | else: 195 | mask[row, col] = 1 196 | return embedded, seq_lens, mask 197 | 198 | class Generator(nn.Module): 199 | def __init__(self, 200 | batch_size=32, 201 | device="cuda:0", 202 | hidden_size=300, 203 | input_size=1024, 204 | encode_size=30, 205 | n_layers=3, 206 | dropout=0.33): 207 | super(Generator, self).__init__() 208 | self.device = device if torch.cuda.is_available() else "cpu" 209 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 210 | dropout=(0 if n_layers == 1 else dropout), 211 | bidirectional=True, 212 | batch_first=True) 213 | self.fc1 = nn.Linear(2*hidden_size, input_size) 214 | self.hidden_expander_1 = nn.Linear(encode_size, hidden_size) 215 | self.hidden_expander_2 = nn.Linear(encode_size, hidden_size) 216 | self.optimizer = torch.optim.Adam(self.parameters()) 217 | 218 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 219 | embedded = torch.from_numpy(input_seq).to(self.device) 220 | if sort: 221 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 222 | hidden[0] = self.hidden_expander_1(hidden[0]) 223 | hidden[1] = self.hidden_expander_2(hidden[1]) 224 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 225 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 226 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 227 | pred_seq = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 228 | embedded.cpu() 229 | if unsort: 230 | [pred_seq, _], _ = sort_by([pred_seq, ind], piv=1, unsort=True) 231 | return pred_seq 232 | 233 | class Discriminator(nn.Module): 234 | def __init__(self, 235 | batch_size=32, 236 | device="cuda:0", 237 | hidden_size=300, 238 | input_size=1024, 239 | n_layers=3, 240 | dropout=0.33): 241 | super(Discriminator, self).__init__() 242 | self.device = device if torch.cuda.is_available() else "cpu" 243 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 244 | dropout=(0 if n_layers == 1 else dropout), 245 | bidirectional=True, 246 | batch_first=True) 247 | self.fc1 = nn.Linear(2*hidden_size, 2) 248 | self.optimizer = torch.optim.Adam(self.parameters()) 249 | 250 | def forward(self, input_seq, input_lengths, hidden=None, numpy=True, sort=True, unsort=False): 251 | if numpy: 252 | embedded = torch.from_numpy(input_seq).to(self.device) 253 | else: 254 | embedded = input_seq.to(self.device) 255 | if sort: 256 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 257 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 258 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 259 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 260 | pred_prob = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 261 | embedded.cpu() 262 | if unsort: 263 | [pred_prob, _], _ = sort_by([pred_prob, ind], piv=1, unsort=True) 264 | return pred_prob 265 | 266 | 267 | class Encoder(nn.Module): 268 | def __init__(self, 269 | batch_size=32, 270 | device="cuda:0", 271 | hidden_size=300, 272 | encode_size=30, 273 | input_size=1024, 274 | n_layers=3, 275 | dropout=0.33): 276 | super(Encoder, self).__init__() 277 | self.device = device if torch.cuda.is_available() else "cpu" 278 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 279 | dropout=(0 if n_layers == 1 else dropout), 280 | bidirectional=True, 281 | batch_first=True) 282 | self.fc1 = nn.Linear(hidden_size, encode_size) 283 | self.fc2 = nn.Linear(hidden_size, encode_size) 284 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 285 | self.optimizer = torch.optim.Adam(self.parameters()) 286 | 287 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 288 | embedded = torch.from_numpy(input_seq).to(self.device) 289 | if sort: 290 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 291 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 292 | outputs, (h_n, c_n) = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 293 | outputs = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 294 | code1 = self.fc1(h_n) #nn.Softmax(dim=-1)(self.fc1(outputs)) 295 | code2 = self.fc2(c_n) 296 | embedded.cpu() 297 | if unsort: 298 | [code1, code2, _], _ = sort_by([code1, code2, ind], piv=2, unsort=True) 299 | return [code1, code2] 300 | 301 | def load_cpu_invelmo(): 302 | elmo = invELMo() 303 | elmo.device = "cpu" 304 | elmo.load_model(device="cpu") 305 | elmo.eval() 306 | return elmo 307 | 308 | class MaskGAN(): 309 | def __init__(self, embedder, discriminator, generator, encoder, utils): 310 | self.embedder = embedder 311 | self.D = discriminator 312 | self.G = generator 313 | self.encoder = encoder 314 | self.utils = utils 315 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 316 | self.invelmo = load_cpu_invelmo() 317 | self.mse = nn.MSELoss() 318 | 319 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt"): 320 | torch.save(self.D.state_dict(), d_path) 321 | torch.save(self.G.state_dict(), g_path) 322 | 323 | def load_model(self, path=""): 324 | self.D.load_state_dict(torch.load(os.path.join(path,"Dis_model.ckpt"))) 325 | self.G.load_state_dict(torch.load(os.path.join(path,"Gen_model.ckpt"))) 326 | print("model loaded!") 327 | 328 | def pretrain(self, num_epochs=1): 329 | self.G.to(self.G.device) 330 | self.encoder.to(self.encoder.device) 331 | real_datagen = self.utils.data_generator("train") 332 | test_datagen = self.utils.data_generator("test") 333 | for epoch in range(num_epochs): 334 | ct = Clock(self.utils.train_step_num, title="Pretrain(%d/%d)" % (epoch, num_epochs)) 335 | for real_data in real_datagen: 336 | # 2. Train G on D's response (but DO NOT train D on these labels) 337 | self.G.zero_grad() 338 | 339 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 340 | 341 | gen_input = self.encoder(g_org_data, g_data_seqlen) 342 | g_fake_data = self.G(g_org_data, g_data_seqlen, hidden=gen_input) 343 | loss = self.mse(g_fake_data, torch.from_numpy(g_org_data).to(self.G.device)) 344 | 345 | loss.backward() 346 | self.G.optimizer.step() # Only optimizes G's parameters 347 | self.encoder.optimizer.step() 348 | ct.flush(info={"G_loss": loss.item()}) 349 | 350 | with torch.no_grad(): 351 | for _, real_data in zip(range(2), test_datagen): 352 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 353 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 354 | g_mask_data, g_data_seqlen, g_mask_label = \ 355 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 356 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 357 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 358 | 359 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 360 | for i, j in zip(real_data, gen_sents): 361 | print("="*50) 362 | print(' '.join(i)) 363 | print("---") 364 | print(' '.join(j)) 365 | print("=" * 50) 366 | torch.save(self.G.state_dict(), "pretrain_model.ckpt") 367 | 368 | 369 | def train_model(self, num_epochs=100, d_steps=10, g_steps=10): 370 | self.D.to(self.D.device) 371 | self.G.to(self.G.device) 372 | self.encoder.to(self.encoder.device) 373 | real_datagen = self.utils.data_generator("train") 374 | test_datagen = self.utils.data_generator("test") 375 | for epoch in range(num_epochs): 376 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 377 | for d_step, real_data in zip(range(d_steps), real_datagen): 378 | # 1. Train D on real+fake 379 | self.D.zero_grad() 380 | 381 | # 1A: Train D on real 382 | d_org_data, d_data_seqlen = self.utils.raw2elmo(real_data) 383 | d_mask_data, d_data_seqlen, d_mask_label = \ 384 | self.utils.elmo2mask(d_org_data, d_data_seqlen, mask_rate=epoch/num_epochs) 385 | d_real_pred = self.D(d_org_data, d_data_seqlen) 386 | d_real_error = self.criterion(d_real_pred.transpose(1, 2), torch.ones(d_mask_label.shape, dtype=torch.int64).to(self.D.device)) # ones = true 387 | d_real_error.backward() # compute/store gradients, but don't change params 388 | self.D.optimizer.step() 389 | 390 | # 1B: Train D on fake 391 | d_gen_input = self.encoder(d_org_data, d_data_seqlen) 392 | d_fake_data = self.G(d_mask_data, d_data_seqlen, hidden=d_gen_input).detach() # detach to avoid training G on these labels 393 | d_fake_pred = self.D(d_fake_data, d_data_seqlen, numpy=False) 394 | d_fake_error = self.criterion(d_fake_pred.transpose(1, 2), torch.from_numpy(d_mask_label).to(self.D.device)) # zeros = fake 395 | d_fake_error.backward() 396 | self.D.optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 397 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 398 | 399 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 400 | for g_step, real_data in zip(range(g_steps), real_datagen): 401 | # 2. Train G on D's response (but DO NOT train D on these labels) 402 | self.G.zero_grad() 403 | 404 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 405 | g_mask_data, g_data_seqlen, g_mask_label = \ 406 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 407 | gen_input = self.encoder(g_org_data, g_data_seqlen) 408 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input) 409 | dg_fake_pred = self.D(g_fake_data, g_data_seqlen, numpy=False) 410 | g_error = self.criterion(dg_fake_pred.transpose(1, 2), torch.ones(g_mask_label.shape, dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 411 | 412 | g_error.backward() 413 | self.G.optimizer.step() # Only optimizes G's parameters 414 | self.encoder.optimizer.step() 415 | g_ct.flush(info={"G_loss": g_error.item()}) 416 | 417 | with torch.no_grad(): 418 | for _, real_data in zip(range(2), test_datagen): 419 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 420 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 421 | g_mask_data, g_data_seqlen, g_mask_label = \ 422 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 423 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 424 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 425 | 426 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 427 | for i, j in zip(real_data, gen_sents): 428 | print("="*50) 429 | print(' '.join(i)) 430 | print("---") 431 | print(' '.join(j)) 432 | print("=" * 50) 433 | self.save_model() 434 | 435 | 436 | 437 | if __name__ == "__main__": 438 | parser = argparse.ArgumentParser() 439 | parser.add_argument("mode", help="execute mode") 440 | # parser.add_argument("-train_file", default=None, required=False, help="test filename") 441 | # parser.add_argument("-test_file", default=None, required=True, help="test filename") 442 | parser.add_argument("-load_model_path", default=None, required=False, help="test filename") 443 | parser.add_argument("-epoch", default=1000, required=False, help="test filename") 444 | parser.add_argument("-d_step", default=100, required=False, help="test filename") 445 | parser.add_argument("-g_step", default=100, required=False, help="test filename") 446 | args = parser.parse_args() 447 | 448 | 449 | embedder, discriminator, generator, encoder, utils = \ 450 | Embedder(), Discriminator(), Generator(), Encoder(), \ 451 | Utils(training_data_path="data/train_as.txt", 452 | testing_data_path="data/test_as.txt", elmo_device="cuda:0") 453 | model = MaskGAN(embedder, discriminator, generator, encoder, utils) 454 | if args.load_model_path != None: 455 | model.load_model(args.load_model_path) 456 | if args.mode == "train": 457 | model.train_model(num_epochs=int(args.epoch), d_steps=int(args.d_step), g_steps=int(args.g_step)) 458 | if args.mode == "pretrain": 459 | model.pretrain(num_epochs=int(args.epoch)) 460 | 461 | -------------------------------------------------------------------------------- /maskgan.1.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch import nn 4 | import os 5 | from ELMoForManyLangs import elmo 6 | import numpy as np 7 | from jexus import Clock, History 8 | import math 9 | from sklearn.metrics import f1_score 10 | from sklearn.metrics import accuracy_score 11 | import argparse 12 | import sys 13 | import random, time 14 | cwd = os.getcwd() 15 | sys.path.append(os.path.join(os.path.dirname(__file__))) 16 | sys.path.append(os.path.join(cwd, "../InverseELMo")) 17 | sys.path.append(os.path.join(cwd, "../CycleGAN-sentiment-transfer")) 18 | from invELMo import invELMo 19 | 20 | def f2h(s): 21 | s = list(s) 22 | for i in range(len(s)): 23 | num = ord(s[i]) 24 | if num == 0x3000: 25 | num = 32 26 | elif 0xFF01 <= num <= 0xFF5E: 27 | num -= 0xfee0 28 | s[i] = chr(num).translate(str.maketrans('﹕﹐﹑。﹔﹖﹗﹘ ', ':,、。;?!- ')) 29 | return re.sub(r"( | )+", " ", "".join(s)).strip() 30 | 31 | def sort_list(li, piv=2,unsort_ind=None): 32 | ind = [] 33 | if unsort_ind == None: 34 | ind = sorted(range(len(li[piv])), key=(lambda k: li[piv][k])) 35 | else: 36 | ind = unsort_ind 37 | for i in range(len(li)): 38 | li[i] = [li[i][j] for j in ind] 39 | return li, ind 40 | 41 | def sort_numpy(li, piv=2,unsort=False): 42 | ind = np.argsort(-li[piv] if not unsort else li[piv], axis=0) 43 | for i in range(len(li)): 44 | if type(li[i]).__module__ == np.__name__ or type(li[i]).__module__ == torch.__name__: 45 | li[i] = li[i][ind] 46 | else: 47 | li[i] = [li[i][j] for j in ind] 48 | return li, ind 49 | 50 | def sort_torch(li, piv=2,unsort=False): 51 | li[piv], ind = torch.sort(li[piv], dim=0, descending=(not unsort)) 52 | for i in range(len(li)): 53 | if i == piv: 54 | continue 55 | else: 56 | li[i] = li[i][ind] 57 | return li, ind 58 | 59 | def sort_by(li, piv=2, unsort=False): 60 | if type(li[piv]).__module__ == np.__name__: 61 | return sort_numpy(li, piv, unsort) 62 | elif type(li[piv]).__module__ == torch.__name__: 63 | return sort_torch(li, piv, unsort) 64 | else: 65 | return sort_list(li, piv, unsort) 66 | 67 | 68 | class Embedder(): 69 | def __init__(self, seq_len=0, use_cuda=True, device=None): 70 | self.embedder = elmo.Embedder(batch_size=512, use_cuda=use_cuda) 71 | self.seq_len = seq_len 72 | self.bos_vec, self.eos_vec = np.load("bos_eos.npy") 73 | self.pad, self.oov = np.load("pad_oov.npy") 74 | self.device = device 75 | if self.device != None: 76 | self.embedder.model.to(self.device) 77 | 78 | def __call__(self, sents, max_len=0, with_bos_eos=True, layer=-1, pad_matters=False): 79 | seq_lens = np.array([len(x) for x in sents], dtype=np.int64) 80 | sents = [[self.sub_unk(x) for x in sent] for sent in sents] 81 | if max_len != 0: 82 | pass 83 | elif self.seq_len != 0: 84 | max_len = self.seq_len 85 | else: 86 | max_len = seq_lens.max() 87 | emb_list = self.embedder.sents2elmo(sents, output_layer=layer) 88 | if not with_bos_eos: 89 | for i in range(len(emb_list)): 90 | if max_len - seq_lens[i] > 0: 91 | if pad_matters: 92 | emb_list[i] = np.concatenate([emb_list[i], np.tile(self.pad,[max_len - seq_lens[i],1])], axis=0) 93 | else: 94 | emb_list[i] = np.concatenate([emb_list[i], np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))]) 95 | else: 96 | emb_list[i] = emb_list[i][:max_len] 97 | elif with_bos_eos: 98 | for i in range(len(emb_list)): 99 | if max_len - seq_lens[i] > 0: 100 | if pad_matters: 101 | emb_list[i] = np.concatenate([ 102 | self.bos_vec[np.newaxis], 103 | emb_list[i], 104 | self.eos_vec[np.newaxis], 105 | np.tile(self.pad, [max_len - seq_lens[i], 1])], axis=0) 106 | else: 107 | emb_list[i] = np.concatenate([ 108 | self.bos_vec[np.newaxis], 109 | emb_list[i], 110 | self.eos_vec[np.newaxis], 111 | np.zeros((max_len - seq_lens[i], emb_list[i].shape[1]))], axis=0) 112 | else: 113 | emb_list[i] = np.concatenate([self.bos_vec[np.newaxis], emb_list[i][:max_len],self.eos_vec[np.newaxis]], axis=0) 114 | embedded = np.array(emb_list, dtype=np.float32) 115 | seq_lens = seq_lens+2 if with_bos_eos else seq_lens 116 | return embedded, seq_lens 117 | 118 | def sub_unk(self, e): 119 | e = e.replace(',',',') 120 | e = e.replace(':',':') 121 | e = e.replace(';',';') 122 | e = e.replace('?','?') 123 | e = e.replace('!', '!') 124 | return e 125 | 126 | 127 | class Utils(): 128 | def __init__(self, 129 | training_data_path, 130 | testing_data_path, 131 | batch_size = 32, elmo_device=None): 132 | self.training_data_path = training_data_path 133 | self.training_line_num = int(os.popen("wc -l %s"%self.training_data_path).read().split(' ')[0]) 134 | self.testing_data_path = testing_data_path 135 | self.testing_line_num = int(os.popen("wc -l %s"%self.testing_data_path).read().split(' ')[0]) 136 | self.elmo = Embedder(device=elmo_device, use_cuda=elmo_device!="cpu") 137 | self.batch_size = batch_size 138 | self.train_step_num = math.floor(self.training_line_num / batch_size) 139 | self.test_step_num = math.floor(self.testing_line_num / batch_size) 140 | self.device="cuda:0" 141 | 142 | def process_sent(self, sent): 143 | sent = f2h(sent) 144 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 145 | char_list = list("".join(word_list)) 146 | label_list = [] 147 | for word in word_list: 148 | label_list += [0] * (len(word) - 1) + [1] 149 | return char_list, label_list 150 | 151 | def sent2list(self, sent): 152 | sent = f2h(sent) 153 | word_list = re.split(r"[\s|\u3000]+", sent.strip()) 154 | return word_list 155 | 156 | def data_generator(self, mode="train", write_actual_data=False): 157 | if write_actual_data: 158 | fw = open("actual_test_data.utf8", 'w') 159 | path = eval("self.%sing_data_path" % mode) 160 | file = open(path) 161 | sents = [] 162 | for sent in file: 163 | if len(sent.strip()) == 0: 164 | continue 165 | word_list = self.sent2list(sent) 166 | if len(word_list) > 150:# process long sentences 167 | continue 168 | else: 169 | if write_actual_data: 170 | fw.write(' '.join(word_list) + '\n') 171 | sents.append(word_list) 172 | if len(sents) == self.batch_size: 173 | yield sents 174 | sents = [] 175 | fw.close() 176 | if len(sents)!=0: 177 | yield sents 178 | 179 | def raw2elmo(self, batch, with_bos_eos=True): 180 | embedded, seq_lens = self.elmo(batch, with_bos_eos=with_bos_eos) 181 | return embedded, seq_lens 182 | 183 | def elmo2mask(self, embedded, seq_lens, with_bos_eos=True, mask_rate=0.0): 184 | # embedded, seq_lens = self.elmo(batch) 185 | mask = np.full((embedded.shape[0], embedded.shape[1]), -1, dtype=np.int64) 186 | if mask_rate: 187 | (begin, end) = (1, -1) if with_bos_eos else(0, 0) 188 | rand_mat = np.random.rand(embedded.shape[0], embedded.shape[1]) 189 | for row in range(embedded.shape[0]): 190 | for col in range(begin, seq_lens[row] + end): 191 | if rand_mat[row, col] < mask_rate: 192 | embedded[row, col] = torch.zeros(embedded[row, col].shape) 193 | mask[row, col] = 0 194 | else: 195 | mask[row, col] = 1 196 | return embedded, seq_lens, mask 197 | 198 | class Generator(nn.Module): 199 | def __init__(self, 200 | batch_size=32, 201 | device="cuda:0", 202 | hidden_size=300, 203 | input_size=1024, 204 | encode_size=30, 205 | n_layers=3, 206 | dropout=0.33): 207 | super(Generator, self).__init__() 208 | self.device = device if torch.cuda.is_available() else "cpu" 209 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 210 | dropout=(0 if n_layers == 1 else dropout), 211 | bidirectional=True, 212 | batch_first=True) 213 | self.fc1 = nn.Linear(2*hidden_size, input_size) 214 | self.hidden_expander_1 = nn.Linear(encode_size, hidden_size) 215 | self.hidden_expander_2 = nn.Linear(encode_size, hidden_size) 216 | self.optimizer = torch.optim.Adam(self.parameters()) 217 | 218 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 219 | embedded = torch.from_numpy(input_seq).to(self.device) 220 | if sort: 221 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 222 | hidden[0] = self.hidden_expander_1(hidden[0]) 223 | hidden[1] = self.hidden_expander_2(hidden[1]) 224 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 225 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 226 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 227 | pred_seq = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 228 | embedded.cpu() 229 | if unsort: 230 | [pred_seq, _], _ = sort_by([pred_seq, ind], piv=1, unsort=True) 231 | return pred_seq 232 | 233 | class Discriminator(nn.Module): 234 | def __init__(self, 235 | batch_size=32, 236 | device="cuda:0", 237 | hidden_size=300, 238 | input_size=1024, 239 | n_layers=3, 240 | dropout=0.33): 241 | super(Discriminator, self).__init__() 242 | self.device = device if torch.cuda.is_available() else "cpu" 243 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 244 | dropout=(0 if n_layers == 1 else dropout), 245 | bidirectional=True, 246 | batch_first=True) 247 | self.fc1 = nn.Linear(2*hidden_size, 2) 248 | self.optimizer = torch.optim.Adam(self.parameters()) 249 | 250 | def forward(self, input_seq, input_lengths, hidden=None, numpy=True, sort=True, unsort=False): 251 | if numpy: 252 | embedded = torch.from_numpy(input_seq).to(self.device) 253 | else: 254 | embedded = input_seq.to(self.device) 255 | if sort: 256 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 257 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 258 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 259 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 260 | pred_prob = self.fc1(outputs)#nn.Softmax(dim=-1)(self.fc1(outputs)) 261 | embedded.cpu() 262 | if unsort: 263 | [pred_prob, _], _ = sort_by([pred_prob, ind], piv=1, unsort=True) 264 | return pred_prob 265 | 266 | 267 | class Encoder(nn.Module): 268 | def __init__(self, 269 | batch_size=32, 270 | device="cuda:0", 271 | hidden_size=300, 272 | encode_size=30, 273 | input_size=1024, 274 | n_layers=3, 275 | dropout=0.33): 276 | super(Encoder, self).__init__() 277 | self.device = device if torch.cuda.is_available() else "cpu" 278 | self.gru = nn.LSTM(input_size, hidden_size, n_layers, 279 | dropout=(0 if n_layers == 1 else dropout), 280 | bidirectional=True, 281 | batch_first=True) 282 | self.fc1 = nn.Linear(hidden_size, encode_size) 283 | self.fc2 = nn.Linear(hidden_size, encode_size) 284 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 285 | self.optimizer = torch.optim.Adam(self.parameters()) 286 | 287 | def forward(self, input_seq, input_lengths, hidden=None, sort=True, unsort=False): 288 | embedded = torch.from_numpy(input_seq).to(self.device) 289 | if sort: 290 | [embedded, input_lengths], ind = sort_by([embedded, input_lengths], piv=1) 291 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 292 | outputs, (h_n, c_n) = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 293 | outputs = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 294 | code1 = self.fc1(h_n) #nn.Softmax(dim=-1)(self.fc1(outputs)) 295 | code2 = self.fc2(c_n) 296 | embedded.cpu() 297 | if unsort: 298 | [code1, code2, _], _ = sort_by([code1, code2, ind], piv=2, unsort=True) 299 | return [code1, code2] 300 | 301 | def load_cpu_invelmo(): 302 | elmo = invELMo() 303 | elmo.device = "cpu" 304 | elmo.load_model(device="cpu") 305 | elmo.eval() 306 | return elmo 307 | 308 | class MaskGAN(): 309 | def __init__(self, embedder, discriminator, generator, encoder, utils): 310 | self.embedder = embedder 311 | self.D = discriminator 312 | self.G = generator 313 | self.encoder = encoder 314 | self.utils = utils 315 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 316 | self.invelmo = load_cpu_invelmo() 317 | self.mse = nn.MSELoss() 318 | 319 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt"): 320 | torch.save(self.D.state_dict(), d_path) 321 | torch.save(self.G.state_dict(), g_path) 322 | 323 | def load_model(self, path=""): 324 | self.D.load_state_dict(torch.load(os.path.join(path,"Dis_model.ckpt"))) 325 | self.G.load_state_dict(torch.load(os.path.join(path,"Gen_model.ckpt"))) 326 | print("model loaded!") 327 | 328 | def pretrain(self, num_epochs=1): 329 | self.G.to(self.G.device) 330 | self.encoder.to(self.encoder.device) 331 | real_datagen = self.utils.data_generator("train") 332 | test_datagen = self.utils.data_generator("test") 333 | for epoch in range(num_epochs): 334 | ct = Clock(self.utils.train_step_num, title="Pretrain(%d/%d)" % (epoch, num_epochs)) 335 | for real_data in real_datagen: 336 | # 2. Train G on D's response (but DO NOT train D on these labels) 337 | self.G.zero_grad() 338 | 339 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 340 | 341 | gen_input = self.encoder(g_org_data, g_data_seqlen) 342 | g_fake_data = self.G(g_org_data, g_data_seqlen, hidden=gen_input) 343 | loss = self.mse(g_fake_data, torch.from_numpy(g_org_data).to(self.G.device)) 344 | 345 | loss.backward() 346 | self.G.optimizer.step() # Only optimizes G's parameters 347 | self.encoder.optimizer.step() 348 | ct.flush(info={"G_loss": loss.item()}) 349 | 350 | with torch.no_grad(): 351 | for _, real_data in zip(range(2), test_datagen): 352 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 353 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 354 | g_mask_data, g_data_seqlen, g_mask_label = \ 355 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 356 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 357 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 358 | 359 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 360 | for i, j in zip(real_data, gen_sents): 361 | print("="*50) 362 | print(' '.join(i)) 363 | print("---") 364 | print(' '.join(j)) 365 | print("=" * 50) 366 | torch.save(self.G.state_dict(), "pretrain_model.ckpt") 367 | 368 | 369 | def train_model(self, num_epochs=100, d_steps=10, g_steps=10): 370 | self.D.to(self.D.device) 371 | self.G.to(self.G.device) 372 | self.encoder.to(self.encoder.device) 373 | real_datagen = self.utils.data_generator("train") 374 | test_datagen = self.utils.data_generator("test") 375 | for epoch in range(num_epochs): 376 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 377 | for d_step, real_data in zip(range(d_steps), real_datagen): 378 | # 1. Train D on real+fake 379 | self.D.zero_grad() 380 | 381 | # 1A: Train D on real 382 | d_org_data, d_data_seqlen = self.utils.raw2elmo(real_data) 383 | d_mask_data, d_data_seqlen, d_mask_label = \ 384 | self.utils.elmo2mask(d_org_data, d_data_seqlen, mask_rate=epoch/num_epochs) 385 | d_real_pred = self.D(d_org_data, d_data_seqlen) 386 | d_real_error = self.criterion(d_real_pred.transpose(1, 2), torch.ones(d_mask_label.shape, dtype=torch.int64).to(self.D.device)) # ones = true 387 | d_real_error.backward() # compute/store gradients, but don't change params 388 | self.D.optimizer.step() 389 | 390 | # 1B: Train D on fake 391 | d_gen_input = self.encoder(d_org_data, d_data_seqlen) 392 | d_fake_data = self.G(d_mask_data, d_data_seqlen, hidden=d_gen_input).detach() # detach to avoid training G on these labels 393 | d_fake_pred = self.D(d_fake_data, d_data_seqlen, numpy=False) 394 | d_fake_error = self.criterion(d_fake_pred.transpose(1, 2), torch.from_numpy(d_mask_label).to(self.D.device)) # zeros = fake 395 | d_fake_error.backward() 396 | self.D.optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 397 | d_ct.flush(info={"D_loss":d_fake_error.item()}) 398 | 399 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 400 | for g_step, real_data in zip(range(g_steps), real_datagen): 401 | # 2. Train G on D's response (but DO NOT train D on these labels) 402 | self.G.zero_grad() 403 | 404 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 405 | g_mask_data, g_data_seqlen, g_mask_label = \ 406 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 407 | gen_input = self.encoder(g_org_data, g_data_seqlen) 408 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input) 409 | dg_fake_pred = self.D(g_fake_data, g_data_seqlen, numpy=False) 410 | g_error = self.criterion(dg_fake_pred.transpose(1, 2), torch.ones(g_mask_label.shape, dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 411 | 412 | g_error.backward() 413 | self.G.optimizer.step() # Only optimizes G's parameters 414 | self.encoder.optimizer.step() 415 | g_ct.flush(info={"G_loss": g_error.item()}) 416 | 417 | with torch.no_grad(): 418 | for _, real_data in zip(range(2), test_datagen): 419 | g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) 420 | [g_org_data, g_data_seqlen], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) 421 | g_mask_data, g_data_seqlen, g_mask_label = \ 422 | self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) 423 | gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) 424 | g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) 425 | 426 | gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) 427 | for i, j in zip(real_data, gen_sents): 428 | print("="*50) 429 | print(' '.join(i)) 430 | print("---") 431 | print(' '.join(j)) 432 | print("=" * 50) 433 | self.save_model() 434 | 435 | 436 | 437 | if __name__ == "__main__": 438 | parser = argparse.ArgumentParser() 439 | parser.add_argument("mode", help="execute mode") 440 | # parser.add_argument("-train_file", default=None, required=False, help="test filename") 441 | # parser.add_argument("-test_file", default=None, required=True, help="test filename") 442 | parser.add_argument("-load_model_path", default=None, required=False, help="test filename") 443 | parser.add_argument("-epoch", default=1000, required=False, help="test filename") 444 | parser.add_argument("-d_step", default=100, required=False, help="test filename") 445 | parser.add_argument("-g_step", default=100, required=False, help="test filename") 446 | args = parser.parse_args() 447 | 448 | 449 | embedder, discriminator, generator, encoder, utils = \ 450 | Embedder(), Discriminator(), Generator(), Encoder(), \ 451 | Utils(training_data_path="data/train_as.txt", 452 | testing_data_path="data/test_as.txt", elmo_device="cuda:0") 453 | model = MaskGAN(embedder, discriminator, generator, encoder, utils) 454 | if args.load_model_path != None: 455 | model.load_model(args.load_model_path) 456 | if args.mode == "train": 457 | model.train_model(num_epochs=int(args.epoch), d_steps=int(args.d_step), g_steps=int(args.g_step)) 458 | if args.mode == "pretrain": 459 | model.pretrain(num_epochs=int(args.epoch)) 460 | 461 | -------------------------------------------------------------------------------- /v5_cyclegan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy, time 6 | from torch.autograd import Variable 7 | import os, re, sys 8 | from jexus import Clock 9 | from attn import Transformer, LabelSmoothing, \ 10 | data_gen, NoamOpt, Generator, SimpleLossCompute, \ 11 | greedy_decode, subsequent_mask, Batch 12 | from utils import Utils 13 | from char_cnn_discriminator import Discriminator 14 | import argparse 15 | 16 | device = "cuda:1" 17 | 18 | 19 | def prob_backward(model, embed, src, src_mask, max_len, start_symbol=2, raw=False): 20 | if raw==False: 21 | memory = model.encode(embed(src.to(device)), src_mask) 22 | else: 23 | memory = model.encode(src.to(device), src_mask) 24 | 25 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 26 | probs = [] 27 | for i in range(max_len+2-1): 28 | out = model.decode(memory, src_mask, 29 | embed(Variable(ys)), 30 | Variable(subsequent_mask(ys.size(1)) 31 | .type_as(src.data))) 32 | prob = model.generator(out[:, -1]) 33 | probs.append(prob.unsqueeze(1)) 34 | _, next_word = torch.max(prob, dim = 1) 35 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 36 | ret = torch.cat(probs, dim=1) 37 | return ret 38 | 39 | def backward_decode(model, embed, src, src_mask, max_len, start_symbol=2, raw=False, return_term=-1): 40 | if raw==False: 41 | memory = model.encode(embed(src.to(device)), src_mask) 42 | else: 43 | memory = model.encode(src.to(device), src_mask) 44 | 45 | ys = torch.ones(src.shape[0], 1, dtype=torch.int64).fill_(start_symbol).to(device) 46 | ret_back = embed(ys).float() 47 | for i in range(max_len+2-1): 48 | out = model.decode(memory, src_mask, 49 | embed(Variable(ys)), 50 | Variable(subsequent_mask(ys.size(1)) 51 | .type_as(src.data))) 52 | prob = model.generator.scaled_forward(out[:, -1], scale=10.0) 53 | back = torch.matmul(prob ,embed.weight.data.float()) 54 | _, next_word = torch.max(prob, dim = 1) 55 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 56 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 57 | return (ret_back, ys) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else None 58 | 59 | def reconstruct(model, src, max_len, start_symbol=2): 60 | memory = model.encoder(model.src_embed[1](src), None) 61 | ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device) 62 | ret_back = model.tgt_embed[0].pure_emb(ys).float() 63 | for i in range(max_len-1): 64 | out = model.decode(memory, None, 65 | Variable(ys), 66 | Variable(subsequent_mask(ys.size(1)) 67 | .type_as(src.data))) 68 | prob = model.generator(out[:, -1]) 69 | back = torch.matmul(prob ,model.tgt_embed[0].lut.weight.data.float()) 70 | _, next_word = torch.max(prob, dim = 1) 71 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 72 | ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1) 73 | return ret_back 74 | 75 | def wgan_pg(netD, fake_data, real_data, lamb=10): 76 | batch_size = fake_data.shape[0] 77 | ## 1. interpolation 78 | alpha = torch.rand(batch_size, 1, 1).expand(real_data.size()).to(device) 79 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 80 | interpolates = Variable(interpolates.to(device), requires_grad=True) 81 | ## 2. gradient penalty 82 | disc_interpolates = netD(interpolates).view(batch_size, ) 83 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 84 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 85 | create_graph=True, retain_graph=True, only_inputs=True)[0] 86 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb 87 | ## 3. append it to loss function 88 | return gradient_penalty 89 | 90 | class CycleGAN(nn.Module): 91 | def __init__(self, discriminator, generator, utils, embedder): 92 | super(CycleGAN, self).__init__() 93 | self.D = discriminator 94 | self.G = generator 95 | self.R = copy.deepcopy(generator) 96 | self.D_opt = torch.optim.Adam(self.D.parameters()) 97 | # self.G_opt = torch.optim.Adam(self.G.parameters()) 98 | self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 99 | torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 100 | # self.R_opt = torch.optim.Adam(self.R.parameters()) 101 | self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 102 | torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 103 | self.embed = embedder 104 | 105 | self.utils = utils 106 | self.criterion = nn.CrossEntropyLoss(ignore_index=-1) 107 | self.mse = nn.MSELoss() 108 | self.cos = nn.CosineSimilarity(dim=-1) 109 | self.cosloss=nn.CosineEmbeddingLoss() 110 | self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 111 | self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt) 112 | 113 | def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): 114 | torch.save(self.D.state_dict(), d_path) 115 | torch.save(self.G.state_dict(), g_path) 116 | torch.save(self.R.state_dict(), r_path) 117 | 118 | def load_model(self, path="", g_file=None, d_file=None, r_file=None): 119 | if g_file!=None: 120 | self.G.load_state_dict(torch.load(os.path.join(path, g_file), map_location=device)) 121 | if d_file!=None: 122 | self.D.load_state_dict(torch.load(os.path.join(path, d_file), map_location=device)) 123 | if r_file!=None: 124 | self.R.load_state_dict(torch.load(os.path.join(path, r_file), map_location=device)) 125 | print("model loaded!") 126 | 127 | def pretrain_disc(self, num_epochs=100): 128 | # self.D.to(self.D.device) 129 | # self.G.to(self.G.device) 130 | # self.R.to(self.R.device) 131 | X_datagen = self.utils.data_generator("X") 132 | Y_datagen = self.utils.data_generator("Y") 133 | for epoch in range(num_epochs): 134 | d_steps = self.utils.train_step_num 135 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) 136 | for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 137 | # 1. Train D on real+fake 138 | # if epoch == 0: 139 | # break 140 | self.D.zero_grad() 141 | d_real_data = self.embed(Y_data.src.to(device)).float() 142 | 143 | # 1A: Train D on real 144 | d_real_pred = self.D(d_real_data) 145 | # d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true 146 | 147 | # 1B: Train D on fake 148 | d_fake_data = self.embed(X_data.src.to(device)).float() 149 | d_fake_pred = self.D(d_fake_data) 150 | # d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake 151 | # (d_fake_error + d_real_error).backward() 152 | 153 | d_loss = d_fake_pred.mean() - d_real_pred.mean() 154 | d_loss += wgan_pg(self.D, d_fake_data, d_real_data, lamb=10) 155 | d_loss.backward() 156 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 157 | d_ct.flush(info={"D_loss": d_loss.item()}) 158 | torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt") 159 | 160 | def train_model(self, num_epochs=100, d_steps=20, g_steps=20, g_scale=1.0, r_scale=1000.0): 161 | for i, batch in enumerate(data_gen(self.utils.test_generator("X"), self.utils.sents2idx)): 162 | X_test_batch = batch 163 | break 164 | 165 | for i, batch in enumerate(data_gen(self.utils.test_generator("Y"), self.utils.sents2idx)): 166 | Y_test_batch = batch 167 | break 168 | X_datagen = self.utils.data_generator("X") 169 | Y_datagen = self.utils.data_generator("Y") 170 | for epoch in range(num_epochs): 171 | d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) 172 | if epoch>0: 173 | for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): 174 | # 1. Train D on real+fake 175 | # if epoch == 0: 176 | # break 177 | self.D.zero_grad() 178 | 179 | # 1A: Train D on real 180 | d_real_data = self.embed(Y_data.src.to(device)).float() 181 | d_real_pred = self.D(d_real_data) 182 | # d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true 183 | 184 | # 1B: Train D on fake 185 | self.G.to(device) 186 | d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels 187 | d_fake_pred = self.D(d_fake_data) 188 | # d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake 189 | # (d_fake_error + d_real_error).backward() 190 | 191 | d_loss = d_fake_pred.mean() - d_real_pred.mean() 192 | d_loss += wgan_pg(self.D, d_fake_data, d_real_data, lamb=10) 193 | d_loss.backward() 194 | self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() 195 | d_ct.flush(info={"D_loss":d_loss.item()}) 196 | 197 | g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) 198 | r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs)) 199 | if epoch>0: 200 | for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): 201 | # 2. Train G on D's response (but DO NOT train D on these labels) 202 | self.G.zero_grad() 203 | g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0) 204 | dg_fake_pred = self.D(g_fake_data) 205 | # g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine 206 | g_loss = -dg_fake_pred.mean() 207 | 208 | g_loss.backward(retain_graph=True) 209 | self.G_opt.step() # Only optimizes G's parameters 210 | self.G.zero_grad() 211 | # g_ct.flush(info={"G_loss": g_loss.item()}) 212 | 213 | # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com 214 | # way_3 215 | out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 216 | None, X_data.trg_mask) 217 | r_loss = r_scale*self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) 218 | # way_2 219 | # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) 220 | # x_orgi_data = X_data.src[:, 1:] 221 | # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) 222 | # way_1 223 | # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] 224 | # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) 225 | self.G_opt.step() 226 | self.G_opt.optimizer.zero_grad() 227 | r_ct.flush(info={"G": g_loss.item(), 228 | "R": r_loss / X_data.ntokens.float().to(device)}) 229 | 230 | with torch.no_grad(): 231 | x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) 232 | x = utils.idx2sent(x_ys) 233 | y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) 234 | y = utils.idx2sent(y_ys) 235 | r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 236 | r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 237 | 238 | for i,j,l in zip(X_test_batch.src, x, r_x): 239 | print("===") 240 | k = utils.idx2sent([i])[0] 241 | print("ORG:", " ".join(k[:k.index('')+1])) 242 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 243 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 244 | print("=====") 245 | for i, j, l in zip(Y_test_batch.src, y, r_y): 246 | print("===") 247 | k = utils.idx2sent([i])[0] 248 | print("ORG:", " ".join(k[:k.index('')+1])) 249 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 250 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 251 | self.save_model() 252 | 253 | def demo(self, sent): 254 | sent = sent.strip() 255 | data = torch.from_numpy(self.utils.sents2idx([self.utils.process_sent(sent)])).long() 256 | data = torch.cat((torch.full((data.shape[0], 1), 2, dtype=torch.long), data), dim=1) 257 | src = Variable(data, requires_grad=False) 258 | tgt = Variable(data, requires_grad=False) 259 | batch = Batch(src, tgt, 0) 260 | with torch.no_grad(): 261 | x_cont, x_ys = backward_decode(model, self.embed, batch.src, batch.src_mask, max_len=25, start_symbol=2) 262 | x = utils.idx2sent(x_ys) 263 | r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) 264 | 265 | for i,j,l in zip(batch.src, x, r_x): 266 | print("===") 267 | k = utils.idx2sent([i])[0] 268 | print("ORG:", " ".join(k[:k.index('')+1])) 269 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 270 | print("REC:", " ".join(l[:l.index('')+1] if '' in l else l)) 271 | print("=====") 272 | 273 | def live(self): 274 | while True: 275 | e = input("sent> ") 276 | self.demo(e) 277 | 278 | 279 | def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer): 280 | "Standard Training and Logging Function" 281 | start = time.time() 282 | total_tokens = 0 283 | total_loss = 0 284 | tokens = 0 285 | ct = Clock(train_step_num) 286 | embedding_layer.to(device) 287 | model.to(device) 288 | for i, batch in enumerate(data_iter): 289 | batch.to(device) 290 | out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), 291 | batch.src_mask, batch.trg_mask) 292 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 293 | total_loss += loss 294 | total_tokens += batch.ntokens 295 | tokens += batch.ntokens 296 | batch.to("cpu") 297 | if i % 50 == 1: 298 | elapsed = time.time() - start 299 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) 300 | # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 301 | # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) 302 | start = time.time() 303 | tokens = 0 304 | else: 305 | ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) 306 | return total_loss / total_tokens.float().to(device) 307 | 308 | def get_embedding_layer(utils): 309 | d_model = utils.emb_mat.shape[1] 310 | vocab = utils.emb_mat.shape[0] 311 | embedding_layer = nn.Embedding(vocab, d_model) 312 | embedding_layer.weight.data = torch.tensor(utils.emb_mat) 313 | embedding_layer.weight.requires_grad = False 314 | embedding_layer.to(device) 315 | return embedding_layer 316 | 317 | def pretrain(model, embedding_layer, utils, epoch_num=1): 318 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 319 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, 320 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 321 | X_test_batch = None 322 | Y_test_batch = None 323 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 324 | X_test_batch = batch 325 | break 326 | 327 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 328 | Y_test_batch = batch 329 | break 330 | model.to(device) 331 | for epoch in range(epoch_num): 332 | model.train() 333 | print("EPOCH %d:"%(epoch+1)) 334 | pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 335 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 336 | pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 337 | SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) 338 | model.eval() 339 | torch.save(model.state_dict(), 'model_pretrain.ckpt') 340 | x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) 341 | y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) 342 | 343 | for i,j in zip(X_test_batch.src, x): 344 | print("===") 345 | k = utils.idx2sent([i])[0] 346 | print("ORG:", " ".join(k[:k.index('')+1])) 347 | print("--") 348 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 349 | print("===") 350 | print("=====") 351 | for i, j in zip(Y_test_batch.src, y): 352 | print("===") 353 | k = utils.idx2sent([i])[0] 354 | print("ORG:", " ".join(k[:k.index('')+1])) 355 | print("--") 356 | print("GEN:", " ".join(j[:j.index('')+1] if '' in j else j)) 357 | print("===") 358 | 359 | if __name__ == "__main__": 360 | parser = argparse.ArgumentParser() 361 | parser.add_argument("mode", help="execute mode") 362 | parser.add_argument("-filename", default=None, required=False, help="test filename") 363 | parser.add_argument("-load_model", default=False, required=False, help="test filename") 364 | parser.add_argument("-model_name", default="model.ckpt", required=False, help="test filename") 365 | parser.add_argument("-disc_name", default="cnn_disc/model_disc_pretrain_xy_inv.ckpt", required=False, help="test filename") 366 | parser.add_argument("-save_path", default="", required=False, help="test filename") 367 | parser.add_argument("-X_file", default="shuf_cna.txt", required=False, help="X domain text filename") 368 | parser.add_argument("-Y_file", default="shuf_cou.txt", required=False, help="Y domain text filename") 369 | parser.add_argument("-X_test", default="test.cna", required=False, help="X domain text filename") 370 | parser.add_argument("-Y_test", default="test.cou", required=False, help="Y domain text filename") 371 | parser.add_argument("-epoch", default=1, required=False, help="test filename") 372 | parser.add_argument("-max_len", default=20, required=False, help="test filename") 373 | parser.add_argument("-batch_size", default=32, required=False, help="batch size") 374 | parser.add_argument("-r_scale", default=10, required=False, help="batch size") 375 | args = parser.parse_args() 376 | 377 | model = Transformer(N=2) 378 | utils = Utils(X_data_path=args.X_file, Y_data_path=args.Y_file, 379 | X_test_path=args.X_test, Y_test_path=args.Y_test, 380 | batch_size=int(args.batch_size)) 381 | embedding_layer = get_embedding_layer(utils).to(device) 382 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 383 | if args.load_model: 384 | model.load_state_dict(torch.load(args.model_name)) 385 | if args.mode == "pretrain": 386 | pretrain(model, embedding_layer, utils, int(args.epoch)) 387 | if args.mode == "cycle": 388 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 389 | main_model = CycleGAN(disc, model, utils, embedding_layer) 390 | main_model.to(device) 391 | main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file="model_disc_pretrain.ckpt") 392 | main_model.train_model(num_epochs=int(args.epoch), r_scale=int(args.r_scale)) 393 | if args.mode == "demo": 394 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 395 | main_model = CycleGAN(disc, model, utils, embedding_layer) 396 | main_model.to(device) 397 | main_model.load_model(g_file="Gen_model.ckpt", r_file="Res_model.ckpt", d_file="Dis_model.ckpt") 398 | main_model.live() 399 | if args.mode == "disc": 400 | disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) 401 | main_model = CycleGAN(disc, model, utils, embedding_layer) 402 | main_model.to(device) 403 | main_model.pretrain_disc(2) 404 | 405 | if args.mode == "dev": 406 | model = Transformer(N=2) 407 | utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt") 408 | model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) 409 | criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) 410 | model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400, 411 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 412 | X_test_batch = None 413 | Y_test_batch = None 414 | for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): 415 | X_test_batch = batch 416 | break 417 | 418 | for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): 419 | Y_test_batch = batch 420 | break 421 | # if args.load_model: 422 | # model.load_model(filename=args.model_name) 423 | # if args.mode == "train": 424 | # model.train_model(num_epochs=int(args.epoch)) 425 | # print("========= Testing =========") 426 | # model.load_model() 427 | # model.test_corpus() 428 | # if args.mode == "test": 429 | # model.test_corpus() 430 | --------------------------------------------------------------------------------