├── .gitignore ├── LICENSE ├── Model ├── Constants.py ├── Modules.py └── __init__.py ├── README.md ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /Model/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | import Model.Constants as Constants 8 | from utils import check_cuda 9 | 10 | class Encoder(nn.Module): 11 | '''A LSTM encoder to encode a sentence into a latent vector z.''' 12 | def __init__( 13 | self, 14 | n_src_vocab, 15 | n_layers=1, 16 | d_word_vec=150, 17 | d_inner_hid=300, 18 | dropout=0.1, 19 | d_out_hid=300, 20 | use_cuda=False, 21 | ): 22 | super(Encoder, self).__init__() 23 | 24 | self.drop = nn.Dropout(dropout) 25 | self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 26 | 27 | self.n_layers = n_layers 28 | self.d_inner_hid = d_inner_hid 29 | self.use_cuda = use_cuda 30 | 31 | # NOTE:Maybe try GRU 32 | self.rnn = nn.LSTM(d_word_vec, d_inner_hid, n_layers, dropout=dropout) 33 | 34 | # For generating Gaussian distribution 35 | self._enc_mu = nn.Linear(d_inner_hid, d_out_hid) 36 | self._enc_log_sigma = nn.Linear(d_inner_hid, d_out_hid) 37 | 38 | self.init_weights() 39 | 40 | # Borrow from https://github.com/ethanluoyc/pytorch-vae/blob/master/vae.py 41 | def _sample_latent(self, enc_hidden): 42 | mu = self._enc_mu(enc_hidden) 43 | log_sigma = self._enc_log_sigma(enc_hidden) 44 | sigma = torch.exp(log_sigma) 45 | std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float() 46 | 47 | self.z_mean = mu 48 | self.z_sigma = sigma 49 | 50 | std_z_var = Variable(std_z, requires_grad=False) 51 | std_z_var = check_cuda(std_z_var, self.use_cuda) 52 | return mu + sigma * std_z_var 53 | 54 | def forward(self, src_seq, hidden, dont_pass_emb=False): 55 | if dont_pass_emb: 56 | enc_input = self.drop(src_seq) 57 | else: 58 | enc_input = self.drop(self.src_word_emb(src_seq)) 59 | # Reshape tensor's shape to (d_word_vec, batch_size, d_inner_hid) 60 | enc_input = enc_input.permute(1, 0, 2) 61 | _, hidden = self.rnn(enc_input, hidden) 62 | hidden = ( 63 | self._sample_latent(hidden[0]), 64 | hidden[1] 65 | ) 66 | return hidden 67 | 68 | def init_hidden(self, batch_size): 69 | # NOTE: LSTM needs 2 hidden states 70 | hidden = [ 71 | Variable(torch.zeros(self.n_layers, batch_size, self.d_inner_hid)), 72 | Variable(torch.zeros(self.n_layers, batch_size, self.d_inner_hid)) 73 | ] 74 | hidden[0] = check_cuda(hidden[0], self.use_cuda) 75 | hidden[1] = check_cuda(hidden[1], self.use_cuda) 76 | return hidden 77 | 78 | def init_weights(self): 79 | initrange = 0.1 80 | self.src_word_emb.weight.data.uniform_(-initrange, initrange) 81 | self._enc_mu.weight.data.uniform_(-initrange, initrange) 82 | self._enc_log_sigma.weight.data.uniform_(-initrange, initrange) 83 | 84 | class Generator(nn.Module): 85 | '''A LSTM generator to synthesis a sentence with input (z, c) 86 | where z is a latent vector from encoder and c is attribute code. 87 | ''' 88 | def __init__( 89 | self, 90 | n_target_vocab, 91 | n_layers=1, 92 | d_word_vec=150, 93 | d_inner_hid=300, 94 | c_dim=1, 95 | dropout=0.1, 96 | use_cuda=False, 97 | ): 98 | super(Generator, self).__init__() 99 | 100 | self.drop = nn.Dropout(dropout) 101 | 102 | self.d_inner_hid = d_inner_hid 103 | self.c_dim = c_dim 104 | self.n_layers = n_layers 105 | self.use_cuda = use_cuda 106 | 107 | self.target_word_emb = nn.Embedding( 108 | n_target_vocab, d_word_vec, padding_idx=Constants.PAD) 109 | 110 | self.rnn = nn.LSTM(d_word_vec, d_inner_hid + c_dim, n_layers, dropout=dropout) 111 | 112 | self.to_word_emb = nn.Sequential( 113 | nn.Linear(d_inner_hid + c_dim, d_word_vec), 114 | nn.ReLU() 115 | ) 116 | self.linear = nn.Linear(d_word_vec, n_target_vocab) 117 | # Speicial embbeding for cold temperature trick 118 | self.one_hot_to_word_emb = nn.Linear(n_target_vocab, d_word_vec) 119 | # Share embbedding weight 120 | self.linear.weight = self.target_word_emb.weight 121 | self.one_hot_to_word_emb.weight = torch.nn.Parameter(self.linear.weight.permute(1, 0).data) 122 | 123 | self.softmax = nn.Softmax() 124 | 125 | self.init_weights() 126 | 127 | def forward(self, target_word, hidden, low_temp=False, one_hot_input=False): 128 | ''' hidden is composed of z and c ''' 129 | ''' input is word-by-word in Generator ''' 130 | if one_hot_input: 131 | dec_input = self.drop( 132 | self.one_hot_to_word_emb(target_word)).unsqueeze(0) 133 | else: 134 | dec_input = self.drop( 135 | self.target_word_emb(target_word)).unsqueeze(0) 136 | output, hidden = self.rnn(dec_input, hidden) 137 | output = self.to_word_emb(output) 138 | output = self.linear(output) 139 | # Low temperature factor trick 140 | if low_temp: 141 | pre_soft = output[0] 142 | lowed_output = pre_soft / 0.001 143 | output = self.softmax(lowed_output) 144 | return output, hidden, pre_soft 145 | return output, hidden 146 | 147 | def init_hidden_c_for_lstm(self, batch_size): 148 | hidden = Variable(torch.zeros(self.n_layers, batch_size, self.d_inner_hid)) 149 | hidden = check_cuda(hidden, self.use_cuda) 150 | return hidden 151 | 152 | def init_weights(self): 153 | initrange = 0.1 154 | self.target_word_emb.weight.data.uniform_(-initrange, initrange) 155 | self.to_word_emb[0].weight.data.uniform_(-initrange, initrange) 156 | self.linear.weight.data.uniform_(-initrange, initrange) 157 | 158 | class Discriminator(nn.Module): 159 | '''A CNN discriminator to classify the attributes given a sentence.''' 160 | def __init__( 161 | self, 162 | n_src_vocab, 163 | maxlen, 164 | d_word_vec=150, 165 | dropout=0.1, 166 | use_cuda=False, 167 | ): 168 | super(Discriminator, self).__init__() 169 | 170 | self.use_cuda = use_cuda 171 | 172 | self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 173 | self.drop = nn.Dropout(dropout) 174 | self.conv1 = nn.Conv1d(maxlen, 128, kernel_size=5) 175 | self.conv2 = nn.Conv1d(128, 128, kernel_size=5) 176 | self.conv3 = nn.Conv1d(128, 128, kernel_size=5) 177 | self.softmax = nn.LogSoftmax() 178 | 179 | def forward(self, input_sentence, is_softmax=False, dont_pass_emb=False): 180 | if dont_pass_emb: 181 | emb_sentence = input_sentence 182 | else: 183 | emb_sentence = self.src_word_emb(input_sentence) 184 | relu1 = F.relu(self.conv1(emb_sentence)) 185 | layer1 = F.max_pool1d(relu1, 3) 186 | relu2 = F.relu(self.conv2(layer1)) 187 | layer2 = F.max_pool1d(relu2, 3) 188 | layer3 = F.max_pool1d(F.relu(self.conv2(layer2)), 10) 189 | flatten = self.drop(layer2.view(layer3.size()[0], -1)) 190 | if not hasattr(self, 'linear'): 191 | self.linear = nn.Linear(flatten.size()[1], 2) 192 | self.linear = check_cuda(self.linear, self.use_cuda) 193 | logit = self.linear(flatten) 194 | if is_softmax: 195 | logit = self.softmax(logit) 196 | return logit 197 | -------------------------------------------------------------------------------- /Model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GBLin5566/toward-controlled-generation-of-text-pytorch/dfb95ccb6833321930cc08004eeffbeaf71082ed/Model/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Implementation of "Toward Controlled Generation of Text" 2 | This is a [PyTorch](http://pytorch.org/) implementation of the model 3 | proposed in paper [Toward Controlled Generation of Text](http://proceedings.mlr.press/v70/hu17e.html), 4 | which aims to generate natural language given some target attributes. 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.preprocessing import sequence 3 | from keras.datasets import imdb 4 | import torch 5 | from torch.autograd import Variable 6 | from torch.optim import Adam 7 | 8 | import Model.Constants as Constants 9 | from Model.Modules import Encoder, Generator, Discriminator 10 | from utils import check_cuda 11 | 12 | max_features = 10000 13 | maxlen = 20 14 | batch_size = 128 15 | epoch = 1 16 | c_dim = 2 17 | d_word_vec = 150 18 | lambda_c = 0.1 19 | lambda_z = 0.1 20 | use_cuda = False 21 | 22 | print('Loading data...') 23 | (x_train, y_train), (x_test, y_test) = imdb.load_data( 24 | num_words=max_features, 25 | start_char=Constants.BOS, 26 | oov_char=Constants.UNK, 27 | index_from=Constants.EOS, 28 | ) 29 | 30 | forward_dict = imdb.get_word_index() 31 | for key, value in forward_dict.items(): 32 | forward_dict[key] = value + Constants.EOS 33 | forward_dict[Constants.PAD_WORD] = Constants.PAD 34 | forward_dict[Constants.UNK_WORD] = Constants.UNK 35 | forward_dict[Constants.BOS_WORD] = Constants.BOS 36 | forward_dict[Constants.EOS_WORD] = Constants.EOS 37 | 38 | backward_dict = {} 39 | for key, value in forward_dict.items(): 40 | backward_dict[value] = key 41 | 42 | x_train = sequence.pad_sequences( 43 | x_train, 44 | maxlen=maxlen, 45 | padding='post', 46 | truncating='post', 47 | value=Constants.PAD, 48 | ) 49 | x_test = sequence.pad_sequences( 50 | x_test, 51 | maxlen=maxlen, 52 | padding='post', 53 | truncating='post', 54 | value=Constants.PAD, 55 | ) 56 | 57 | def get_batch(data, index, batch_size, testing=False): 58 | tensor = torch.from_numpy(data[index:index+batch_size]).type(torch.LongTensor) 59 | input_data = Variable(tensor, volatile=testing, requires_grad=False) 60 | input_data = check_cuda(input_data, use_cuda) 61 | output_data = input_data 62 | return input_data, output_data 63 | 64 | def get_batch_label(data, label, index, batch_size, testing=False): 65 | tensor = torch.from_numpy(data[index:index+batch_size]).type(torch.LongTensor) 66 | input_data = Variable(tensor, volatile=testing, requires_grad=False) 67 | input_data = check_cuda(input_data, use_cuda) 68 | label_tensor = torch.from_numpy(label[index:index+batch_size]).type(torch.LongTensor) 69 | output_data = Variable(label_tensor, volatile=testing, requires_grad=False) 70 | output_data = check_cuda(output_data, use_cuda) 71 | return input_data, output_data 72 | 73 | # Borrow from https://github.com/ethanluoyc/pytorch-vae/blob/master/vae.py 74 | def latent_loss(z_mean, z_stddev): 75 | mean_sq = z_mean * z_mean 76 | stddev_sq = z_stddev * z_stddev 77 | return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1) 78 | 79 | # Make instances 80 | encoder = Encoder( 81 | n_src_vocab=max_features, 82 | use_cuda=use_cuda, 83 | ) 84 | decoder = Generator( 85 | n_target_vocab=max_features, 86 | c_dim=c_dim, 87 | use_cuda=use_cuda, 88 | ) 89 | discriminator = Discriminator( 90 | n_src_vocab=max_features, 91 | maxlen=maxlen, 92 | use_cuda=use_cuda, 93 | ) 94 | encoder = check_cuda(encoder, use_cuda) 95 | decoder = check_cuda(decoder, use_cuda) 96 | discriminator = check_cuda(discriminator, use_cuda) 97 | criterion = torch.nn.CrossEntropyLoss() 98 | vae_parameters = list(encoder.parameters()) + list(decoder.parameters()) 99 | vae_opt = Adam(vae_parameters) 100 | e_opt = Adam(encoder.parameters()) 101 | g_opt = Adam(decoder.parameters()) 102 | d_opt = Adam(discriminator.parameters()) 103 | 104 | def train_discriminator(discriminator): 105 | # TODO: empirical Shannon entropy 106 | print_epoch = 0 107 | for epoch_index in range(epoch): 108 | for batch, index in enumerate(range(0, len(x_train) - 1, batch_size)): 109 | discriminator.train() 110 | input_data, output_data = get_batch_label( 111 | x_train, 112 | y_train, 113 | index, 114 | batch_size 115 | ) 116 | 117 | discriminator.zero_grad() 118 | 119 | output = discriminator(input_data) 120 | loss = criterion(output, output_data) 121 | loss.backward() 122 | d_opt.step() 123 | 124 | if batch % 25 == 0: 125 | print("[Discriminator] Epoch {} batch {}'s loss: {}".format( 126 | epoch_index, 127 | batch, 128 | loss.data[0], 129 | )) 130 | if print_epoch == epoch_index and print_epoch: 131 | discriminator.eval() 132 | print_epoch = epoch_index + 1 133 | input_data, output_data = get_batch_label(x_test, y_test, 0, len(y_test), testing=True) 134 | _, predicted = torch.max(discriminator(input_data).data, 1) 135 | correct = (predicted == torch.from_numpy(y_test)).sum() 136 | print("[Discriminator] Test accuracy {} %".format( 137 | 100 * correct / len(y_test) 138 | )) 139 | 140 | def train_vae(encoder, decoder): 141 | encoder.train() 142 | decoder.train() 143 | for epoch_index in range(epoch): 144 | for batch, index in enumerate(range(0, len(x_train) - 1, batch_size)): 145 | total_loss = 0 146 | input_data, output_data = get_batch(x_train, index, batch_size) 147 | encoder.zero_grad() 148 | decoder.zero_grad() 149 | vae_opt.zero_grad() 150 | 151 | # Considering the data may do not have enough data for batching 152 | # Init. hidden with len(input_data) instead of batch_size 153 | enc_hidden = encoder.init_hidden(len(input_data)) 154 | # Input of encoder is a batch of sequence. 155 | enc_hidden = encoder(input_data, enc_hidden) 156 | 157 | # Generate the random one-hot array from prior p(c) 158 | # NOTE: Assume general distribution for now 159 | random_one_dim = np.random.randint(c_dim, size=len(input_data)) 160 | one_hot_array = np.zeros((len(input_data), c_dim)) 161 | one_hot_array[np.arange(len(input_data)), random_one_dim] = 1 162 | 163 | c = torch.from_numpy(one_hot_array).float() 164 | var_c = Variable(c, requires_grad=False) 165 | var_c = check_cuda(var_c, use_cuda) 166 | # TODO: use iteration along first dim. 167 | cat_hidden = (torch.cat([enc_hidden[0][0], var_c], dim=1).unsqueeze(0), 168 | torch.cat([decoder.init_hidden_c_for_lstm(len(input_data))[0], var_c], dim=1).unsqueeze(0)) 169 | 170 | # Reshape output_data from (batch_size, seq_len) to (seq_len, batch_size) 171 | output_data = output_data.permute(1, 0) 172 | # Input of decoder is a batch of word-by-word. 173 | for index, word in enumerate(output_data): 174 | if index == len(output_data) - 1: 175 | break 176 | output, cat_hidden = decoder(word, cat_hidden) 177 | next_word = output_data[index+1] 178 | total_loss += criterion(output.view(-1, max_features), next_word) 179 | # Train 180 | avg_loss = total_loss.data[0] / maxlen 181 | ll = latent_loss(encoder.z_mean, encoder.z_sigma) 182 | total_loss += ll 183 | total_loss.backward() 184 | vae_opt.step() 185 | 186 | if batch % 25 == 0: 187 | print("[VAE] Epoch {} batch {}'s average language loss: {}, latent loss: {}".format( 188 | epoch_index, 189 | batch, 190 | avg_loss, 191 | ll.data[0], 192 | )) 193 | 194 | def train_vae_with_attr_loss(encoder, decoder, discriminator): 195 | for epoch_index in range(epoch): 196 | for batch, index in enumerate(range(0, len(x_train) - 1, batch_size)): 197 | encoder.zero_grad() 198 | decoder.zero_grad() 199 | e_opt.zero_grad() 200 | g_opt.zero_grad() 201 | vae_loss = 0 202 | ll = 0 203 | 204 | input_data, output_data = get_batch_label(x_train, y_train, index, batch_size) 205 | 206 | enc_hidden = encoder.init_hidden(len(input_data)) 207 | enc_hidden = encoder(input_data, enc_hidden) 208 | 209 | target = np.array([output_data.cpu().data.numpy()]).reshape(-1) 210 | one_hot_array = np.eye(c_dim)[target] 211 | c = torch.from_numpy(one_hot_array).float() 212 | var_c = Variable(c, requires_grad=False) 213 | var_c = check_cuda(var_c, use_cuda) 214 | # TODO: use iteration along first dim. 215 | cat_hidden = (torch.cat([enc_hidden[0][0], var_c], dim=1).unsqueeze(0), 216 | torch.cat([decoder.init_hidden_c_for_lstm(len(input_data))[0], var_c], dim=1).unsqueeze(0)) 217 | 218 | batch_init_word = np.zeros((batch_size, max_features)) 219 | batch_init_word[np.arange(batch_size), Constants.BOS] = 1 220 | batch_init_word = Variable(torch.from_numpy(batch_init_word), requires_grad=False).float() 221 | batch_init_word = check_cuda(batch_init_word, use_cuda) 222 | 223 | input_data = input_data.permute(1, 0) 224 | for index in range(maxlen - 1): 225 | if 'next_word' in locals(): 226 | word = next_word.squeeze(1) 227 | word = check_cuda(word, use_cuda) 228 | output, cat_hidden, pre_soft = decoder( 229 | word, 230 | cat_hidden, 231 | low_temp=True, 232 | one_hot_input=True 233 | ) 234 | else: 235 | word = batch_init_word 236 | word = check_cuda(word, use_cuda) 237 | output, cat_hidden, pre_soft = decoder( 238 | word, 239 | cat_hidden, 240 | low_temp=True, 241 | one_hot_input=True 242 | ) 243 | # From one-hot to word embedding 244 | next_word = output 245 | correct_word = input_data[index+1] 246 | vae_loss += criterion(pre_soft.view(-1, max_features), correct_word) 247 | if len(batch_init_word.size()) == 2: 248 | batch_init_word = batch_init_word.unsqueeze(1) 249 | if len(next_word.size()) == 2: 250 | next_word = next_word.unsqueeze(1) 251 | batch_init_word = torch.cat([batch_init_word, next_word], dim=1) 252 | # NOTE Latent loss 253 | ll = latent_loss(encoder.z_mean, encoder.z_sigma) 254 | # NOTE L_attr_c loss 255 | generated_sentence = batch_init_word 256 | discriminator.eval() 257 | logit = discriminator(generated_sentence, dont_pass_emb=True) 258 | l_attr_c = criterion(logit, output_data) 259 | # NOTE L_attr_z loss 260 | encoder.eval() 261 | generated_sentence = decoder.one_hot_to_word_emb(generated_sentence) 262 | encoded_gen = encoder.init_hidden(len(generated_sentence)) 263 | encoded_gen = encoder(generated_sentence, encoded_gen, dont_pass_emb=True) 264 | l_attr_z = latent_loss(encoder.z_mean, encoder.z_sigma) 265 | 266 | avg_loss = vae_loss.data[0] / maxlen 267 | 268 | total_vae_loss = vae_loss + ll 269 | extra_decoder_loss = lambda_c * l_attr_c + lambda_z * l_attr_z 270 | total_vae_loss.backward() 271 | #e_opt.step() 272 | #extra_decoder_loss.backward() 273 | #g_opt.step() 274 | vae_opt.step() 275 | 276 | if batch % 25 == 0: 277 | print("[Attr] Epoch {} batch {}'s average language loss: {}, latent loss: {}".format( 278 | epoch_index, 279 | batch, 280 | avg_loss, 281 | ll.data[0], 282 | )) 283 | print("l_attr_c loss: {}, l_attr_z loss: {}".format( 284 | l_attr_c.data[0], 285 | l_attr_z.data[0], 286 | )) 287 | 288 | 289 | def main_alg(encoder, decoder, discriminator): 290 | train_vae(encoder, decoder) 291 | repeat_times = 10 292 | for repeat_index in range(repeat_times): 293 | train_discriminator(discriminator) 294 | #train_discriminator(discriminator) 295 | #train_vae(encoder, decoder) 296 | train_vae_with_attr_loss(encoder, decoder, discriminator) 297 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch import cuda 2 | 3 | def check_cuda(torch_var, use_cuda=False): 4 | if use_cuda and cuda.is_available(): 5 | return torch_var.cuda() 6 | else: 7 | return torch_var 8 | --------------------------------------------------------------------------------