├── evaluation.py ├── baselines ├── __init__.py ├── ppvae.py └── optimus.py ├── .DS_Store ├── pics ├── acc_results.jpg └── pcae_struct.jpg ├── logger.py ├── LICENSE ├── pcae.py ├── README.md ├── utils.py ├── vae.py └── train.py /evaluation.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImKeTT/PCAE/HEAD/.DS_Store -------------------------------------------------------------------------------- /pics/acc_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImKeTT/PCAE/HEAD/pics/acc_results.jpg -------------------------------------------------------------------------------- /pics/pcae_struct.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImKeTT/PCAE/HEAD/pics/pcae_struct.jpg -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | """ 4 | @file: logger.py 5 | @author: ImKe at 2021/3/29 6 | @email: thq415_ic@yeah.net 7 | @feature: #Enter features here 8 | """ 9 | import logging 10 | 11 | 12 | class Logger(object): 13 | def __init__(self, log_file): 14 | self.logger = logging.getLogger() 15 | self.formatter = logging.Formatter(fmt='[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 16 | 17 | self.logger.setLevel(logging.INFO) 18 | self.logger.handlers = [] 19 | 20 | fh = logging.FileHandler(log_file, mode='w') 21 | fh.setLevel(logging.INFO) 22 | fh.setFormatter(self.formatter) 23 | self.logger.addHandler(fh) 24 | 25 | sh = logging.StreamHandler() 26 | sh.setLevel(logging.INFO) 27 | sh.setFormatter(self.formatter) 28 | self.logger.addHandler(sh) 29 | 30 | def info(self, text): 31 | self.logger.info(text) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Anton Kiselev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baselines/ppvae.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | import torch.nn.functional as F 6 | import sys 7 | sys.path.append("../") 8 | from utils import * 9 | 10 | class PPVAE(nn.Module): 11 | def __init__(self, Vae, config, device): 12 | super().__init__() 13 | self.vae = Vae 14 | self.device = device 15 | self.z_enc = nn.Sequential(nn.Linear(config.dim_z, config.dim_z//2), nn.LeakyReLU(0.2, True), 16 | nn.Linear(config.dim_z//2, config.dim_z//4), nn.LeakyReLU(0.2, True)) 17 | self.z_dec = nn.Sequential(nn.Linear(config.dim_bottle, config.dim_z//4), 18 | nn.Linear(config.dim_z//4, config.dim_z//2), 19 | nn.Linear(config.dim_z//2, config.dim_z)) 20 | self.z2mu = nn.Linear(config.dim_z//4, config.dim_bottle) 21 | self.z2lv = nn.Linear(config.dim_z//4, config.dim_bottle) 22 | 23 | self.label_emb = nn.Linear(config.class_num, config.dim_label) 24 | self.config = config 25 | self.freeze_encoder() 26 | self.freeze_decoder() 27 | 28 | def freeze_encoder(self): 29 | for param in self.vae.encoder.parameters(): 30 | param.requires_grad = False 31 | def freeze_decoder(self): 32 | for param in self.vae.decoder.parameters(): 33 | param.requires_grad = False 34 | 35 | def generate(self, z, top_k=10, top_p=0.5, temperature=1.0): 36 | num_samples = z.size(0) 37 | z = self.z_enc(z) 38 | mu, logvar = self.z2mu(z), self.z2lv(z) 39 | z_out = self.vae.reparameterize(mu, logvar, nsamples=1).squeeze(1) 40 | rec_z = self.z_dec(z_out) 41 | z = self.vae.linear(rec_z) 42 | 43 | generated =torch.tensor([[self.vae.tokenizer.bos_token_id]] * num_samples, device=self.device) 44 | while generated.shape[1] < 1000: 45 | decoder_outputs = self.vae.decoder(input_ids=generated, encoder_hidden_states=z) 46 | logits = self.vae.lm_head(decoder_outputs[0]) 47 | logits = logits[:, -1, :] / temperature 48 | filtered_logits = top_k_top_p_filtering_batch(logits, top_k=top_k, top_p=top_p) 49 | 50 | probabilities = F.softmax(filtered_logits, dim=-1) 51 | next_token_id = torch.multinomial(probabilities, 1) 52 | # next_token_id = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) 53 | 54 | generated = torch.cat((generated, next_token_id), dim=1) 55 | not_finished = next_token_id != self.vae.tokenizer.eos_token_id 56 | if torch.sum(not_finished) == 0: 57 | break 58 | return generated 59 | 60 | 61 | def forward(self, x, labels, x_attention_mask, decoder_inputs, decoder_inputs_mask=None, beta=1.): 62 | encoder_outputs = self.vae.encoder(x, x_attention_mask) 63 | pooled_hidden_fea = self.vae.pooler(encoder_outputs[0]) # model outputs are always tuple in pytorch-transformers (see doc) 64 | mu_in, logvar_in = self.vae.z_linear(pooled_hidden_fea).chunk(2, -1) 65 | 66 | z_in = self.vae.reparameterize(mu_in, logvar_in, nsamples=1) 67 | z_in = z_in.squeeze(1) 68 | z = self.z_enc(z_in) 69 | mu, logvar = self.z2mu(z), self.z2lv(z) 70 | 71 | loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) 72 | kl_mask = (loss_kl > self.config.dim_target_kl).float() 73 | loss_kl = (kl_mask * loss_kl).sum(dim=1) 74 | 75 | z_out = self.vae.reparameterize(mu, logvar, nsamples=1).squeeze(1) 76 | rec_z = self.z_dec(z_out) 77 | loss_z_rec = torch.mean(torch.pow(z_in - rec_z, 2)) 78 | 79 | # pdb.set_trace() 80 | z = self.vae.linear(rec_z) 81 | decoder_outputs = self.vae.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 82 | encoder_hidden_states=z) 83 | logits = self.vae.lm_head(decoder_outputs[0]) 84 | 85 | loss_rec = None 86 | loss_fct = CrossEntropyLoss() 87 | loss_rec = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 88 | 89 | loss = loss_z_rec + loss_rec + beta * loss_kl 90 | return loss_z_rec, loss_rec, loss_kl, loss -------------------------------------------------------------------------------- /pcae.py: -------------------------------------------------------------------------------- 1 | from vae import VAE 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | import torch.nn.functional as F 6 | from utils import * 7 | 8 | 9 | class BroadcastingNet(nn.Module): 10 | def __init__(self, emb_size, z_size, class_num, layer_num=5): 11 | """ 12 | Broadcasting Net in the original paper 13 | """ 14 | super().__init__() 15 | # self.es = emb_size 16 | # self.hs = hid_size 17 | # self.z_size = z_size 18 | self.lm = layer_num 19 | self.forwardnet = nn.ModuleList() 20 | self.act = nn.ReLU() 21 | for i in range(layer_num): 22 | self.forwardnet.append(nn.Linear(z_size+emb_size, z_size)) 23 | 24 | def forward(self, label, z): 25 | label_emb = label 26 | for i in range(self.lm): 27 | z = torch.cat([label_emb, z], -1) 28 | z = self.forwardnet[i](z) 29 | # z = self.act(z) 30 | return z 31 | 32 | class PCAE(nn.Module): 33 | def __init__(self, Vae, config, device, layer_num=5): 34 | super().__init__() 35 | self.vae = Vae 36 | self.device = device 37 | self.layer_num = layer_num 38 | self.lper = BroadcastingNet(config.dim_label, config.dim_z, config.class_num, self.layer_num) 39 | self.label_emb = nn.Linear(config.class_num, config.dim_label) 40 | self.config = config 41 | self.freeze_encoder() 42 | 43 | def freeze_encoder(self): 44 | for param in self.vae.encoder.parameters(): 45 | param.requires_grad = False 46 | def freeze_decoder(self): 47 | for param in self.vae.decoder.parameters(): 48 | param.requires_grad = False 49 | 50 | def generate(self, z, y, top_k=10, top_p=0.5, temperature=1.0): 51 | num_samples = z.size(0) 52 | labels = torch.full([num_samples, 1], y).to(self.device) 53 | labels = torch.stack([torch.eye(self.config.class_num)[label.squeeze(0)].to(self.device) for label in labels]) 54 | labels = self.label_emb(labels) 55 | 56 | z = self.lper(labels, z) 57 | z = self.vae.linear(z) 58 | 59 | generated =torch.tensor([[self.vae.tokenizer.bos_token_id]] * num_samples, device=self.device) 60 | while generated.shape[1] <= 50: 61 | decoder_outputs = self.vae.decoder(input_ids=generated, encoder_hidden_states=z) 62 | logits = self.vae.lm_head(decoder_outputs[0]) 63 | logits = logits[:, -1, :] / temperature 64 | filtered_logits = top_k_top_p_filtering_batch(logits, top_k=top_k, top_p=top_p) 65 | 66 | probabilities = F.softmax(filtered_logits, dim=-1) 67 | next_token_id = torch.multinomial(probabilities, 1) 68 | # next_token_id = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) 69 | 70 | generated = torch.cat((generated, next_token_id), dim=1) 71 | not_finished = next_token_id != self.vae.tokenizer.eos_token_id 72 | if torch.sum(not_finished) == 0: 73 | break 74 | return generated 75 | 76 | 77 | def forward(self, x, y, labels, x_attention_mask, decoder_inputs, use_mean, decoder_inputs_mask=None, beta=1.): 78 | encoder_outputs = self.vae.encoder(x, x_attention_mask) 79 | pooled_hidden_fea = self.vae.pooler(encoder_outputs[0]) # model outputs are always tuple in pytorch-transformers (see doc) 80 | mu, logvar = self.vae.z_linear(pooled_hidden_fea).chunk(2, -1) 81 | 82 | label = F.one_hot(y.squeeze(0).long(), self.config.class_num).float().to(self.device) 83 | label = self.label_emb(label) 84 | mu_label = self.lper(label, mu) 85 | 86 | latent_z = self.vae.reparameterize(mu_label, logvar, nsamples=1) 87 | z_label = latent_z.squeeze(1) 88 | loss_kl = 0.5 * (mu_label.pow(2) + logvar.exp() - logvar - 1) 89 | kl_mask = (loss_kl > self.config.dim_target_kl).float() 90 | loss_kl = (kl_mask * loss_kl).sum(dim=1) 91 | 92 | # pdb.set_trace() 93 | if not use_mean: 94 | z = self.vae.linear(z_label) 95 | else: 96 | z = mu_label 97 | z = self.vae.linear(z) 98 | decoder_outputs = self.vae.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 99 | encoder_hidden_states=z) 100 | logits = self.vae.lm_head(decoder_outputs[0]) 101 | 102 | loss_rec = None 103 | loss_fct = CrossEntropyLoss() 104 | loss_rec = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 105 | 106 | loss = loss_rec + beta * loss_kl 107 | return loss_rec, loss_kl, loss -------------------------------------------------------------------------------- /baselines/optimus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | """ 4 | @file: optimus.py 5 | @author: ImKe 6 | @email: tuisaac163@gmail.com 7 | @feature: #Enter features here 8 | Optimus Baseline from https://github.com/ChunyuanLI/Optimus/blob/master/code/examples/big_ae/run_lm_vae_label_ctrl_gen.py 9 | """ 10 | from tqdm import tqdm 11 | import torch.nn as nn 12 | import torch 13 | from torch.nn import CrossEntropyLoss 14 | import torch.nn.functional as F 15 | import sys 16 | sys.path.append("../") 17 | from utils import * 18 | 19 | 20 | class Optimus(nn.Module): 21 | def __init__(self, Vae, config, device, layer_num=5): 22 | super().__init__() 23 | self.vae = Vae 24 | self.device = device 25 | self.layer_num = layer_num 26 | self.latent_generator = nn.Linear(config.dim_z, config.dim_z) 27 | self.latent_discriminator = nn.Linear(config.dim_z, 1) 28 | self.label_emb = nn.Linear(config.class_num, config.dim_label) 29 | self.label_linear = nn.Linear(config.dim_label, config.dim_z) 30 | self.config = config 31 | 32 | self.CrossEntropyLoss = nn.CrossEntropyLoss() 33 | self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss() 34 | 35 | def freeze_encoder(self): 36 | for param in self.vae.encoder.parameters(): 37 | param.requires_grad = False 38 | def freeze_decoder(self): 39 | for param in self.vae.decoder.parameters(): 40 | param.requires_grad = False 41 | 42 | def generate(self, z, y, top_k=10, top_p=0.5, temperature=1.0): 43 | num_samples = z.size(0) 44 | y = torch.tensor(y, device=self.device) 45 | label = F.one_hot(y.squeeze(0).long(), self.config.class_num).float().to(self.device) 46 | label = self.label_emb(label) 47 | label = self.label_linear(label) 48 | 49 | z_label = z + label 50 | # pdb.set_trace() 51 | z = self.vae.linear(z_label) 52 | 53 | generated =torch.tensor([[self.vae.tokenizer.bos_token_id]] * num_samples, device=self.device) 54 | while generated.shape[1] < 80: 55 | decoder_outputs = self.vae.decoder(input_ids=generated, encoder_hidden_states=z) 56 | logits = self.vae.lm_head(decoder_outputs[0]) 57 | logits = logits[:, -1, :] / temperature 58 | filtered_logits = top_k_top_p_filtering_batch(logits, top_k=top_k, top_p=top_p) 59 | 60 | probabilities = F.softmax(filtered_logits, dim=-1) 61 | next_token_id = torch.multinomial(probabilities, 1) 62 | # next_token_id = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) 63 | 64 | generated = torch.cat((generated, next_token_id), dim=1) 65 | not_finished = next_token_id != self.vae.tokenizer.eos_token_id 66 | if torch.sum(not_finished) == 0: 67 | break 68 | return generated 69 | 70 | 71 | def forward(self, x, y, lm_labels, x_attention_mask, decoder_inputs, decoder_inputs_mask=None, alpha=1., beta=1.): 72 | ones_label = torch.ones_like(y).to(dtype=torch.float32) 73 | zeros_label = torch.zeros_like(y).to(dtype=torch.float32) 74 | random_noise = torch.nn.init.normal_(torch.empty(x.size(0), self.config.dim_z)).to(device=self.device, dtype=torch.float32) 75 | 76 | encoder_outputs = self.vae.encoder(x, x_attention_mask) 77 | pooled_hidden_fea = self.vae.pooler(encoder_outputs[0]) # model outputs are always tuple in pytorch-transformers (see doc) 78 | mu, logvar = self.vae.z_linear(pooled_hidden_fea).chunk(2, -1) 79 | 80 | label = F.one_hot(y.squeeze(0).long(), self.config.class_num).float().to(self.device) 81 | label = self.label_emb(label) 82 | label = self.label_linear(label) 83 | 84 | latent_z = self.vae.reparameterize(mu, logvar, nsamples=1) 85 | z = latent_z.squeeze(1) 86 | loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) 87 | kl_mask = (loss_kl > self.config.dim_target_kl).float() 88 | loss_kl = (kl_mask * loss_kl).sum(dim=1) 89 | 90 | ## fake z 91 | gen_z = self.latent_generator(random_noise) 92 | # gen_z = random_noise 93 | 94 | #################### Latent discriminator for sampling from a simple distribution #################### 95 | prob_encode_z_dis = self.latent_discriminator(z).squeeze(1).float() # (B) 96 | prob_gen_z_dis = self.latent_discriminator(gen_z).squeeze(1).float() # (B) 97 | # Train latent discriminator 98 | loss_lsd = self.BCEWithLogitsLoss(prob_gen_z_dis, zeros_label) + self.BCEWithLogitsLoss(prob_encode_z_dis, ones_label) 99 | acc_encode_z_dis = ((prob_encode_z_dis >= 0).float() == ones_label).float() 100 | acc_gen_z_dis = ((prob_gen_z_dis >= 0).float() == zeros_label).float() 101 | # Train sampler adversarially 102 | loss_lsg = self.BCEWithLogitsLoss(prob_gen_z_dis, ones_label) 103 | 104 | z_label = z + label 105 | 106 | # pdb.set_trace() 107 | z = self.vae.linear(z_label) 108 | decoder_outputs = self.vae.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 109 | encoder_hidden_states=z) 110 | logits = self.vae.lm_head(decoder_outputs[0]) 111 | 112 | loss_rec = None 113 | loss_rec = self.CrossEntropyLoss(logits.view(-1, logits.size(-1)), lm_labels.view(-1)) 114 | 115 | loss = loss_rec + beta * loss_kl + alpha * loss_lsg 116 | return loss_rec, loss_kl, loss_lsd, loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCAE: A Framework of Plug-in Conditional Auto-Encoder for Controllable Text Generation 2 | 3 | Official PyTorch implementation of *[PCAE: A Framework of Plug-in Conditional Auto-Encoder for Controllable Text Generation](https://www.sciencedirect.com/science/article/pii/S0950705122008942)*, published in *Knowledge-Based Systems*. We provide PCAE as well as all implemented baselines (PPVAE and Optimus) under pre-trained BART. 4 | 5 | ![pcae_struct](pics/pcae_struct.jpg) 6 | 7 | 8 | 9 | ## News 10 | 11 | - [2023-02-02] We are sharing fine-tuned BART VAE [here](https://drive.google.com/file/d/1hp_vm1rQIxWgCSkgm7cKtasGWxCG1kL8/view?usp=sharing) now! 12 | - [2022-10-10] Our paper is available on [arXiv](https://arxiv.org/abs/2210.03496) now. 13 | - [2022-09-27] Our paper is now in this [paper list](https://github.com/ImKeTT/CTG-latentAEs), which aims at collecting all kinds of latent variational auto-encoders that controllably generate texts. Feel free to check it out and contribute! 14 | - [2022-09-27] We release our PCAE and baseline codes under the setup of pre-trained BART. 15 | - [2022-09-06] Our work PCAE is [available](https://www.sciencedirect.com/science/article/pii/S0950705122008942) online. 16 | - [2022-08-21] Our paper PCAE is accepted to *Knowledge-Based Systems*. 17 | 18 | ## Setup 19 | 20 | Make sure you have installed 21 | 22 | ```bash 23 | transformers 24 | tqdm 25 | torch 26 | numpy 27 | ``` 28 | 29 | ## Dataset 30 | 31 | We conduct five tasks span from three datasets: *Yelp review*, *Titles* and *Yahoo Question*. 32 | 33 | We provide our full processed datasets in: 34 | 35 | - [BaiduPan](https://pan.baidu.com/s/11vEqD_liL_U8brCEC6Nohg?pwd=bx81) (password `bx81`) 36 | - [GoogleDrive](https://drive.google.com/file/d/1XDHN3rbXhl-dc_cqIFsQCd01pr6BiQjn/view?usp=sharing) 37 | 38 | Please download `data.zip` and unzip it to the current folder. 39 | 40 | You can also try your own data, follow the split in `data` folder. Note that, for PPVAE, you have to manually split negative samples for every control signal. 41 | 42 | ## Training 43 | 44 | ### Stage 1 BART VAE Finetuning 45 | 46 | You can download full `checkpoints` folder from [here](https://drive.google.com/file/d/1hp_vm1rQIxWgCSkgm7cKtasGWxCG1kL8/view?usp=sharing), unzip the `checkpoints.zip` to the current folder. 47 | 48 | Or you can train the BART VAE from the scratch: 49 | 50 | Finetuning on three datasets. (choose DATA from `yelp`, `yahoo`, `titles`, and EPOCH from 8, 10, 10): 51 | 52 | ```shell 53 | DATA=yelp 54 | EPOCH=8 55 | python train.py --run_mode vae_ft --dataset $DATA --zmanner hidden\ 56 | --gpu 0 1 --dim_z 128 --per_gpu_train_batch_size 64\ 57 | --train_epochs $EPOCH --fb_mode 1 --lr 1e-4 --first_token_pooling --eval_every 500 58 | ``` 59 | 60 | 61 | 62 | ![acc_results](pics/acc_results.jpg) 63 | 64 | ### Stage 2.1 PCAE Plug-in Training 65 | 66 | Plug-in training of PCAE. Choose arguments below: 67 | 68 | + TASK: [sentiment, tense, topics, quess_s, quess] 69 | 70 | (topics, quess_s, quess corresponds to $topics_S,topics_M, topics_L$ in the paper respectively) 71 | 72 | + SAMPLE_N: [100, 300, 500, 800, 1000] 73 | 74 | + NUM_LAYER: int number from 8 to 15 is fine 75 | 76 | + EPOCH: 10 to 20 is fine, less SAMPLE_N means less EPOCH required 77 | 78 | ```shell 79 | TASK=sentiment 80 | EPOCH=10 81 | SAMPLE_N=100 82 | NUM_LAYER=10 83 | 84 | python train.py --run_mode pcae --task $TASK --zmanner hidden\ 85 | --gpu 0 --dim_z 128 --per_gpu_train_batch_size 5\ 86 | --plugin_train_epochs $EPOCH --fb_mode 1 --sample_n $SAMPLE_N\ 87 | --layer_num $NUM_LAYER --lr 1e-4 --use_mean 88 | ``` 89 | 90 | ### Stage 2.2 PPVAE Plug-in Training 91 | 92 | Plug-in training of PPVAE BART. Choose arguments below: 93 | 94 | + TASK: [sentiment, tense, topics, quess_s, quess] 95 | + SAMPLE_N: [100, 300, 500, 800, 1000] 96 | + TASK_LABEL: [pos, neg] for sentiment task; [present, past] for tense task; [0, 1, 2, 3, 4] for topics task; [0, 1, 2, 3, ..., 9] for quess task 97 | + EPOCH: 10 to 20 is fine, less SAMPLE_N means less EPOCH required 98 | 99 | For example, if you want to train PPVAE to generate **positive** sentences in sentiment task with 100 training samples per class, run: 100 | 101 | ```shell 102 | TASK=sentiment 103 | EPOCH=10 104 | SAMPLE_N=100 105 | TASK_LABEL=pos 106 | 107 | python train.py --run_mode ppvae --task $TASK --zmanner hidden\ 108 | --gpu 0 --dim_z 128 --per_gpu_train_batch_size 5\ 109 | --plugin_train_epochs $EPOCH --fb_mode 1 --sample_n $SAMPLE_N\ 110 | --task_label $TASK_LABEL --lr 1e-4 --ppvae_dim_bottle 25 111 | ``` 112 | 113 | ### Stage 2.3 Optimus_{bart} Plug-in Finetuning 114 | 115 | Plug-in finetuning of Optimus under BART setup. Choose arguments below: 116 | 117 | + TASK: [sentiment, tense, topics, quess_s, quess] 118 | + SAMPLE_N: [100, 300, 500, 800, 1000] 119 | + EPOCH: 10 to 20 is fine, less SAMPLE_N means less EPOCH required 120 | 121 | ```shell 122 | TASK=sentiment 123 | EPOCH=10 124 | SAMPLE_N=100 125 | 126 | python train.py --run_mode optimus --task $TASK --zmanner hidden\ 127 | --gpu 0 --dim_z 128 --per_gpu_train_batch_size 5\ 128 | --plugin_train_epochs $EPOCH --fb_mode 1 --sample_n $SAMPLE_N\ 129 | --lr 1e-4 130 | ``` 131 | 132 | ## Others 133 | 134 | Please [email](tuisaac163@gmail.com) me or open an issue if you have further questions. 135 | 136 | if you find our work useful, please cite the paper and star the repo~ :) 137 | 138 | ```bibtex 139 | @article{tu2022pcae, 140 | title={PCAE: A framework of plug-in conditional auto-encoder for controllable text generation}, 141 | author={Tu, Haoqin and Yang, Zhongliang and Yang, Jinshuai and Zhang, Siyu and Huang, Yongfeng}, 142 | journal={Knowledge-Based Systems}, 143 | volume={256}, 144 | pages={109766}, 145 | year={2022}, 146 | publisher={Elsevier} 147 | } 148 | ``` 149 | 150 | We thank open sourced codes related to VAEs and plug-and-play models, which inspired our work!! 151 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 8 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 9 | Args: 10 | logits: logits distribution shape (vocabulary size) 11 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 12 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 13 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 14 | """ 15 | assert ( 16 | logits.dim() == 1 17 | ) # batch size 1 for now - could be updated for more but the code would be less clear 18 | top_k = min(top_k, logits.size(-1)) # Safety check 19 | if top_k > 0: 20 | # Remove all tokens with a probability less than the last token of the top-k 21 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 22 | logits[indices_to_remove] = filter_value 23 | 24 | if top_p > 0.0: 25 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 26 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 27 | 28 | # Remove tokens with cumulative probability above the threshold 29 | sorted_indices_to_remove = cumulative_probs > top_p 30 | # Shift the indices to the right to keep also the first token above the threshold 31 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 32 | sorted_indices_to_remove[..., 0] = 0 33 | 34 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 35 | logits[indices_to_remove] = filter_value 36 | return logits 37 | 38 | def top_k_top_p_filtering_batch(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 39 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 40 | Args: 41 | logits: logits distribution shape (vocabulary size) 42 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 43 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 44 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 45 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 46 | """ 47 | # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 48 | 49 | top_k = min(top_k, logits.size(-1)) # Safety check 50 | 51 | if top_k > 0: 52 | # Remove all tokens with a probability less than the last token of the top-k 53 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 54 | # logits.masked_fill_(logits < threshold, filter_value) # (B, vocab_size) 55 | logits[indices_to_remove] = filter_value 56 | 57 | if top_p > 0.0: 58 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (B, vocab_size) 59 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (B, vocab_size) 60 | 61 | # Remove tokens with cumulative probability above the threshold 62 | sorted_indices_to_remove = cumulative_probs > top_p 63 | 64 | # Shift the indices to the right to keep also the first token above the threshold 65 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 66 | sorted_indices_to_remove[..., 0] = 0 67 | 68 | # indices_to_remove = sorted_indices[sorted_indices_to_remove] 69 | 70 | # logits.masked_fill_(indices_to_remove, filter_value) 71 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) 72 | logits[indices_to_remove] = filter_value 73 | 74 | return logits 75 | 76 | def read_txt(file): 77 | with open(file, "r") as f: 78 | txt = f.readlines() 79 | return [item[:-1] for item in txt] 80 | 81 | class ConditionalGenerationDataset(Dataset): 82 | def __init__(self, dl): 83 | self.x = [] 84 | self.text_len = [] 85 | self.y = [] 86 | self.init_data(dl) 87 | self.length = len(self.x) 88 | 89 | def init_data(self, dl): 90 | for inst in dl: 91 | inst = inst.split('\t') 92 | ## label 93 | self.y.append(inst[0]) 94 | self.x.append(inst[1]) 95 | self.text_len.append(len(inst[1].split())) 96 | 97 | def __getitem__(self, index: int) -> dict: 98 | ## add BOS and EOS special token 99 | x = self.x[index][:-1] 100 | y = self.y[index] 101 | 102 | return str(x), int(y) 103 | 104 | def __len__(self): 105 | return self.length 106 | 107 | ## call for direct input 108 | @staticmethod 109 | def from_file(file_path: str): 110 | with open(file_path, 'r') as f: 111 | dl = f.readlines() 112 | return ConditionalGenerationDataset(dl) 113 | 114 | def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5): 115 | L = np.ones(n_iter) * stop 116 | period = int(n_iter / n_cycle) 117 | step = (stop - start) / (period * ratio) # linear schedule 118 | 119 | for c in range(n_cycle): 120 | v, i = start, 0 121 | while v <= stop and (int(i + c * period) < n_iter): 122 | L[int(i + c * period)] = v 123 | v += step 124 | i += 1 125 | return L 126 | 127 | def frange_cycle_zero_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio_increase=0.3, ratio_zero=0.1): 128 | L = np.ones(n_iter) * stop 129 | period = n_iter / n_cycle 130 | step = (stop-start) / (period * ratio_increase) # linear schedule 131 | 132 | for c in range(n_cycle): 133 | v, i = start, 0 134 | while v <= stop and (int(i + c * period) < n_iter): 135 | if i < period * ratio_zero: 136 | L[int(i + c * period)] = start 137 | else: 138 | L[int(i + c * period)] = v 139 | v += step 140 | i += 1 141 | return L -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | """ 4 | @file: vae.py 5 | @author: ImKe 6 | @email: tuisaac163@gmail.com 7 | @feature: #Enter features here 8 | Modified from https://github.com/ChunyuanLI/Optimus/blob/master/code/examples/big_ae/modules/vae.py 9 | """ 10 | import torch.nn as nn 11 | import torch 12 | from torch.nn import CrossEntropyLoss 13 | import torch.nn.functional as F 14 | 15 | 16 | class BartPooler(nn.Module): 17 | def __init__(self, config, first_token_pooling=True): 18 | super(BartPooler, self).__init__() 19 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 20 | self.activation = nn.Tanh() 21 | self.first_token_pooling = first_token_pooling 22 | 23 | def forward(self, hidden_states): 24 | if self.first_token_pooling: 25 | pooled_token_tensor = hidden_states[:, 0] 26 | else: 27 | pooled_token_tensor = hidden_states.mean(1) 28 | pooled_output = self.dense(pooled_token_tensor) 29 | pooled_output = self.activation(pooled_output) 30 | return pooled_output 31 | 32 | class VAE(nn.Module): 33 | """VAE with normal prior""" 34 | def __init__(self, encoder, decoder, tokenizer, args, device): 35 | super(VAE, self).__init__() 36 | self.encoder = encoder 37 | self.decoder = decoder 38 | self.tokenizer = tokenizer 39 | self.config = self.decoder.config 40 | self.pooler = BartPooler(self.config, args.first_token_pooling) 41 | 42 | padding_idx, vocab_size = tokenizer.pad_token_id, self.config.vocab_size 43 | self.vocab_size = vocab_size 44 | self.shared = nn.Embedding(self.vocab_size, self.config.d_model, padding_idx) 45 | 46 | self.args = args 47 | self.nz = args.dim_z 48 | self.z_linear = nn.Linear(self.config.hidden_size, 2 * self.nz, bias=False) 49 | 50 | self.eos_token_id = tokenizer.eos_token_id 51 | self.pad_token_id = padding_idx 52 | self.device = device 53 | 54 | 55 | # connector: from Bert hidden units to the latent space 56 | # self.linear = nn.Linear(args.nz, 2 * args.nz, bias=False) 57 | 58 | # Standard Normal prior 59 | loc = torch.zeros(self.nz, device=device) 60 | scale = torch.ones(self.nz, device=device) 61 | self.prior = torch.distributions.normal.Normal(loc, scale) 62 | 63 | if self.args.zmanner=="mem": 64 | self.linear = nn.Linear(self.nz, self.config.hidden_size * self.config.decoder_layers, bias=False) 65 | elif self.args.zmanner=="hidden": 66 | self.linear = nn.Linear(self.nz, self.config.hidden_size, bias=False) 67 | 68 | self.register_buffer("final_logits_bias", torch.zeros((1, self.shared.num_embeddings))) 69 | self.lm_head = nn.Linear(self.decoder.config.d_model, self.shared.num_embeddings, bias=False) 70 | 71 | def connect(self, bert_fea, nsamples=1): 72 | """ 73 | Returns: Tensor1, Tensor2 74 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 75 | Tensor2: the tenor of KL for each x with shape [batch] 76 | """ 77 | 78 | # (batch_size, nz) 79 | mean, logvar = self.z_linear(bert_fea).chunk(2, -1) 80 | # pdb.set_trace() 81 | # mean, logvar = mean.squeeze(0), logvar.squeeze(0) 82 | 83 | # (batch, nsamples, nz) 84 | z = self.reparameterize(mean, logvar, nsamples) 85 | KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 86 | return z, KL 87 | 88 | def connect_deterministic(self, bert_fea, nsamples=1): 89 | """ 90 | Returns: Tensor1, Tensor2 91 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 92 | Tensor2: the tenor of KL for each x with shape [batch] 93 | """ 94 | 95 | # (batch_size, nz) 96 | 97 | mean, logvar = self.z_linear(bert_fea).chunk(2, -1) 98 | # pdb.set_trace() 99 | # mean, logvar = mean.squeeze(0), logvar.squeeze(0) 100 | 101 | logvar.fill_(.0) 102 | # (batch, nsamples, nz) 103 | z = self.reparameterize(mean, logvar, nsamples) 104 | KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 105 | 106 | return z, KL 107 | 108 | def reparameterize(self, mu, logvar, nsamples=1): 109 | """sample from posterior Gaussian family 110 | Args: 111 | mu: Tensor 112 | Mean of gaussian distribution with shape (batch, nz) 113 | logvar: Tensor 114 | logvar of gaussian distibution with shape (batch, nz) 115 | Returns: Tensor 116 | Sampled z with shape (batch, nsamples, nz) 117 | """ 118 | batch_size, nz = mu.size() 119 | std = logvar.mul(0.5).exp() 120 | 121 | mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) 122 | std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) 123 | 124 | eps = torch.zeros_like(std_expd).normal_() 125 | 126 | return mu_expd + torch.mul(eps, std_expd) 127 | 128 | def cond_gen(self, logits, labels): 129 | masked_lm_loss = None 130 | loss_fct = CrossEntropyLoss() 131 | masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 132 | return masked_lm_loss 133 | 134 | def rec_loss(self, x, target): 135 | loss = F.cross_entropy( 136 | x.transpose(1, 2), 137 | target, 138 | ignore_index=self.pad_token_id, 139 | reduction="none", 140 | ) 141 | return loss 142 | 143 | def build_past(self, z): 144 | projection = self.linear(z) 145 | 146 | cross_attn = projection.reshape( 147 | self.config.decoder_layers, 148 | projection.shape[0], 149 | self.config.decoder_attention_heads, 150 | 1, 151 | int(self.config.hidden_size / self.config.decoder_attention_heads) 152 | ) 153 | past_key_values = tuple((ca, ca) for ca in cross_attn) 154 | return past_key_values 155 | 156 | def generate(self, z, top_k=10, top_p=0.5, use_cache=None, temperature=1.0): 157 | num_samples = z.size(0) 158 | if self.args.zmanner == "mem": 159 | z = self.build_past(z) 160 | generated = torch.tensor([[self.tokenizer.bos_token_id]] * num_samples, device=device) 161 | while generated.shape[1] < 1000: 162 | decoder_outputs = self.decoder(input_ids=generated, past_key_values=z, use_cache=use_cache) 163 | z = decoder_outputs[1] if use_cache is not None else z 164 | logits = self.lm_head(decoder_outputs[0]) 165 | logits = logits[:, -1, :] / temperature 166 | filtered_logits = top_k_top_p_filtering_batch(logits, top_k=top_k, top_p=top_p) 167 | 168 | probabilities = F.softmax(filtered_logits, dim=-1) 169 | next_token_id = torch.multinomial(probabilities, 1) 170 | # next_token_id = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) 171 | 172 | generated = torch.cat((generated, next_token_id), dim=1) 173 | if next_token_id == self.tokenizer.eos_token_id: 174 | break 175 | elif self.args.zmanner == "hidden": 176 | z = self.linear(z) 177 | generated =torch.tensor([[self.tokenizer.bos_token_id]] * num_samples, device=device) 178 | while generated.shape[1] < 1000: 179 | decoder_outputs = self.decoder(input_ids=generated, encoder_hidden_states=z, use_cache=use_cache) 180 | logits = self.lm_head(decoder_outputs[0]) 181 | logits = logits[:, -1, :] / temperature 182 | filtered_logits = top_k_top_p_filtering_batch(logits, top_k=top_k, top_p=top_p) 183 | 184 | probabilities = F.softmax(filtered_logits, dim=-1) 185 | next_token_id = torch.multinomial(probabilities, 1) 186 | # next_token_id = torch.argmax(filtered_logits, dim=-1).unsqueeze(0) 187 | 188 | generated = torch.cat((generated, next_token_id), dim=1) 189 | not_finished = next_token_id != self.tokenizer.eos_token_id 190 | if torch.sum(not_finished) == 0: 191 | break 192 | # if next_token_id == self.tokenizer.eos_token_id: 193 | # break 194 | return generated 195 | 196 | 197 | def forward(self, inputs, labels, attention_mask, decoder_inputs, decoder_inputs_mask=None, beta=1.): 198 | reconstrution_mask=(labels != self.pad_token_id).float() 199 | sent_length = torch.sum(reconstrution_mask, dim=1) 200 | 201 | 202 | encoder_outputs = self.encoder(inputs, attention_mask) 203 | pooled_hidden_fea = self.pooler(encoder_outputs[0]) # model outputs are always tuple in pytorch-transformers (see doc) 204 | 205 | if self.args.fb_mode==0: 206 | # Connect hidden feature to the latent space 207 | latent_z, loss_kl = self.connect(pooled_hidden_fea) 208 | latent_z = latent_z.squeeze(1) 209 | if self.args.zmanner == "mem": 210 | past = self.build_past(latent_z) 211 | past_length = 1 # past[0][0].size(-2) 212 | # Decoding 213 | decoder_inputs = decoder_inputs[:, -1:] 214 | decoder_outputs = self.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 215 | encoder_hidden_states=encoder_outputs[0], past_key_values=past) 216 | elif self.args.zmanner == "hidden": 217 | z = self.linear(latent_z) 218 | decoder_outputs = self.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 219 | encoder_hidden_states=z) 220 | logits = self.lm_head(decoder_outputs[0]) 221 | if self.args.zmanner == "mem": 222 | logits = logits.repeat(1, labels.size(-1), 1) 223 | loss_rec = self.cond_gen(logits, labels) # model outputs are always tuple in pytorch-transformers (see doc) 224 | 225 | elif self.args.fb_mode==1: 226 | # Connect hidden feature to the latent space 227 | mu, logvar = self.z_linear(pooled_hidden_fea).chunk(2, -1) 228 | latent_z = self.reparameterize(mu, logvar, nsamples=1) 229 | latent_z = latent_z.squeeze(1) 230 | loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) 231 | kl_mask = (loss_kl > self.args.dim_target_kl).float() 232 | loss_kl = (kl_mask * loss_kl).sum(dim=1) 233 | 234 | # pdb.set_trace() 235 | if self.args.zmanner == "mem": 236 | past = self.build_past(latent_z) 237 | past_length = 1 # past[0][0].size(-2) 238 | # Decoding 239 | decoder_inputs = decoder_inputs[:, -1:] 240 | decoder_outputs = self.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 241 | encoder_hidden_states=encoder_outputs[0], past_key_values=past) 242 | elif self.args.zmanner == "hidden": 243 | z = self.linear(latent_z) 244 | decoder_outputs = self.decoder(input_ids=decoder_inputs, attention_mask=decoder_inputs_mask, 245 | encoder_hidden_states=z) 246 | logits = self.lm_head(decoder_outputs[0]) 247 | if self.args.zmanner == "mem": 248 | logits = logits.repeat(1, labels.size(-1), 1) 249 | loss_rec = self.cond_gen(logits, labels) # model outputs are always tuple in pytorch-transformers (see doc) 250 | 251 | 252 | # pdb.set_trace() 253 | if self.args.length_weighted_loss: 254 | loss = loss_rec / sent_length + beta * loss_kl 255 | else: 256 | loss = loss_rec + beta * loss_kl 257 | 258 | 259 | return loss_rec, loss_kl, loss -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vae import VAE 3 | from pcae import PCAE 4 | from baselines.ppvae import PPVAE 5 | from baselines.optimus import Optimus 6 | from tqdm import tqdm 7 | from transformers import BartTokenizer, BartModel, AdamW, BartForConditionalGeneration 8 | from utils import * 9 | import torch.nn as nn 10 | import torch 11 | from logger import Logger 12 | import datetime, math, os, sys, json, argparse, time, re, copy 13 | import numpy as np 14 | 15 | parser = argparse.ArgumentParser() 16 | ## data preparation 17 | parser.add_argument("--dataset", default='', type=str, required=False, choices=["yelp", "yahoo", "titles"], 18 | help="Training dataset.") 19 | parser.add_argument('--no_gpu', action='store_true') 20 | parser.add_argument('--gpu', nargs='+', type=int, default=[0]) 21 | parser.add_argument('--seed', default=42, type=int) 22 | parser.add_argument("--per_gpu_train_batch_size", default=42, type=int, 23 | help="Batch size per GPU/CPU for training.") 24 | parser.add_argument("--per_gpu_eval_batch_size", default=5, type=int, 25 | help="Batch size per GPU/CPU for evaluation.") 26 | parser.add_argument("--eval_every", default=500, type=int, 27 | help="eval step in finetuning.") 28 | parser.add_argument("--dim_label", default=8, type=int, help="Dim of label embedding layer") 29 | parser.add_argument("--zmanner", default='hidden', type=str, choices=['hidden', 'mem']) 30 | parser.add_argument("--dim_z", default=128, type=int, help="Dim of latent space") 31 | 32 | parser.add_argument("--bart_version", default='facebook/bart-base', type=str) 33 | parser.add_argument('--use_mean', action='store_true', help="Use mean representation of latent space in PCAE") 34 | parser.add_argument('--length_weighted_loss', action='store_true') 35 | parser.add_argument("--gen_batch_size", default=5, type=int, 36 | help="Batch size per GPU/CPU for generation.") 37 | parser.add_argument("--workers", default=3, type=int, 38 | help="Dataloader worker.") 39 | parser.add_argument("--lr", default=3e-5, type=float, help="The initial learning rate.") 40 | parser.add_argument("--alpha", default=0.1, type=float) 41 | parser.add_argument("--beta", default=1., type=float) 42 | parser.add_argument("--train_epochs", default=10, type=int, help="Training Epoch for Finetuning.") 43 | parser.add_argument("--plugin_train_epochs", default=20, type=int, help="Training Epoch for Plugin Training.") 44 | 45 | parser.add_argument("--fb_mode", default=1, type=int, help="Free bit threshold mode.") 46 | parser.add_argument("--layer_num", default=10, type=int, help="Broadcasting Layer Number of PCAE.") 47 | parser.add_argument("--dim_target_kl", default=0.1, type=float, help="KL thresh for each dimension in VAE.") 48 | parser.add_argument("--gen_k", default=100, type=int, help="Number of batch sentence to generate.") 49 | parser.add_argument("--task", default='sentiment', type=str, choices=['sentiment', 'tense', 'topics', 'quess_s', 'quess']) 50 | parser.add_argument("--task_label", default='pos', type=str, help="For PPVAE only") 51 | parser.add_argument("--ppvae_dim_bottle", default=25, type=int, help="For PPVAE only") 52 | parser.add_argument("--ppvae_loss_relax", default=10, type=float, help="For PPVAE only") 53 | parser.add_argument("--sample_n", default=100, type=int, help="Number of training instance for each class for training.") 54 | 55 | parser.add_argument("--run_mode", default='pcae', type=str, choices=['vae_ft', 'ppvae', 'pcae', 'optimus']) 56 | 57 | parser.add_argument('--first_token_pooling', action='store_true', 58 | help='Use the first token as the pooling signal in VAE, else the mean pooling.') 59 | 60 | 61 | 62 | ## Data setup details for PCAE plugin training 63 | task_dataset_dict = {"sentiment": "yelp", "tense": "yelp", "topics": "titles", "quess_s": "yahoo", "quess": "yahoo"} 64 | class_num_dataset_dict = {"sentiment": 2, "tense": 2, "topics": 4, "quess_s": 6, "quess": 10} 65 | 66 | def evaluate_vae(dataloader, model, tokenizer, device, logger): 67 | model.eval() 68 | losses = [] 69 | losses_rec = [] 70 | losses_kl = [] 71 | for batch_id, texts in enumerate(tqdm(dataloader)): 72 | out = tokenizer.batch_encode_plus(texts, return_tensors="pt", padding=True) 73 | pad_token_id = tokenizer.pad_token_id 74 | y = out['input_ids'] 75 | y_ids = y[:, :-1].contiguous() 76 | y_mask = out['attention_mask'][:, :-1] 77 | lm_labels = y[:, 1:].clone() 78 | lm_labels[y[:, 1:] == pad_token_id] = -100 79 | 80 | loss_rec, loss_kl, loss = model(out['input_ids'].to(device), labels=lm_labels.to(device), 81 | decoder_inputs=y_ids.to(device), 82 | attention_mask=out['attention_mask'].to(device)) 83 | loss_rec, loss_kl, loss = loss_rec.mean(), loss_kl.mean(), loss.mean() 84 | losses.append(loss.detach().cpu().numpy()) 85 | losses_rec.append(loss_rec.detach().cpu().numpy()) 86 | losses_kl.append(loss_kl.detach().cpu().numpy()) 87 | 88 | logger.info("Val Loss : {:.4f}".format(np.mean(losses))) 89 | logger.info("Val Loss Rec : {:.4f}".format(np.mean(losses_rec))) 90 | logger.info("Val Loss KL. : {:.4f}".format(np.mean(losses_kl))) 91 | model.train() 92 | return np.mean(losses_rec) 93 | 94 | def VAE_finetuning(args): 95 | gpu = not args.no_gpu 96 | args.train_batch_size = args.per_gpu_train_batch_size 97 | args.eval_batch_size = args.per_gpu_eval_batch_size 98 | device = torch.device(args.gpu[0] if gpu else "cpu") 99 | 100 | # randomness 101 | np.random.seed(args.seed) 102 | prng = np.random.RandomState() 103 | torch.random.manual_seed(args.seed) 104 | if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) 105 | 106 | ## 'facebook/bart-base' 107 | tokenizer = BartTokenizer.from_pretrained(args.bart_version) 108 | model = BartForConditionalGeneration.from_pretrained(args.bart_version) 109 | model.to(device) 110 | 111 | # epos = [8, 10, 10] 112 | # bss = [64, 64, 64] 113 | dataname = args.dataset 114 | vae = VAE(model.model.encoder, model.model.decoder, tokenizer, args, device) 115 | vae.shared = model.model.shared 116 | vae.lm_head = model.lm_head 117 | if len(args.gpu)>1: 118 | vae = nn.DataParallel(vae, device_ids=args.gpu) 119 | vae.to(device) 120 | os.makedirs(f"checkpoints/{dataname}/BART", exist_ok=True) 121 | log_file = f"checkpoints/{dataname}/BART/log_e{args.train_epochs}_vae_z{args.dim_z}_{args.zmanner}.txt" 122 | logger = Logger(log_file) 123 | train_data = read_txt(f"data/{dataname}/train.txt") 124 | val_data = read_txt(f"data/{dataname}/valid.txt") 125 | 126 | train_loader = DataLoader(train_data, batch_size=args.train_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 127 | iterations = args.train_epochs * len(train_loader) 128 | print(f"Iterations: {iterations}") 129 | betas = frange_cycle_zero_linear(iterations, start=0.0, stop=1.0, n_cycle=4, 130 | ratio_increase=0.2, ratio_zero=0.1) 131 | val_loader = DataLoader(val_data, batch_size=args.eval_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 132 | optimizer = AdamW(vae.parameters(), lr=args.lr, correct_bias=True) 133 | 134 | ## Fine-tuning 135 | best_val_loss = 99999. 136 | total_iters = 0 137 | for e in range(args.train_epochs): 138 | vae.train() 139 | losses = [] 140 | losses_rec = [] 141 | losses_kl = [] 142 | for batch_id, texts in enumerate(tqdm(train_loader)): 143 | beta = betas[total_iters] 144 | out = tokenizer.batch_encode_plus(texts, return_tensors="pt", padding=True) 145 | pad_token_id = tokenizer.pad_token_id 146 | ## target and source input for VAE 147 | y = out['input_ids'] 148 | y_ids = y[:, :-1].contiguous() 149 | y_mask = out['attention_mask'][:, :-1] 150 | lm_labels = y[:, 1:].clone() 151 | lm_labels[y[:, 1:] == pad_token_id] = -100 152 | 153 | loss_rec, loss_kl, loss = vae(out['input_ids'].to(device), 154 | labels=lm_labels.to(device), decoder_inputs=y_ids.to(device), 155 | attention_mask=out['attention_mask'].to(device), beta=beta) 156 | 157 | loss_rec, loss_kl, loss = loss_rec.mean(), loss_kl.mean(), loss.mean() 158 | loss.backward() 159 | optimizer.step() 160 | optimizer.zero_grad() 161 | losses.append(loss.detach().cpu().numpy()) 162 | losses_rec.append(loss_rec.detach().cpu().numpy()) 163 | losses_kl.append(loss_kl.detach().cpu().numpy()) 164 | total_iters += 1 165 | 166 | ## eval the model in each epoch 167 | if total_iters % args.eval_every==0: 168 | logger.info("Train Loss : {:.4f}".format(np.mean(losses))) 169 | logger.info("Train Loss Rec : {:.4f}".format(np.mean(losses_rec))) 170 | logger.info("Train Loss KL. : {:.4f}".format(np.mean(losses_kl))) 171 | val_loss = evaluate_vae(val_loader, vae, tokenizer, device, logger) 172 | ## reset training loss lists 173 | losses = [] 174 | losses_rec = [] 175 | losses_kl = [] 176 | 177 | if val_loss < best_val_loss: 178 | best_val_loss = val_loss 179 | logger.info("Saving the Best Eval Weights..") 180 | save_orderdict = vae.state_dict() 181 | torch.save(save_orderdict, f"checkpoints/{dataname}/BART/best_val_vae_{args.zmanner}.pt") 182 | 183 | def Optimus_plugin_fintuning(args): 184 | gpu = not args.no_gpu 185 | args.train_batch_size = args.per_gpu_train_batch_size 186 | args.eval_batch_size = args.per_gpu_eval_batch_size 187 | device = torch.device(args.gpu[0] if gpu else "cpu") 188 | 189 | # randomness 190 | np.random.seed(args.seed) 191 | prng = np.random.RandomState() 192 | torch.random.manual_seed(args.seed) 193 | if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) 194 | 195 | dataset = task_dataset_dict[args.task] 196 | 197 | ## setup config for each training task 198 | class config(): 199 | pass 200 | config.fb_mode=args.fb_mode ## 1 201 | config.dim_target_kl=args.dim_target_kl ## 0.1 202 | config.zmanner=args.zmanner ## "hidden" 203 | config.dim_label = args.dim_label ## 8 204 | config.dim_z=args.dim_z ## 128 205 | config.class_num = class_num_dataset_dict[args.task] 206 | config.train = f"data/{dataset}/{args.task}/label_text{args.sample_n}.txt" 207 | config.test = f"data/{dataset}/{args.task}/label_text100.txt" 208 | config.lr = args.lr ## 1e-4 209 | config.epoch=args.plugin_train_epochs ## 28 210 | config.gen_k = args.gen_k ## 100 211 | config.first_token_pooling = args.first_token_pooling 212 | 213 | tokenizer = BartTokenizer.from_pretrained(args.bart_version) 214 | model = BartForConditionalGeneration.from_pretrained(args.bart_version) 215 | 216 | vae = VAE(model.model.encoder, model.model.decoder, tokenizer, config, device) 217 | state = torch.load(f"checkpoints/{dataset}/BART/best_val_vae_hidden.pt") 218 | if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module' 219 | state_copy = copy.copy(state) 220 | keys = state_copy.keys() 221 | for k in keys: 222 | state[k.replace('module.', '')] = state.pop(k) 223 | vae.load_state_dict(state) 224 | del state 225 | print("Finish Loading Pre-trained Weights..") 226 | model = Optimus(vae, config, device) 227 | model.to(device) 228 | 229 | traindata = ConditionalGenerationDataset.from_file(file_path=config.train) 230 | trainloader = DataLoader(traindata, batch_size=args.train_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 231 | 232 | optimizer = AdamW(model.parameters(), lr=config.lr, correct_bias=True) 233 | os.makedirs(f"bart_result/optimus/results/{dataset}", exist_ok=True) 234 | log_file = f"bart_result/optimus/results/{dataset}/{args.task}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ns{args.sample_n}.log" 235 | logger = Logger(log_file) 236 | 237 | ## Model training 238 | model.train() 239 | total_iters = 0 240 | for e in range(config.epoch): 241 | losses = [] 242 | losses_rec = [] 243 | losses_kl = [] 244 | losses_lsd = [] 245 | for batch_id, (x, ylabel) in enumerate(tqdm(trainloader)): 246 | out = tokenizer.batch_encode_plus(x, return_tensors="pt", padding=True) 247 | pad_token_id = tokenizer.pad_token_id 248 | y = out['input_ids'] 249 | y_ids = y[:, :-1].contiguous() 250 | y_mask = out['attention_mask'][:, :-1] 251 | lm_labels = y[:, 1:].clone() 252 | lm_labels[y[:, 1:] == pad_token_id] = -100 253 | 254 | ylabel = torch.tensor(ylabel, device=device) 255 | loss_rec, loss_kl, loss_lsd, loss = model(out['input_ids'].to(device), ylabel, 256 | lm_labels.to(device), 257 | out['attention_mask'].to(device), 258 | y_ids.to(device), alpha=args.alpha, beta=args.beta) 259 | 260 | loss_rec, loss_kl, loss_lsd, loss = loss_rec.mean(), loss_kl.mean(), loss_lsd.mean(), loss.mean() 261 | loss.backward() 262 | optimizer.step() 263 | optimizer.zero_grad() 264 | losses.append(loss.detach().cpu().numpy()) 265 | losses_rec.append(loss_rec.detach().cpu().numpy()) 266 | losses_kl.append(loss_kl.detach().cpu().numpy()) 267 | losses_lsd.append(loss_lsd.detach().cpu().numpy()) 268 | total_iters += 1 269 | 270 | logger.info("Train Loss : {:.4f}".format(np.mean(losses))) 271 | logger.info("Train Loss Rec : {:.4f}".format(np.mean(losses_rec))) 272 | logger.info("Train Loss KL. : {:.4f}".format(np.mean(losses_kl))) 273 | logger.info("Train Loss LSD : {:.4f}".format(np.mean(losses_lsd))) 274 | 275 | ## Generation 276 | os.makedirs(f"bart_result/optimus/sentences/{dataset}/{args.task}", exist_ok=True) 277 | for y in range(config.class_num): 278 | model.eval() 279 | finalsents = [] 280 | for _ in tqdm(range(config.gen_k)): 281 | z = torch.randn(args.gen_batch_size, config.dim_z).to(device) 282 | with torch.no_grad(): 283 | sents = model.generate(z, y) 284 | texts = [] 285 | for ii in sents: 286 | endindex = (ii == tokenizer.eos_token_id).nonzero(as_tuple=True)[0] 287 | if len(endindex) !=0: 288 | texts.append(tokenizer.decode(ii[1: min(endindex)])) 289 | else: 290 | continue 291 | finalsents.extend(texts) 292 | with open(f"bart_result/optimus/sentences/{dataset}/{args.task}/{y}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ns{args.sample_n}-{config.gen_k}K.txt", "w") as f: 293 | for sent in finalsents: 294 | f.write(sent + "\n") 295 | f.close() 296 | 297 | def PPVAE_plugin_training(args): 298 | gpu = not args.no_gpu 299 | args.train_batch_size = args.per_gpu_train_batch_size 300 | args.eval_batch_size = args.per_gpu_eval_batch_size 301 | device = torch.device(args.gpu[0] if gpu else "cpu") 302 | 303 | # randomness 304 | np.random.seed(args.seed) 305 | prng = np.random.RandomState() 306 | torch.random.manual_seed(args.seed) 307 | if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) 308 | 309 | # datali = [100, 300, 500, 800] 310 | dataset = task_dataset_dict[args.task] 311 | class config(): 312 | pass 313 | config.fb_mode=args.fb_mode ##1 314 | config.dim_target_kl=args.dim_target_kl ##0.1 315 | config.beta=args.beta ##1.0 316 | config.zmanner=args.zmanner ##"hidden" 317 | config.dim_label = args.dim_label ##8 318 | config.dim_z = args.dim_z ##128 319 | config.class_num = class_num_dataset_dict[args.task] 320 | config.alpha = args.alpha ## 0.1 321 | config.train = f"data/{dataset}/{args.task}/{args.sample_n}.{args.task_label}" 322 | 323 | if args.task == "sentiment": 324 | if args.task_label == "pos": 325 | config.neg_train = f"data/{dataset}/{args.task}/{args.sample_n}.neg" 326 | elif args.task_label == "neg": 327 | config.neg_train = f"data/{dataset}/{args.task}/{args.sample_n}.pos" 328 | else: 329 | raise NotImplementedError 330 | elif args.task == "tense": 331 | if args.task_label == "past": 332 | config.neg_train = f"data/{dataset}/{args.task}/{args.sample_n}.present" 333 | elif args.task_label == "present": 334 | config.neg_train = f"data/{dataset}/{args.task}/{args.sample_n}.past" 335 | else: 336 | raise NotImplementedError 337 | else: 338 | config.neg_train = f"data/{dataset}/{args.task}/{args.sample_n}.{args.task_label}_neg" 339 | 340 | config.test = f"data/{dataset}/{args.task}/label_text100.txt" 341 | config.lr = args.lr ## 3e-4 342 | config.epoch=args.plugin_train_epochs ## 12 343 | config.dim_target_kl=args.dim_target_kl ## 0.1 344 | config.gen_k = args.gen_k ## 300 345 | config.dim_bottle=args.ppvae_dim_bottle ## 25 346 | config.relax=args.ppvae_loss_relax ## 10.0 347 | config.first_token_pooling = args.first_token_pooling 348 | 349 | tokenizer = BartTokenizer.from_pretrained(args.bart_version) 350 | model = BartForConditionalGeneration.from_pretrained(args.bart_version) 351 | 352 | vae = VAE(model.model.encoder, model.model.decoder, tokenizer, config, device) 353 | state = torch.load(f"checkpoints/{dataset}/BART/best_val_vae_hidden.pt") 354 | if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module' 355 | state_copy = copy.copy(state) 356 | keys = state_copy.keys() 357 | for k in keys: 358 | state[k.replace('module.', '')] = state.pop(k) 359 | vae.load_state_dict(state) 360 | del state 361 | print("Finish Loading Pre-trained Weights..") 362 | 363 | model = PPVAE(vae, config, device) 364 | model.to(device) 365 | 366 | traindata = read_txt(config.train) 367 | negtraindata = read_txt(config.neg_train) 368 | trainloader = DataLoader(traindata, batch_size=args.train_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 369 | neg_trainloader = DataLoader(negtraindata, batch_size=args.train_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 370 | 371 | optimizer = AdamW(model.parameters(), lr=config.lr, correct_bias=True) 372 | os.makedirs(f"bart_result/ppvae/results/{dataset}", exist_ok=True) 373 | log_file = f"bart_result/ppvae/results/{dataset}/{args.task}-{args.task_label}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ns{args.sample_n}.log" 374 | logger = Logger(log_file) 375 | 376 | model.train() 377 | total_iters = 0 378 | for e in range(config.epoch): 379 | losses = [] 380 | losses_rec = [] 381 | losses_kl = [] 382 | losses_z_rec=[] 383 | for batch_id, x in enumerate(tqdm(trainloader)): 384 | out = tokenizer.batch_encode_plus(x, return_tensors="pt", padding=True) 385 | pad_token_id = tokenizer.pad_token_id 386 | y = out['input_ids'] 387 | y_ids = y[:, :-1].contiguous() 388 | y_mask = out['attention_mask'][:, :-1] 389 | lm_labels = y[:, 1:].clone() 390 | lm_labels[y[:, 1:] == pad_token_id] = -100 391 | 392 | neg_x = next(iter(neg_trainloader)) 393 | out_neg = tokenizer.batch_encode_plus(neg_x, return_tensors="pt", padding=True) 394 | neg_y = out_neg['input_ids'] 395 | neg_y_ids = neg_y[:, :-1].contiguous() 396 | neg_y_mask = out_neg['attention_mask'][:, :-1] 397 | neg_lm_labels = neg_y[:, 1:].clone() 398 | neg_lm_labels[neg_y[:, 1:] == pad_token_id] = -100 399 | 400 | loss_z_rec, loss_rec, loss_kl, loss = model(out['input_ids'].to(device), 401 | lm_labels.to(device), 402 | out['attention_mask'].to(device), 403 | y_ids.to(device)) 404 | neg_loss_z_rec, neg_loss_rec, neg_loss_kl, neg_loss = model(out_neg['input_ids'].to(device), 405 | neg_lm_labels.to(device), 406 | out_neg['attention_mask'].to(device), 407 | neg_y_ids.to(device)) 408 | 409 | loss_z_rec, loss_rec, loss_kl, loss = loss_z_rec.mean(), loss_rec.mean(), loss_kl.mean(), loss.mean() 410 | neg_loss_z_rec, neg_loss_rec, neg_loss_kl, neg_loss = neg_loss_z_rec.mean(), neg_loss_rec.mean(), neg_loss_kl.mean(), neg_loss.mean() 411 | loss = torch.max(torch.tensor(0).to(device), loss - config.alpha*neg_loss + config.relax) 412 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # max_grad_norm=1.0 413 | loss.backward() 414 | optimizer.step() 415 | optimizer.zero_grad() 416 | losses.append(loss.detach().cpu().numpy()) 417 | losses_z_rec.append(loss_z_rec.detach().cpu().numpy()) 418 | losses_rec.append(loss_rec.detach().cpu().numpy()) 419 | losses_kl.append(loss_kl.detach().cpu().numpy()) 420 | total_iters += 1 421 | 422 | logger.info("Train Loss : {:.4f}".format(np.mean(losses))) 423 | logger.info("Train Loss z Rec : {:.4f}".format(np.mean(losses_z_rec))) 424 | logger.info("Train Loss Rec : {:.4f}".format(np.mean(losses_rec))) 425 | logger.info("Train Loss KL. : {:.4f}".format(np.mean(losses_kl))) 426 | 427 | ## Generation 428 | os.makedirs(f"bart_result/ppvae/sentences/{dataset}/{args.task}", exist_ok=True) 429 | model.eval() 430 | finalsents = [] 431 | for _ in tqdm(range(config.gen_k)): 432 | z = torch.randn(args.gen_batch_size, config.dim_z).to(device) 433 | with torch.no_grad(): 434 | sents = model.generate(z) 435 | texts = [] 436 | for ii in sents: 437 | endindex = (ii == tokenizer.eos_token_id).nonzero(as_tuple=True)[0] 438 | if len(endindex) !=0: 439 | texts.append(tokenizer.decode(ii[1: min(endindex)])) 440 | else: 441 | continue 442 | finalsents.extend(texts) 443 | with open(f"bart_result/ppvae/sentences/{dataset}/{args.task}/{args.task_label}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ns{args.sample_n}-{config.gen_k}K-alpha{config.alpha}.txt", "w") as f: 444 | for sent in finalsents: 445 | f.write(sent + "\n") 446 | f.close() 447 | 448 | def PCAE_plugin_training(args): 449 | gpu = not args.no_gpu 450 | args.train_batch_size = args.per_gpu_train_batch_size 451 | args.eval_batch_size = args.per_gpu_eval_batch_size 452 | device = torch.device(args.gpu[0] if gpu else "cpu") 453 | 454 | # randomness 455 | np.random.seed(args.seed) 456 | prng = np.random.RandomState() 457 | torch.random.manual_seed(args.seed) 458 | if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) 459 | 460 | dataset = task_dataset_dict[args.task] 461 | 462 | ## setup config for each training task 463 | class config(): 464 | pass 465 | config.fb_mode=args.fb_mode ## 1 466 | config.dim_target_kl=args.dim_target_kl ## 0.1 467 | config.zmanner=args.zmanner ## "hidden" 468 | config.dim_label = args.dim_label ## 8 469 | config.dim_z=args.dim_z ## 128 470 | config.class_num = class_num_dataset_dict[args.task] 471 | config.train = f"data/{dataset}/{args.task}/label_text{args.sample_n}.txt" 472 | config.layer_num=args.layer_num 473 | config.lr = args.lr ## 4e-4 474 | config.epoch=args.plugin_train_epochs ## 28 475 | config.gen_k = args.gen_k ## 100 476 | config.first_token_pooling = args.first_token_pooling 477 | 478 | tokenizer = BartTokenizer.from_pretrained(args.bart_version) 479 | model = BartForConditionalGeneration.from_pretrained(args.bart_version) 480 | 481 | vae = VAE(model.model.encoder, model.model.decoder, tokenizer, config, device) 482 | state = torch.load(f"checkpoints/{dataset}/BART/best_val_vae_hidden.pt") 483 | if 'module' in list(state.keys())[0]: # model_path is data parallel model with attr 'module' 484 | state_copy = copy.copy(state) 485 | keys = state_copy.keys() 486 | for k in keys: 487 | state[k.replace('module.', '')] = state.pop(k) 488 | vae.load_state_dict(state) 489 | del state 490 | print("Finish Loading Pre-trained Weights..") 491 | 492 | model = PCAE(vae, config, device, layer_num=config.layer_num) 493 | model.to(device) 494 | traindata = ConditionalGenerationDataset.from_file(file_path=config.train) 495 | trainloader = DataLoader(traindata, batch_size=args.train_batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, shuffle=True) 496 | 497 | optimizer = AdamW(model.parameters(), lr=config.lr, correct_bias=True) 498 | os.makedirs(f"bart_result/pcae/results/{dataset}", exist_ok=True) 499 | log_file = f"bart_result/pcae/results/{dataset}/{args.task}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ln{config.layer_num}-ns{args.sample_n}-mean{args.use_mean}.log" 500 | logger = Logger(log_file) 501 | 502 | model.train() 503 | total_iters = 0 504 | for e in range(config.epoch): 505 | losses = [] 506 | losses_rec = [] 507 | losses_kl = [] 508 | for batch_id, (x, ylabel) in enumerate(tqdm(trainloader)): 509 | out = tokenizer.batch_encode_plus(x, return_tensors="pt", padding=True) 510 | pad_token_id = tokenizer.pad_token_id 511 | y = out['input_ids'] 512 | y_ids = y[:, :-1].contiguous() 513 | y_mask = out['attention_mask'][:, :-1] 514 | lm_labels = y[:, 1:].clone() 515 | lm_labels[y[:, 1:] == pad_token_id] = -100 516 | 517 | ylabel = torch.tensor(ylabel, device=device) 518 | loss_rec, loss_kl, loss = model(out['input_ids'].to(device), ylabel, 519 | lm_labels.to(device), 520 | out['attention_mask'].to(device), 521 | y_ids.to(device), args.use_mean, beta=args.beta) 522 | 523 | loss_rec, loss_kl, loss = loss_rec.mean(), loss_kl.mean(), loss.mean() 524 | loss.backward() 525 | optimizer.step() 526 | optimizer.zero_grad() 527 | losses.append(loss.detach().cpu().numpy()) 528 | losses_rec.append(loss_rec.detach().cpu().numpy()) 529 | losses_kl.append(loss_kl.detach().cpu().numpy()) 530 | total_iters += 1 531 | 532 | logger.info("Train Loss : {:.4f}".format(np.mean(losses))) 533 | logger.info("Train Loss Rec : {:.4f}".format(np.mean(losses_rec))) 534 | logger.info("Train Loss KL. : {:.4f}".format(np.mean(losses_kl))) 535 | 536 | ## generation 537 | os.makedirs(f"bart_result/pcae/sentences/{dataset}/{args.task}", exist_ok=True) 538 | for y in range(config.class_num): 539 | model.eval() 540 | finalsents = [] 541 | ## iteratively generate controllable texts 542 | for _ in tqdm(range(config.gen_k)): 543 | z = torch.randn(args.gen_batch_size, config.dim_z).to(device) 544 | with torch.no_grad(): 545 | sents = model.generate(z, y) 546 | texts = [] 547 | for ii in sents: 548 | endindex = (ii == tokenizer.eos_token_id).nonzero(as_tuple=True)[0] 549 | if len(endindex) !=0: 550 | texts.append(tokenizer.decode(ii[1: min(endindex)])) 551 | else: 552 | texts.append(tokenizer.decode(ii[1: ])) 553 | finalsents.extend(texts) 554 | with open(f"bart_result/pcae/sentences/{dataset}/{args.task}/{y}-epoch{config.epoch}-bs{args.train_batch_size}-lr{config.lr}-ln{config.layer_num}-ns{args.sample_n}-{config.gen_k}K.txt", "w") as f: 555 | for sent in finalsents: 556 | f.write(sent + "\n") 557 | f.close() 558 | 559 | if __name__=="__main__": 560 | args = parser.parse_args() 561 | if args.run_mode == "vae_ft": 562 | VAE_finetuning(args) 563 | elif args.run_mode == "pcae": 564 | PCAE_plugin_training(args) 565 | elif args.run_mode == "ppvae": 566 | PPVAE_plugin_training(args) 567 | elif args.run_mode == "optimus": 568 | Optimus_plugin_fintuning(args) 569 | else: 570 | raise NotImplementedError --------------------------------------------------------------------------------