├── .idea ├── .name ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── misc.xml ├── Paraphrase-OPT.iml └── workspace.xml ├── fine_tuning ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── fine_tune_bart.cpython-310.pyc │ └── fine_tune_opt.cpython-310.pyc ├── README.md ├── fine_tune_bart.py └── fine_tune_opt.py ├── metrics ├── __init__.py ├── benchmark_runs │ ├── model_preds │ │ └── .gitignore │ └── model_benchmarked_results │ │ └── .gitignore ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── bart_metric.cpython-310.pyc └── bart_metric.py ├── soft_prompt_tuning ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── soft_embedding.cpython-310.pyc │ ├── soft_embedding.cpython-39.pyc │ ├── soft_prompt_opt.cpython-310.pyc │ └── soft_prompt_opt.cpython-39.pyc ├── soft_embedding.py └── soft_prompt_opt.py ├── training_datasets ├── __init__.py ├── para_nmt │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── para_nmt.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── para_nmt.cpython-310.pyc │ ├── README.md │ └── para_nmt.py ├── parabank │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── parabank.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── parabank.cpython-310.pyc │ └── parabank.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── paracombined.cpython-39.pyc │ └── paracombined.cpython-310.pyc └── paracombined.py ├── .gitattributes ├── images ├── W&B Chart.png ├── thats_quite_big.png ├── Screenshot 2022-06-20 174801.png ├── Screenshot 2022-06-20 174920.png └── Screenshot 2022-06-21 131258.png ├── .gitignore ├── .ipynb_checkpoints ├── README-checkpoint.md ├── REPORT-checkpoint.ipynb ├── example-checkpoint.html └── EDA_benchmarking-checkpoint.ipynb ├── train_fine_tune.py ├── config-defaults.yaml ├── test_vis.py ├── train_soft_prompt.py ├── gui.py ├── requirements.txt ├── README.md ├── model_benchmark.py ├── REPORT.md └── EDA_embedding.ipynb /.idea/.name: -------------------------------------------------------------------------------- 1 | gui.py -------------------------------------------------------------------------------- /fine_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soft_prompt_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training_datasets/para_nmt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training_datasets/parabank/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/benchmark_runs/model_preds/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /metrics/benchmark_runs/model_benchmarked_results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /images/W&B Chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/images/W&B Chart.png -------------------------------------------------------------------------------- /images/thats_quite_big.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/images/thats_quite_big.png -------------------------------------------------------------------------------- /images/Screenshot 2022-06-20 174801.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/images/Screenshot 2022-06-20 174801.png -------------------------------------------------------------------------------- /images/Screenshot 2022-06-20 174920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/images/Screenshot 2022-06-20 174920.png -------------------------------------------------------------------------------- /images/Screenshot 2022-06-21 131258.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/images/Screenshot 2022-06-21 131258.png -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/metrics/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/bart_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/metrics/__pycache__/bart_metric.cpython-310.pyc -------------------------------------------------------------------------------- /fine_tuning/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/fine_tuning/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /fine_tuning/__pycache__/fine_tune_bart.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/fine_tuning/__pycache__/fine_tune_bart.cpython-310.pyc -------------------------------------------------------------------------------- /fine_tuning/__pycache__/fine_tune_opt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/fine_tuning/__pycache__/fine_tune_opt.cpython-310.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /training_datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /training_datasets/__pycache__/paracombined.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/__pycache__/paracombined.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/__pycache__/paracombined.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/__pycache__/paracombined.cpython-310.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/soft_embedding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/soft_embedding.cpython-310.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/soft_embedding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/soft_embedding.cpython-39.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/soft_prompt_opt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/soft_prompt_opt.cpython-310.pyc -------------------------------------------------------------------------------- /soft_prompt_tuning/__pycache__/soft_prompt_opt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/soft_prompt_tuning/__pycache__/soft_prompt_opt.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/para_nmt/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/para_nmt/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/para_nmt/__pycache__/para_nmt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/para_nmt/__pycache__/para_nmt.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/parabank/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/parabank/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /training_datasets/parabank/__pycache__/parabank.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/parabank/__pycache__/parabank.cpython-39.pyc -------------------------------------------------------------------------------- /fine_tuning/README.md: -------------------------------------------------------------------------------- 1 | This folder is specifically for vanilla fine tuning directly on OPT models for benchmarking of 2 | soft prompt performance vs traditional fine tuning of model weights. -------------------------------------------------------------------------------- /training_datasets/para_nmt/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/para_nmt/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /training_datasets/para_nmt/__pycache__/para_nmt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/para_nmt/__pycache__/para_nmt.cpython-310.pyc -------------------------------------------------------------------------------- /training_datasets/parabank/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/parabank/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /training_datasets/parabank/__pycache__/parabank.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clyde013/Paraphrase-OPT/HEAD/training_datasets/parabank/__pycache__/parabank.cpython-310.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | training_checkpoints/ 2 | wandb/ 3 | /training_datasets/para_nmt/para-nmt-5m-processed.zip 4 | 5 | /metrics/benchmark_runs/model_benchmarked_results/* 6 | /metrics/benchmark_runs/model_preds/* 7 | /visualisations/ 8 | 9 | .env 10 | /fine_tuning/bart.pth 11 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /training_datasets/para_nmt/README.md: -------------------------------------------------------------------------------- 1 | Download the para-nmt5m filtered dataset from here 2 | https://drive.google.com/file/d/19NQ87gEFYu3zOIp_VNYQZgmnwRuSIyJd/view 3 | 4 | Leave the zip file in this directory. 5 | 6 | If done from command line like google cloud compute, referencing https://stackoverflow.com/a/63781195: 7 | 8 | install gdown PyPI module using pip: 9 | 10 | ```pip install gdown``` 11 | 12 | Download the file using gdown and the intended ID (19NQ87gEFYu3zOIp_VNYQZgmnwRuSIyJd): 13 | 14 | ```gdown --id 19NQ87gEFYu3zOIp_VNYQZgmnwRuSIyJd``` -------------------------------------------------------------------------------- /.idea/Paraphrase-OPT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 15 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | # Paraphrase OPT 2 | 3 | Training OPT for paraphrasing through prompt engineering. 4 | 5 | It seems like GPT3 will perform quite well with just telling it directly to paraphrase the following sentence, and 6 | bumping up the frequency and presence penalties. 7 | 8 | # Ideas 9 | 10 | Given that OPT is a decoder only model, how will we get it to perform what is traditionally considered a seq-2-seq task 11 | which involves cross attention from the encoder outputs, transforming the input sequence into an output sequence of a 12 | different phrasing. The function of the encoder output cross attention is for the model to maintain a strong reference 13 | point to the original sequence while predicting the next tokens, which is better than appending the input to the start 14 | of the sequence and referring to the causal mask of the decoder sequence since the encoder and decoder embedding spaces 15 | no longer have to be aligned. 16 | 17 | Differentiable prompt (DART)[https://arxiv.org/pdf/2108.13161.pdf] except we adapt it from MLM to CLM. Instead of labels 18 | based on a the output of a single \[MASK\] token we generate a whole sequence and evaluate the semantic similarity of 19 | the output sequence. The input template when fed into a MLM model looks like this: 20 | 21 | Xprompt = [CLS] Xin [SEP] T [SEP] 22 | 23 | where T is the template prompt with containing single [MASK] token, of the form: 24 | 25 | {h0,...,hi,w([MASK]),hi+1,...,hm} 26 | 27 | # GCloud Compute CLI Cheatsheet 28 | 29 | ``` 30 | gcloud compute ssh liewweipyn 31 | gcloud compute instances stop liewweipyn 32 | ``` -------------------------------------------------------------------------------- /fine_tuning/fine_tune_bart.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from pytorch_lightning import LightningModule 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | from transformers import BartForConditionalGeneration, BartTokenizer 5 | import os 6 | import torch 7 | from torch.optim import Adam 8 | 9 | # current working directory changes when imported from other modules, so to ensure para_nmt_path is correct we store 10 | # the absolute path to the module for reference. 11 | package_directory = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | 14 | class FineTuneBART(LightningModule): 15 | bart_path = os.path.join(package_directory, "bart.pth") 16 | 17 | def __init__(self, model_name='facebook/bart-large-cnn'): 18 | super().__init__() 19 | self.model = BartForConditionalGeneration.from_pretrained(model_name) 20 | self.save_hyperparameters() 21 | 22 | def load(self, path=None): 23 | """ Load model from paraphrase finetuning """ 24 | if path is None: 25 | path = self.bart_path 26 | self.model.load_state_dict(torch.load(path, map_location=self.device)) 27 | 28 | def forward(self, **inputs): 29 | return self.model(**inputs) 30 | 31 | 32 | if __name__ == "__main__": 33 | model = FineTuneBART() 34 | model.load() 35 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") 36 | 37 | ARTICLE_TO_PARAPHRASE = ( 38 | "PG&E stated it scheduled the blackouts in response to forecasts for high winds " 39 | "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " 40 | "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." 41 | ) 42 | 43 | while ARTICLE_TO_PARAPHRASE != "": 44 | inputs = tokenizer([ARTICLE_TO_PARAPHRASE], max_length=1024, truncation=True, return_tensors="pt") 45 | 46 | # Generate Summary 47 | summary_ids = model.model.generate(inputs["input_ids"], num_beams=2, do_sample=True, min_length=0, max_length=50) 48 | outputs = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 49 | print(outputs) 50 | 51 | ARTICLE_TO_PARAPHRASE = input("Enter: ") 52 | -------------------------------------------------------------------------------- /train_fine_tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import wandb 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from pytorch_lightning.loggers import WandbLogger 7 | 8 | from fine_tuning.fine_tune_opt import FineTuneOPT 9 | from training_datasets.paracombined import ParaCombinedDataModule 10 | from training_datasets.parabank.parabank import ParabankDataModule 11 | from training_datasets.para_nmt.para_nmt import ParaNMTDataModule 12 | 13 | 14 | if __name__ == "__main__": 15 | # initialisation steps 16 | torch.cuda.empty_cache() 17 | AVAIL_GPUS = min(1, torch.cuda.device_count()) 18 | 19 | run = wandb.init(project="fine-tune-opt", entity="clyde013") 20 | 21 | with run: 22 | datamodule = ParaCombinedDataModule(wandb.config["model_name"], batch_size=wandb.config["batch_size"], 23 | steps_per_epoch=wandb.config["steps_per_epoch"], 24 | datamodules=[ParabankDataModule, ParaNMTDataModule], 25 | probabilities=[0.5, 0.5]) 26 | datamodule.setup() 27 | 28 | if (wandb.config["load_from_checkpoint"] is not None) and (os.path.isfile(wandb.config["load_from_checkpoint"])): 29 | model = FineTuneOPT.load_from_checkpoint(checkpoint_path=wandb.config["load_from_checkpoint"]) 30 | else: 31 | model = FineTuneOPT(wandb.config["model_name"]) 32 | 33 | checkpoint_callback = ModelCheckpoint(dirpath=wandb.config["checkpoint_save_dir"], 34 | save_top_k=2, monitor="val_loss", 35 | filename="fine-tune-opt-epoch={epoch:03d}-val_loss={val_loss:.3f}") 36 | 37 | # create wandb logger (obviously) 38 | wandb_logger = WandbLogger(checkpoint_callback=False) 39 | 40 | print("TRAINING MODEL") 41 | trainer = Trainer(max_epochs=wandb.config["max_epochs"], gpus=AVAIL_GPUS, 42 | check_val_every_n_epoch=wandb.config["check_val_every_n_epoch"], 43 | callbacks=[checkpoint_callback], 44 | logger=wandb_logger) 45 | trainer.fit(model, datamodule=datamodule) 46 | 47 | wandb.finish() 48 | 49 | print("Training complete.") 50 | -------------------------------------------------------------------------------- /soft_prompt_tuning/soft_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/kipgparker/soft-prompt-tuning/blob/main/soft_embedding.py 3 | It's really that simple. huh. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class SoftEmbedding(nn.Module): 11 | def __init__(self, 12 | wte: nn.Embedding, 13 | n_tokens: int = 10, 14 | random_range: float = 0.5, 15 | initialize_from_vocab: bool = True): 16 | """appends learned embedding to 17 | Args: 18 | wte (nn.Embedding): original transformer word embedding 19 | n_tokens (int, optional): number of tokens for task. Defaults to 10. 20 | random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5. 21 | initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True. 22 | """ 23 | super(SoftEmbedding, self).__init__() 24 | self.wte = wte 25 | self.n_tokens = n_tokens 26 | self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte, 27 | n_tokens, 28 | random_range, 29 | initialize_from_vocab)) 30 | 31 | def initialize_embedding(self, 32 | wte: nn.Embedding, 33 | n_tokens: int = 10, 34 | random_range: float = 0.5, 35 | initialize_from_vocab: bool = True) -> torch.Tensor: 36 | """initializes learned embedding 37 | Args: 38 | same as __init__ 39 | Returns: 40 | torch.float: initialized using original schemes 41 | """ 42 | if initialize_from_vocab: 43 | return self.wte.weight[:n_tokens].clone().detach() 44 | return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range) 45 | 46 | def forward(self, tokens): 47 | """run forward pass 48 | Args: 49 | tokens (torch.long): input tokens before encoding 50 | Returns: 51 | torch.float: encoding of text concatenated with learned task specific embedding 52 | """ 53 | input_embedding = self.wte(tokens[:, self.n_tokens:]) 54 | learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1) 55 | return torch.cat([learned_embedding, input_embedding], 1) 56 | -------------------------------------------------------------------------------- /config-defaults.yaml: -------------------------------------------------------------------------------- 1 | # configs file for wandb.config 2 | --- 3 | model_name: 4 | desc: Name of variation of OPT from huggingface to initialise 5 | value: "facebook/opt-1.3b" 6 | load_from_checkpoint: 7 | desc: Checkpoint path to load the model from, if None will instantiate from default 8 | value: null 9 | #value: "training_checkpoints/soft-opt-epoch=000-val_loss=21.112.ckpt" 10 | 11 | # training config 12 | # since we are using streaming dataloaders, we cannot define an epoch as end of the dataset, hence we use 13 | # num_batches_per_epoch = steps_per_epoch / batch_size 14 | max_epochs: 15 | desc: Maximum number of training epochs 16 | value: 30 17 | steps_per_epoch: 18 | desc: Number of steps per epoch 19 | value: 8000 20 | batch_size: 21 | desc: Batch size that each datamodule will output 22 | value: 32 23 | check_val_every_n_epoch: 24 | desc: Every n epochs validation loop is run 25 | value: 2 26 | 27 | # checkpoint configs 28 | checkpoint_every_n_epochs: 29 | desc: Every n epochs we save a checkpoint of the model 30 | value: 30 31 | checkpoint_save_dir: 32 | desc: Directory to save checkpoints to 33 | value: "training_checkpoints/07-06-2022-optimize/" 34 | layers_to_optimize: 35 | desc: The name of the layers to optimize, and then save. regex matching will match even incomplete names, although try not the break this on purpose by matching more than 1 layer :) 36 | value: ["soft_embedding.learned_embedding"] 37 | 38 | # optimizers config 39 | optimizer_type: 40 | desc: Type of optimizer to use (currently only supports "Adam" or "SGD") 41 | value: "Adam" 42 | optimizer_params: 43 | desc: Parameters for the optimizer in the form of a dictionary, same format as state_dict. 44 | # btw when using scientific notation write 1.0e-3 and not 1e-3 otherwise it is mistakenly parsed as string (bug on pyYAML side) 45 | value: {"lr": 1.0e-3} 46 | lr_scheduler_type: 47 | desc: Type of learning rate scheduler to use (currently only supports "ReduceLROnPlateau") 48 | value: "ReduceLROnPlateau" 49 | lr_scheduler_params: 50 | desc: Parameters for the learning rate scheduler in the form of a dictionary, same format as state_dict. 51 | value: {"mode": "min", "patience": 10} 52 | lr_scheduler_config: 53 | desc: Learning rate scheduler configuration for pytorch lightning 54 | value: {"monitor": "train_loss"} 55 | 56 | # learnable embedding config 57 | embedding_n_tokens: 58 | desc: Number of learnable tokens to be prepended to the embedding 59 | value: 20 60 | init_from_vocab: 61 | desc: Whether to intialise the learned embedding as a copy of the existing vocabulary so it does not have to be trained from scratch (basically no downside to having this always True) 62 | value: True 63 | ... -------------------------------------------------------------------------------- /fine_tuning/fine_tune_opt.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from pytorch_lightning import LightningModule 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | from transformers import OPTForCausalLM 5 | 6 | import torch 7 | from torch.optim import Adam 8 | 9 | 10 | class FineTuneOPT(LightningModule): 11 | """ 12 | very straightforward direct fine tuning on the OPT model 13 | """ 14 | def __init__(self, model_name="facebook/opt-350m"): 15 | super().__init__() 16 | self.model = OPTForCausalLM.from_pretrained(model_name) 17 | self.save_hyperparameters() 18 | 19 | def forward(self, **inputs): 20 | return self.model(**inputs) 21 | 22 | def training_step(self, batch, batch_idx): 23 | outputs = self(**batch) 24 | loss = outputs[0] 25 | self.log("train_loss", loss) 26 | 27 | return loss 28 | 29 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 30 | outputs = self(**batch) 31 | val_loss, logits = outputs[:2] 32 | 33 | # we care only about the last token being predicted 34 | pred_token_logits = logits[:, -1, :] 35 | pred_token = torch.argmax(pred_token_logits, dim=-1) 36 | labels = batch["labels"][:, -1] 37 | 38 | self.log("val_loss", val_loss) 39 | 40 | return {"loss": val_loss, "preds": pred_token, "labels": labels} 41 | 42 | def configure_optimizers(self): 43 | optimizer = Adam(self.model.parameters(), **wandb.config["optimizer_params"]) 44 | 45 | # configure learning rate scheduler 46 | lr_scheduler = ReduceLROnPlateau(optimizer, **wandb.config["lr_scheduler_params"]) 47 | 48 | lr_scheduler_config = {"scheduler": lr_scheduler} 49 | lr_scheduler_config.update(wandb.config["lr_scheduler_config"]) 50 | 51 | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} 52 | 53 | """ 54 | Note on following hooks (on_train_epoch_start and on_validation_epoch_start): 55 | 56 | Using the following code to access dataloaders: self.train_dataloader().dataset.set_epoch(self.current_epoch) 57 | Results in an exception like such : pytorch_lightning.utilities.exceptions.MisconfigurationException: 58 | `val_dataloader` must be implemented to be used with the Lightning Trainer 59 | 60 | Although train_dataloader() is a valid hook, the hook is overridden only in the datamodule and we cannot reference 61 | that. We have to use self.trainer.train_dataloader.dataset which returns some CombinedDataset and then .datasets 62 | that one to get the original TorchIterableDataset. 63 | 64 | On the other hand, we can access validation dataloaders with self.trainer.val_dataloaders[0].dataset as that one is 65 | apparently a list and not a CombinedDataset. 66 | 67 | Pain. 68 | """ 69 | 70 | def on_train_epoch_start(self) -> None: 71 | # reshuffle the dataset for every train epoch 72 | self.trainer.train_dataloader.dataset.datasets.set_epoch(self.trainer.current_epoch) 73 | 74 | def on_validation_epoch_start(self) -> None: 75 | # reshuffle the dataset for every validation epoch 76 | self.trainer.val_dataloaders[0].dataset.set_epoch(self.trainer.current_epoch) 77 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/REPORT-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# Teaching OPT to Paraphrase through Soft Prompt Tuning\n", 12 | "\n", 13 | "## Table of Contents\n", 14 | "1. [Introduction](#introduction)\n", 15 | "2. [Some Paragraph](#para1)\n", 16 | " 1. sub\n", 17 | " 1. Huh\n", 18 | " 2. Sub\n", 19 | "3. Some Paragraph" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "pycharm": { 27 | "name": "#%%\n" 28 | } 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "a = 2.4" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "pycharm": { 40 | "name": "#%%\n" 41 | } 42 | }, 43 | "outputs": [], 44 | "source": [] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "pycharm": { 51 | "name": "#%%\n" 52 | } 53 | }, 54 | "outputs": [], 55 | "source": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "pycharm": { 62 | "name": "#%%\n" 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "pycharm": { 73 | "name": "#%%\n" 74 | } 75 | }, 76 | "outputs": [], 77 | "source": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "pycharm": { 84 | "name": "#%%\n" 85 | } 86 | }, 87 | "outputs": [], 88 | "source": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "pycharm": { 95 | "name": "#%%\n" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "pycharm": { 106 | "name": "#%%\n" 107 | } 108 | }, 109 | "outputs": [], 110 | "source": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "pycharm": { 117 | "name": "#%%\n" 118 | } 119 | }, 120 | "outputs": [], 121 | "source": [] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "pycharm": { 127 | "name": "#%% md\n" 128 | } 129 | }, 130 | "source": [ 131 | "# Introduction\n", 132 | "" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "pycharm": { 140 | "name": "#%%\n" 141 | } 142 | }, 143 | "outputs": [], 144 | "source": [] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "pycharm": { 151 | "name": "#%%\n" 152 | } 153 | }, 154 | "outputs": [], 155 | "source": [] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "pycharm": { 162 | "name": "#%%\n" 163 | } 164 | }, 165 | "outputs": [], 166 | "source": [] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3 (ipykernel)", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.10.4" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 1 190 | } 191 | -------------------------------------------------------------------------------- /metrics/bart_metric.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import torch 6 | 7 | from torchmetrics import Metric 8 | from transformers import BartTokenizer, BartForConditionalGeneration 9 | 10 | 11 | class BartScore(Metric): 12 | """ 13 | Torchmetric version of BartScore as adapted from their github 14 | https://github.com/neulab/BARTScore 15 | With lots of reference to bertscore implementation 16 | https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/text/bert.py#L40-L235 17 | 18 | Compute the score by: 19 | ``` 20 | bartscore = BartScore() 21 | score = bartscore(['This is interesting.', 'This is a good idea.'], ['This is fun.', 'Sounds like a good idea.']) 22 | ``` 23 | and it should return [-2.152808666229248, -2.948076009750366]. 24 | """ 25 | 26 | def __init__(self, device='cuda:0', max_length=1024, checkpoint='facebook/bart-large-cnn', 27 | **kwargs: Dict[str, Any]): 28 | super().__init__(**kwargs) 29 | 30 | # Set up model 31 | self.device_ = device 32 | self.max_length = max_length 33 | self.tokenizer = BartTokenizer.from_pretrained(checkpoint) 34 | self.model = BartForConditionalGeneration.from_pretrained(checkpoint) 35 | self.model.eval() 36 | self.model.to(device) 37 | 38 | # Set up loss 39 | self.loss_fct = nn.NLLLoss(reduction='none', ignore_index=self.model.config.pad_token_id) 40 | self.lsm = nn.LogSoftmax(dim=1) 41 | 42 | # Set up metric state variables which keep track of state on each call of update 43 | self.add_state("src_input_ids", [], dist_reduce_fx="cat") 44 | self.add_state("src_attention_mask", [], dist_reduce_fx="cat") 45 | self.add_state("target_input_ids", [], dist_reduce_fx="cat") 46 | self.add_state("target_attention_mask", [], dist_reduce_fx="cat") 47 | 48 | def update(self, preds: List[str], target: List[str]) -> None: 49 | # dict of 2d list of tensors [batch_size, input_size] although input_size is not fixed 50 | encoded_src = self.tokenizer(preds, padding=True, return_tensors='pt') 51 | encoded_targets = self.tokenizer(target, padding=True, return_tensors='pt') 52 | 53 | # 3d list of 2d tensors, since default values of state variables can only be lists or tensors 54 | self.src_input_ids.append(encoded_src['input_ids']) 55 | self.src_attention_mask.append(encoded_src['attention_mask']) 56 | self.target_input_ids.append(encoded_targets['input_ids']) 57 | self.target_attention_mask.append(encoded_targets['attention_mask']) 58 | 59 | def compute(self): 60 | score_list = [] 61 | 62 | src_tokens = self.src_input_ids[0].to(self.device_) 63 | src_mask = self.src_attention_mask[0].to(self.device_) 64 | 65 | tgt_tokens = self.target_input_ids[0].to(self.device_) 66 | tgt_mask = self.target_attention_mask[0] 67 | tgt_len = tgt_mask.sum(dim=1).to(self.device_) 68 | 69 | # while we do not use the loss computation as a result of labels being provided, the labels also cause 70 | # https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/bart/modeling_bart.py#L1320 71 | # the decoder input id to be shifted to the right for us, which is needed for this to work 72 | output = self.model( 73 | input_ids=src_tokens, 74 | attention_mask=src_mask, 75 | labels=tgt_tokens 76 | ) 77 | 78 | # loss calculation based on original bart_score 79 | logits = output.logits.view(-1, self.model.config.vocab_size) 80 | lsm_output = self.lsm(logits) 81 | loss = self.loss_fct(lsm_output, tgt_tokens.view(-1)) 82 | loss = loss.view(tgt_tokens.shape[0], -1) 83 | loss = loss.sum(dim=1) / tgt_len 84 | curr_score_list = [-x.item() for x in loss] 85 | score_list += curr_score_list 86 | 87 | return score_list 88 | 89 | 90 | def main(): 91 | bartscore = BartScore() 92 | score = bartscore(['This is interesting.', 'This is interesting.'], 93 | ['This is very curious.', 'This is incredibly strange.']) 94 | print(score) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /test_vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | import numpy as np 6 | import pickle 7 | 8 | from transformers import GPT2Tokenizer 9 | from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT 10 | 11 | from faerun import Faerun 12 | import tmap as tm 13 | 14 | import re 15 | 16 | print("Initialising...") 17 | 18 | wandb.init(project="test-popt-dump", entity="clyde013", name="test-model", allow_val_change=True) 19 | wandb.config.update({"embedding_n_tokens": 111}, allow_val_change=True) 20 | 21 | #checkpoint = r"training_checkpoints/30-05-2022-1.3b/soft-opt-epoch=179-val_loss=1.397.ckpt" 22 | checkpoint = r"training_checkpoints/optimize/soft-opt-epoch=029-val_loss=0.487-optimizer_type=Adam-embedding_n_tokens=111.ckpt" 23 | model_name = "facebook/opt-1.3b" 24 | 25 | torch.cuda.empty_cache() 26 | 27 | AVAIL_GPUS = min(1, torch.cuda.device_count()) 28 | 29 | model = ParaphraseOPT.load_from_custom_save(model_name, checkpoint) 30 | model = model.eval() 31 | 32 | # default_model = ParaphraseOPT(model_name) 33 | 34 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 35 | 36 | learned_embeddings = model.model.soft_embedding.learned_embedding.detach() 37 | original_embeddings = model.model.soft_embedding.wte.weight.detach() 38 | 39 | 40 | def visualise(lf_filename: str, load_lf: bool): 41 | """ 42 | minhash and lsh forest visualisation 43 | http://matthewcasperson.blogspot.com/2013/11/minhash-for-dummies.html 44 | http://infolab.stanford.edu/~bawa/Pub/similarity.pdf 45 | """ 46 | dims = 512 47 | enc = tm.Minhash(learned_embeddings.size(dim=1), seed=69, sample_size=dims) 48 | lf = tm.LSHForest(dims * 2, 128) 49 | 50 | c_labels = np.concatenate([np.ones(original_embeddings.size(dim=0)), np.zeros(learned_embeddings.size(dim=0))]) 51 | print(c_labels) 52 | print(c_labels.shape) 53 | 54 | # generate labels when you click the points 55 | labels = [] 56 | # add labels for embeddings to be their decoded tokens 57 | # iterate through all tokenizer keys, that are stored as unicode, encode them with utf-8 to get their byte like 58 | # representations, have to repr() the byte string to get b'' and then manually remove the b'', and then use 59 | # regex to remove the special character that GPTs BPE uses to denote whitespace as well as any ' " \\ that will mess 60 | # up the javascript source file. 61 | labels.extend([re.sub(r'(\\xc4|\\xa0)|[\'\"\\]', '', repr(i.encode("utf-8"))[2:-1]) for i in tokenizer.get_vocab().keys()]) 62 | 63 | for i in range(learned_embeddings.size(dim=0)): 64 | labels.append(f"learned embedding {i}") 65 | 66 | if load_lf: 67 | lf.restore(f"visualisations/{lf_filename}") 68 | else: 69 | np_arr = np.concatenate([original_embeddings, learned_embeddings]) 70 | tmp = [] 71 | for i in np_arr: 72 | tmp.append(tm.VectorFloat(i.tolist())) 73 | print("batch add") 74 | lf.batch_add(enc.batch_from_weight_array(tmp)) 75 | print("index") 76 | lf.index() 77 | print("saving lf") 78 | lf.store(f"visualisations/{lf_filename}") 79 | 80 | print("layout") 81 | config = tm.LayoutConfiguration() 82 | config.fme_randomize = False 83 | x, y, s, t, _ = tm.layout_from_lsh_forest(lf, config=config) 84 | print("faerun") 85 | legend_labels = [ 86 | (0, "learned embeddings"), 87 | (255, "default embeddings") 88 | ] 89 | 90 | faerun = Faerun(clear_color="#111111", view="front", coords=False) 91 | faerun.add_scatter( 92 | "Embeddings", 93 | {"x": x, "y": y, "c": c_labels, "labels": labels}, 94 | colormap="RdYlBu", 95 | shader="smoothCircle", 96 | point_scale=3, 97 | max_point_size=20, 98 | has_legend=True, 99 | categorical=True, 100 | legend_labels=legend_labels, 101 | ) 102 | faerun.add_tree( 103 | "Embeddings_tree", {"from": s, "to": t}, point_helper="Embeddings", color="#666666" 104 | ) 105 | 106 | faerun.plot(f"Embeddings_{wandb.config['embedding_n_tokens']}", path="visualisations/") 107 | print("done") 108 | 109 | 110 | if __name__ == "__main__": 111 | visualise(f"lf_{wandb.config['embedding_n_tokens']}_seed=69.dat", False) 112 | -------------------------------------------------------------------------------- /train_soft_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.loggers import WandbLogger 7 | 8 | from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT, SpecificLayersCheckpoint 9 | from training_datasets.paracombined import ParaCombinedDataModule 10 | from training_datasets.parabank.parabank import ParabankDataModule 11 | from training_datasets.para_nmt.para_nmt import ParaNMTDataModule 12 | 13 | import optuna 14 | from optuna.trial import Trial 15 | from optuna.integration import PyTorchLightningPruningCallback 16 | 17 | # initialisation steps 18 | torch.cuda.empty_cache() 19 | AVAIL_GPUS = min(1, torch.cuda.device_count()) 20 | 21 | 22 | def objective(trial: Trial): 23 | # clear cache so we don't RuntimeError: CUDA out of memory. Tried to allocate 17.00 GB (GPU 0; 39.59 GiB total capacity; 37.53 GiB already allocated; 22.19 MiB free; 37.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF 24 | torch.cuda.empty_cache() 25 | 26 | # initialise hyperparameter search 27 | trial_config = dict() 28 | 29 | # number of embedding tokens 30 | embedding_n_tokens = trial.suggest_int("embedding_n_tokens", 50, 150) 31 | trial_config["embedding_n_tokens"] = embedding_n_tokens 32 | # optimizers 33 | # optimizer_type = trial.suggest_categorical("optimizer_type", ["Adam", "SGD"]) 34 | optimizer_type = "Adam" 35 | trial_config["optimizer_type"] = optimizer_type 36 | 37 | # override default params with the hyperparamters being searched for 38 | run = wandb.init(project="optimize-popt", entity="clyde013", 39 | name=f"optimizer_type={optimizer_type}-embedding_n_tokens={embedding_n_tokens}") 40 | with run: 41 | wandb.config.update(trial_config, allow_val_change=True) 42 | 43 | datamodule = ParaCombinedDataModule(wandb.config["model_name"], batch_size=wandb.config["batch_size"], 44 | steps_per_epoch=wandb.config["steps_per_epoch"], 45 | datamodules=[ParabankDataModule, ParaNMTDataModule], 46 | probabilities=[0.5, 0.5]) 47 | datamodule.setup() 48 | 49 | if (wandb.config["load_from_checkpoint"] is not None) and (os.path.isfile(wandb.config["load_from_checkpoint"])): 50 | model = ParaphraseOPT.load_from_custom_save(wandb.config["model_name"], 51 | wandb.config["load_from_checkpoint"]) 52 | else: 53 | model = ParaphraseOPT(wandb.config["model_name"]) 54 | 55 | checkpoint_callback = SpecificLayersCheckpoint( 56 | monitor="val_loss", 57 | dirpath=wandb.config["checkpoint_save_dir"], 58 | filename="soft-opt-epoch={epoch:03d}-val_loss={val_loss:.3f}" + 59 | f"-optimizer_type={optimizer_type}-embedding_n_tokens={embedding_n_tokens}" + ".ckpt", 60 | every_n_epochs=wandb.config["checkpoint_every_n_epochs"], 61 | layers_to_save=wandb.config["layers_to_optimize"] 62 | ) 63 | 64 | early_stopping_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss") 65 | 66 | # create wandb logger (obviously) 67 | wandb_logger = WandbLogger(checkpoint_callback=False) 68 | 69 | print("TRAINING MODEL") 70 | trainer = Trainer(max_epochs=wandb.config["max_epochs"], gpus=AVAIL_GPUS, 71 | check_val_every_n_epoch=wandb.config["check_val_every_n_epoch"], 72 | callbacks=[checkpoint_callback, early_stopping_callback], 73 | logger=wandb_logger) 74 | trainer.fit(model, datamodule=datamodule) 75 | 76 | wandb.finish() 77 | 78 | return trainer.callback_metrics["val_loss"].item() 79 | 80 | 81 | if __name__ == "__main__": 82 | pruner: optuna.pruners.BasePruner = optuna.pruners.MedianPruner(n_warmup_steps=1000) 83 | 84 | study = optuna.create_study(direction="minimize", pruner=pruner) 85 | study.optimize(objective, n_trials=20) 86 | 87 | print("Number of finished trials: {}".format(len(study.trials))) 88 | 89 | print("Best trial:") 90 | trial = study.best_trial 91 | 92 | print(" Value: {}".format(trial.value)) 93 | 94 | print(" Params: ") 95 | for key, value in trial.params.items(): 96 | print(" {}: {}".format(key, value)) 97 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import streamlit as st 4 | 5 | import torch 6 | from transformers import GPT2Tokenizer, OPTForCausalLM 7 | 8 | import wandb 9 | 10 | from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT 11 | 12 | # init 13 | wandb.init(project="popt-gui", entity="clyde013") 14 | wandb.config.update({"embedding_n_tokens": 143}, allow_val_change=True) 15 | AVAIL_GPUS = min(1, torch.cuda.device_count()) 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | fits_on_gpu = True 18 | 19 | # we want to define the functions without actually calling them, so we wrap them in lambdas 20 | model_name = "facebook/opt-1.3b" 21 | checkpoint = r"training_checkpoints/23-06-2022/soft-opt-epoch=029-val_loss=0.397-embedding_n_tokens=143.ckpt" 22 | model_type_key = {'OPT1.3B Prompt Fine Tuned': lambda: ParaphraseOPT.load_from_custom_save(model_name, checkpoint), 23 | 'OPT1.3B Base Model': lambda: OPTForCausalLM.from_pretrained(model_name)} 24 | # model specific prompt changes 25 | model_prompt_key = {'OPT1.3B Prompt Fine Tuned': lambda x: x + "", 26 | 'OPT1.3B Base Model': lambda x: x} 27 | 28 | 29 | # expensive functions that need caching 30 | @st.cache(hash_funcs={ParaphraseOPT: lambda _: None}) 31 | def reconstruct_tokens(model): 32 | """ 33 | Find the nearest tokens in the embedding space. 34 | https://stackoverflow.com/questions/64523788/how-to-invert-a-pytorch-embedding 35 | """ 36 | embeddings = model.model.soft_embedding.wte 37 | learned_embedding = model.model.soft_embedding.learned_embedding 38 | 39 | reconstructed = list() 40 | for i in learned_embedding: 41 | distance = torch.norm(embeddings.weight.detach() - i, dim=1) 42 | nearest = torch.argmin(distance) 43 | reconstructed.append(nearest.item()) 44 | 45 | return reconstructed 46 | 47 | 48 | @st.cache(hash_funcs={ParaphraseOPT: lambda _: None, OPTForCausalLM: lambda _: None, GPT2Tokenizer: lambda _: None}) 49 | def init_model(selection: str): 50 | global device 51 | global fits_on_gpu 52 | global gpu_counter 53 | 54 | init_model = model_type_key[selection]() 55 | try: 56 | init_model.to(device) 57 | except RuntimeError: 58 | fits_on_gpu = False 59 | device = torch.device("cpu") 60 | init_model.to(device) 61 | torch.cuda.empty_cache() 62 | gpu_counter.text("Model does not fit on GPU, using CPU instead.") 63 | 64 | return init_model, GPT2Tokenizer.from_pretrained(model_name) 65 | 66 | 67 | def tokenize(model_type: str, prompt: str): 68 | soft_prompt = model_prompt_key[model_type](prompt) 69 | encoded_inputs = tokenizer(soft_prompt, return_tensors="pt") 70 | return encoded_inputs 71 | 72 | 73 | def predict(model_type: str, prompt: str, max_len: int): 74 | inputs = tokenize(model_type, prompt) 75 | if model_type == "OPT1.3B Base Model": 76 | outputs = model.generate(inputs.input_ids, max_length=max_len, use_cache=False) 77 | else: 78 | outputs = model.model.generate(inputs.input_ids, max_length=max_len, use_cache=False) 79 | outputs = outputs[:, inputs['input_ids'].size(dim=-1):] 80 | decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 81 | return decoded 82 | 83 | 84 | gpu_counter = st.sidebar.empty() 85 | if fits_on_gpu: 86 | gpu_counter.text(f"Currently utilising {AVAIL_GPUS} GPU(s).") 87 | else: 88 | gpu_counter.text("Model does not fit on GPU, using CPU instead.") 89 | cache_warning = st.sidebar.write("Loading models for the first time may take a while. Future selections will " 90 | "be greatly faster as models are cached! :)") 91 | 92 | # Add a selectbox for model type to the sidebar: 93 | model_selectbox = st.sidebar.selectbox( 94 | 'Which model would you like to use?', 95 | model_type_key.keys() 96 | ) 97 | 98 | # Add slider for max length of the sequence 99 | max_len_slider = st.sidebar.slider("What is the maximum sequence length that you would like the model to generate? " 100 | "(counts input tokens)", 101 | min_value=30, max_value=100, value=45) 102 | 103 | model, tokenizer = init_model(model_selectbox) 104 | 105 | 106 | with st.expander("Model Demo"): 107 | # input text box 108 | input_txt = st.text_area("Enter text for the model:", "The quick brown fox jumped over the fence.") 109 | 110 | # token count for input 111 | seq_len = tokenize(model_selectbox, input_txt)['input_ids'].size(dim=-1) 112 | token_count = st.caption(f"Token length: {seq_len}") 113 | 114 | # output text box 115 | with st.spinner(text="Generating..."): 116 | model_outputs_txt = st.code(predict(model_selectbox, input_txt, max_len_slider), language="markdown") 117 | 118 | 119 | with st.expander("Learned Embeddings"): 120 | if model_selectbox == 'OPT1.3B Prompt Fine Tuned': 121 | tokens = reconstruct_tokens(model) 122 | outputs = tokenizer.batch_decode(tokens) 123 | st.text(outputs) 124 | st.text("".join(outputs)) 125 | else: 126 | st.text("Default model does not have any prepended learned embeddings.") -------------------------------------------------------------------------------- /training_datasets/paracombined.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type, Optional 2 | 3 | from datasets import IterableDataset, interleave_datasets 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader 6 | from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling 7 | 8 | from training_datasets.parabank.parabank import ParabankDataModule 9 | from training_datasets.para_nmt.para_nmt import ParaNMTDataModule 10 | 11 | 12 | class ParaCombinedDataModule(LightningDataModule): 13 | """ 14 | LightningDataModule for combining different datasets for causal language modelling 15 | 16 | Note on num_workers: https://github.com/huggingface/datasets/pull/4375 17 | IterableDatasets do not support Dataloaders with num_workers > 0. Watch the PR to see if the fix will be merged. 18 | """ 19 | def __init__(self, opt_name, batch_size, steps_per_epoch, datamodules: List[Type[LightningDataModule]], 20 | probabilities: List[float], num_workers=0, seed=69, pre_tokenize=True): 21 | """ 22 | 23 | Parameters 24 | ---------- 25 | opt_name: str 26 | Name of model type 27 | batch_size: int 28 | batch_size output by dataloader 29 | steps_per_epoch: int 30 | dataset_size = steps_per_epoch * batch_size 31 | Since we do not know the dataset size we simply leave it to the user to determine how many steps per epoch 32 | we should have. 33 | datamodules: List[Type[LightningDataModule]] 34 | List specifying the datamodules whose datasets will be interleaved 35 | probabilities: List[float] 36 | List of probabilities for respective datamodules that should sum to 1 37 | num_workers: int 38 | refer to note above on PR https://github.com/huggingface/datasets/pull/4375 39 | seed: int 40 | haha funny number 41 | pre_tokenize: bool 42 | should we tokenize the texts (if true: dataset will return tokenized ids instead of source text) 43 | """ 44 | super().__init__() 45 | self.opt_name = opt_name 46 | self.batch_size = batch_size 47 | self.steps_per_epoch = steps_per_epoch 48 | self.num_workers = num_workers 49 | self.seed = seed 50 | self.pre_tokenize = pre_tokenize 51 | self.datamodules = datamodules 52 | self.probabilities = probabilities 53 | self.tokenizer = None 54 | self.dataset = None 55 | 56 | # sanity check 57 | assert sum(self.probabilities) == 1, "Probabilities for interleaved datasets do not sum to 1.0" 58 | 59 | def prepare_data(self) -> None: 60 | # download and cache 61 | GPT2Tokenizer.from_pretrained(self.opt_name) 62 | 63 | def setup(self, stage: Optional[str] = None) -> None: 64 | # tokenizer is not actually used once instantiated but to stay consistent with other datamodule implementations 65 | # we instantiate it anyway 66 | self.tokenizer = GPT2Tokenizer.from_pretrained(self.opt_name, use_fast=False) 67 | 68 | # instantiate all the datamodules and extract the dataset from them 69 | datasets = list() 70 | for datamodule in self.datamodules: 71 | dm = datamodule(self.opt_name, self.batch_size, self.steps_per_epoch, 72 | seed=self.seed, pre_tokenize=self.pre_tokenize) 73 | dm.setup() 74 | datasets.append(dm.dataset) 75 | 76 | self.dataset = interleave_datasets(datasets, probabilities=self.probabilities, seed=self.seed) 77 | self.dataset = self.dataset.with_format("torch") 78 | 79 | # monkeypatch of __len__ function in the dataloader so that the trainer knows how many 80 | # steps there are per epoch. Sure this violates many programming paradigms but it works. 81 | n = self.steps_per_epoch 82 | 83 | def __len__(self): 84 | return n 85 | 86 | IterableDataset.__len__ = __len__ 87 | 88 | # dataloaders are basically all the same since we cannot split a streamed dataset 89 | def train_dataloader(self): 90 | dataloader = DataLoader(self.dataset, 91 | batch_size=self.batch_size, 92 | num_workers=self.num_workers) 93 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 94 | return dataloader 95 | 96 | def val_dataloader(self): 97 | dataloader = DataLoader(self.dataset, 98 | batch_size=self.batch_size, 99 | num_workers=self.num_workers) 100 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 101 | return dataloader 102 | 103 | def test_dataloader(self): 104 | dataloader = DataLoader(self.dataset, 105 | batch_size=self.batch_size, 106 | num_workers=self.num_workers) 107 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 108 | return dataloader 109 | 110 | def predict_dataloader(self): 111 | dataloader = DataLoader(self.dataset, 112 | batch_size=self.batch_size, 113 | num_workers=self.num_workers) 114 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 115 | return dataloader 116 | 117 | 118 | if __name__ == "__main__": 119 | model_name = "facebook/opt-1.3b" 120 | datamodule = ParaCombinedDataModule(model_name, 1, 1000, [ParabankDataModule, ParaNMTDataModule], 121 | probabilities=[0.35, 0.65], seed=1337, pre_tokenize=False) 122 | datamodule.setup() 123 | dl = datamodule.val_dataloader() 124 | it = iter(dl) 125 | 126 | for i in range(10): 127 | print(next(it)) 128 | -------------------------------------------------------------------------------- /training_datasets/para_nmt/para_nmt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from datasets import load_dataset, IterableDataset 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader 6 | from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling 7 | 8 | # current working directory changes when imported from other modules, so to ensure para_nmt_path is correct we store 9 | # the absolute path to the module for reference. 10 | package_directory = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | 13 | class ParaNMTDataModule(LightningDataModule): 14 | """ 15 | LightningDataModule for para_nmt dataset for causal language modelling 16 | 17 | Note on num_workers: https://github.com/huggingface/datasets/pull/4375 18 | IterableDatasets do not support Dataloaders with num_workers > 0. Watch the PR to see if the fix will be merged. 19 | """ 20 | para_nmt_path = os.path.join(package_directory, "para-nmt-5m-processed.zip") 21 | 22 | def __init__(self, opt_name, batch_size, steps_per_epoch, num_workers=0, seed=69, pre_tokenize=True): 23 | """ 24 | 25 | Parameters 26 | ---------- 27 | opt_name: str 28 | name of the OPT model type (i.e. facebook/opt-350m) 29 | batch_size: int 30 | batch_size output by dataloader 31 | steps_per_epoch: int 32 | dataset_size = steps_per_epoch * batch_size 33 | Since we do not know the dataset size we simply leave it to the user to determine how many steps per epoch 34 | we should have. 35 | num_workers: int 36 | refer to note above on PR https://github.com/huggingface/datasets/pull/4375 37 | seed: int 38 | haha funny number 39 | pre_tokenize: bool 40 | should we tokenize the texts (if true: dataset will return tokenized ids instead of source text) 41 | """ 42 | 43 | super().__init__() 44 | self.opt_name = opt_name 45 | self.batch_size = batch_size 46 | self.steps_per_epoch = steps_per_epoch 47 | self.num_workers = num_workers 48 | self.seed = seed 49 | self.pre_tokenize = pre_tokenize 50 | 51 | # init None to make pycharm happy 52 | self.tokenizer = None 53 | self.dataset = None 54 | 55 | def prepare_data(self) -> None: 56 | # download and cache 57 | GPT2Tokenizer.from_pretrained(self.opt_name) 58 | 59 | def setup(self, stage: Optional[str] = None) -> None: 60 | # load tokenizer (should be cached) 61 | self.tokenizer = GPT2Tokenizer.from_pretrained(self.opt_name, use_fast=False) 62 | 63 | # preprocess function for the dataset's entries 64 | def preprocess(examples): 65 | # list of len batch 66 | batch = examples['text'] 67 | processed_batch = list() 68 | for i in batch: 69 | # replace the \t splitting with a '' token to denote source-target 70 | processed_batch.append(str.replace(i, "\t", self.tokenizer.eos_token)) 71 | 72 | if self.pre_tokenize: 73 | outputs = self.tokenizer( 74 | processed_batch, 75 | truncation=True, 76 | max_length=69, 77 | ) 78 | else: 79 | outputs = {"source": processed_batch} 80 | return outputs 81 | 82 | # init dataset in streaming mode 83 | self.dataset = load_dataset("text", data_files=self.para_nmt_path, streaming=True)['train'] 84 | # elements within buffer size will be shuffled as they are loaded in 85 | self.dataset = self.dataset.shuffle(seed=self.seed, buffer_size=10_000) 86 | # preprocessing will take place while being streamed by dataloader 87 | self.dataset = self.dataset.map(preprocess, batched=True, remove_columns=['text']) 88 | # ensure pytorch tensors are returned 89 | self.dataset = self.dataset.with_format("torch") 90 | 91 | # monkeypatch of __len__ function in the dataloader so that the trainer knows how many 92 | # steps there are per epoch. Sure this violates many programming paradigms but it works. 93 | n = self.steps_per_epoch 94 | 95 | def __len__(self): 96 | return n 97 | 98 | IterableDataset.__len__ = __len__ 99 | 100 | # dataloaders are basically all the same since we cannot split a streamed dataset 101 | def train_dataloader(self): 102 | dataloader = DataLoader(self.dataset, 103 | batch_size=self.batch_size, 104 | num_workers=self.num_workers) 105 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 106 | return dataloader 107 | 108 | def val_dataloader(self): 109 | dataloader = DataLoader(self.dataset, 110 | batch_size=self.batch_size, 111 | num_workers=self.num_workers) 112 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 113 | return dataloader 114 | 115 | def test_dataloader(self): 116 | dataloader = DataLoader(self.dataset, 117 | batch_size=self.batch_size, 118 | num_workers=self.num_workers) 119 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 120 | return dataloader 121 | 122 | def predict_dataloader(self): 123 | dataloader = DataLoader(self.dataset, 124 | batch_size=self.batch_size, 125 | num_workers=self.num_workers) 126 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 127 | return dataloader 128 | 129 | 130 | if __name__ == "__main__": 131 | model_name = "facebook/opt-1.3b" 132 | datamodule = ParaNMTDataModule(model_name, 1, 1000, seed=1337) 133 | datamodule.setup() 134 | dl = datamodule.val_dataloader() 135 | it = iter(dl) 136 | 137 | for i in range(10): 138 | print(datamodule.tokenizer.batch_decode(next(it)['input_ids'])[0]) 139 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: win-64 4 | absl-py=1.0.0=pypi_0 5 | aiohttp=3.8.1=pypi_0 6 | aiosignal=1.2.0=pypi_0 7 | alembic=1.8.0=pypi_0 8 | altair=4.2.0=pypi_0 9 | annoy=1.17.0=pypi_0 10 | anyio=3.6.1=pypi_0 11 | argon2-cffi=21.3.0=pypi_0 12 | argon2-cffi-bindings=21.2.0=pypi_0 13 | asttokens=2.0.5=pypi_0 14 | async-timeout=4.0.2=pypi_0 15 | attrs=21.4.0=pypi_0 16 | autopage=0.5.1=pypi_0 17 | babel=2.10.1=pypi_0 18 | backcall=0.2.0=pypi_0 19 | beautifulsoup4=4.11.1=pypi_0 20 | bleach=5.0.0=pypi_0 21 | blinker=1.4=pypi_0 22 | bzip2=1.0.8=he774522_0 23 | ca-certificates=2022.4.26=haa95532_0 24 | cachetools=5.1.0=pypi_0 25 | certifi=2022.5.18.1=py310haa95532_0 26 | cffi=1.15.0=pypi_0 27 | cgraph=0.1=pypi_0 28 | charset-normalizer=2.0.12=pypi_0 29 | cheroot=8.6.0=pypi_0 30 | cherrypy=18.6.1=pypi_0 31 | click=8.1.3=pypi_0 32 | cliff=3.10.1=pypi_0 33 | cmaes=0.8.2=pypi_0 34 | cmd2=2.4.1=pypi_0 35 | colorama=0.4.4=pypi_0 36 | colorlog=6.6.0=pypi_0 37 | colour=0.1.5=pypi_0 38 | commonmark=0.9.1=pypi_0 39 | cppyy=1.9.6=pypi_0 40 | cppyy-backend=1.14.4=pypi_0 41 | cppyy-cling=6.21.7=pypi_0 42 | cpycppyy=1.12.5=pypi_0 43 | cycler=0.11.0=pypi_0 44 | datasets=2.2.2=pypi_0 45 | debugpy=1.6.0=pypi_0 46 | decorator=5.1.1=pypi_0 47 | defusedxml=0.7.1=pypi_0 48 | dill=0.3.4=pypi_0 49 | docker-pycreds=0.4.0=pypi_0 50 | entrypoints=0.4=pypi_0 51 | executing=0.8.3=pypi_0 52 | faerun=0.3.20=pypi_0 53 | fastjsonschema=2.15.3=pypi_0 54 | filelock=3.7.0=pypi_0 55 | fonttools=4.33.3=pypi_0 56 | frozenlist=1.3.0=pypi_0 57 | fsspec=2022.5.0=pypi_0 58 | gitdb=4.0.9=pypi_0 59 | gitpython=3.1.27=pypi_0 60 | google-auth=2.6.6=pypi_0 61 | google-auth-oauthlib=0.4.6=pypi_0 62 | graphistry=0.25.2=pypi_0 63 | greenlet=1.1.2=pypi_0 64 | grpcio=1.46.3=pypi_0 65 | huggingface-hub=0.6.0=pypi_0 66 | idna=3.3=pypi_0 67 | importlib-metadata=4.11.4=pypi_0 68 | importlib-resources=1.5.0=pypi_0 69 | ipykernel=6.13.0=pypi_0 70 | ipython=8.3.0=pypi_0 71 | ipython-genutils=0.2.0=pypi_0 72 | ipywidgets=7.7.0=pypi_0 73 | jaraco-classes=3.2.1=pypi_0 74 | jaraco-collections=3.5.1=pypi_0 75 | jaraco-context=4.1.1=pypi_0 76 | jaraco-functools=3.5.0=pypi_0 77 | jaraco-text=3.8.0=pypi_0 78 | jedi=0.18.1=pypi_0 79 | jinja2=3.1.2=pypi_0 80 | joblib=1.1.0=pypi_0 81 | json5=0.9.8=pypi_0 82 | jsonpickle=2.2.0=pypi_0 83 | jsonschema=4.5.1=pypi_0 84 | jupyter-client=7.3.1=pypi_0 85 | jupyter-core=4.10.0=pypi_0 86 | jupyter-server=1.17.0=pypi_0 87 | jupyterlab=3.4.2=pypi_0 88 | jupyterlab-pygments=0.2.2=pypi_0 89 | jupyterlab-server=2.14.0=pypi_0 90 | jupyterlab-widgets=1.1.0=pypi_0 91 | kiwisolver=1.4.2=pypi_0 92 | libffi=3.4.2=h604cdb4_1 93 | llvmlite=0.38.1=pypi_0 94 | mako=1.2.0=pypi_0 95 | markdown=3.3.7=pypi_0 96 | markupsafe=2.1.1=pypi_0 97 | matplotlib=3.5.2=pypi_0 98 | matplotlib-inline=0.1.3=pypi_0 99 | mistune=0.8.4=pypi_0 100 | more-itertools=8.13.0=pypi_0 101 | multidict=6.0.2=pypi_0 102 | multiprocess=0.70.12.2=pypi_0 103 | nbclassic=0.3.7=pypi_0 104 | nbclient=0.6.3=pypi_0 105 | nbconvert=6.5.0=pypi_0 106 | nbformat=5.4.0=pypi_0 107 | nest-asyncio=1.5.5=pypi_0 108 | networkx=2.8.4rc1.dev0=dev_0 109 | nltk=3.7=pypi_0 110 | notebook=6.4.11=pypi_0 111 | notebook-shim=0.1.0=pypi_0 112 | numba=0.55.2=pypi_0 113 | numpy=1.22.4=pypi_0 114 | oauthlib=3.2.0=pypi_0 115 | ogdf-python=0.1.3=pypi_0 116 | openssl=1.1.1o=h2bbff1b_0 117 | optuna=2.10.0=pypi_0 118 | packaging=21.3=pypi_0 119 | pandas=1.4.2=pypi_0 120 | pandocfilters=1.5.0=pypi_0 121 | parso=0.8.3=pypi_0 122 | pathtools=0.1.2=pypi_0 123 | patsy=0.5.2=pypi_0 124 | pbr=5.9.0=pypi_0 125 | pickleshare=0.7.5=pypi_0 126 | pillow=9.1.1=pypi_0 127 | pip=21.2.4=py310haa95532_0 128 | plotly=5.8.0=pypi_0 129 | portend=3.1.0=pypi_0 130 | prettytable=3.3.0=pypi_0 131 | prometheus-client=0.14.1=pypi_0 132 | promise=2.3=pypi_0 133 | prompt-toolkit=3.0.29=pypi_0 134 | protobuf=3.20.1=pypi_0 135 | psutil=5.9.1=pypi_0 136 | pure-eval=0.2.2=pypi_0 137 | pyarrow=8.0.0=pypi_0 138 | pyasn1=0.4.8=pypi_0 139 | pyasn1-modules=0.2.8=pypi_0 140 | pycparser=2.21=pypi_0 141 | pydeck=0.7.1=pypi_0 142 | pydeprecate=0.3.2=pypi_0 143 | pygments=2.12.0=pypi_0 144 | pygraphviz=1.9=pypi_0 145 | pympler=1.0.1=pypi_0 146 | pynndescent=0.5.7=pypi_0 147 | pyparsing=3.0.9=pypi_0 148 | pyperclip=1.8.2=pypi_0 149 | pyreadline3=3.4.1=pypi_0 150 | pyrsistent=0.18.1=pypi_0 151 | python=3.10.4=hbb2ffb3_0 152 | python-dateutil=2.8.2=pypi_0 153 | python-dotenv=0.20.0=pypi_0 154 | pytorch-lightning=1.6.3=pypi_0 155 | pytz=2022.1=pypi_0 156 | pytz-deprecation-shim=0.1.0.post0=pypi_0 157 | pyvis=0.2.1=pypi_0 158 | pywin32=304=pypi_0 159 | pywinpty=2.0.5=pypi_0 160 | pyyaml=6.0=pypi_0 161 | pyzmq=23.0.0=pypi_0 162 | regex=2022.4.24=pypi_0 163 | requests=2.27.1=pypi_0 164 | requests-oauthlib=1.3.1=pypi_0 165 | responses=0.18.0=pypi_0 166 | rich=12.4.4=pypi_0 167 | rsa=4.8=pypi_0 168 | scikit-learn=1.1.1=pypi_0 169 | scipy=1.8.1=pypi_0 170 | seaborn=0.11.2=pypi_0 171 | semver=2.13.0=pypi_0 172 | send2trash=1.8.0=pypi_0 173 | sentry-sdk=1.5.12=pypi_0 174 | setproctitle=1.2.3=pypi_0 175 | setuptools=61.2.0=py310haa95532_0 176 | shortuuid=1.0.9=pypi_0 177 | six=1.16.0=pypi_0 178 | smmap=5.0.0=pypi_0 179 | sniffio=1.2.0=pypi_0 180 | soupsieve=2.3.2.post1=pypi_0 181 | sqlalchemy=1.4.37=pypi_0 182 | sqlite=3.38.3=h2bbff1b_0 183 | stack-data=0.2.0=pypi_0 184 | statsmodels=0.13.2=pypi_0 185 | stevedore=3.5.0=pypi_0 186 | streamlit=1.10.0=pypi_0 187 | tabulate=0.8.9=pypi_0 188 | tempora=5.0.1=pypi_0 189 | tenacity=8.0.1=pypi_0 190 | tensorboard=2.9.0=pypi_0 191 | tensorboard-data-server=0.6.1=pypi_0 192 | tensorboard-plugin-wit=1.8.1=pypi_0 193 | terminado=0.15.0=pypi_0 194 | threadpoolctl=3.1.0=pypi_0 195 | tinycss2=1.1.1=pypi_0 196 | tk=8.6.11=h2bbff1b_1 197 | tmap-viz=1.0.16=pypi_0 198 | tokenizers=0.12.1=pypi_0 199 | toml=0.10.2=pypi_0 200 | toolz=0.11.2=pypi_0 201 | torch=1.11.0+cu113=pypi_0 202 | torchaudio=0.11.0+cu113=pypi_0 203 | torchmetrics=0.8.2=pypi_0 204 | torchvision=0.12.0+cu113=pypi_0 205 | tornado=6.1=pypi_0 206 | tqdm=4.64.0=pypi_0 207 | traitlets=5.2.1.post0=pypi_0 208 | transformers=4.19.2=pypi_0 209 | typing-extensions=4.2.0=pypi_0 210 | tzdata=2022.1=pypi_0 211 | tzlocal=4.2=pypi_0 212 | ujson=5.3.0=pypi_0 213 | umap-learn=0.5.3=pypi_0 214 | urllib3=1.26.9=pypi_0 215 | validators=0.20.0=pypi_0 216 | vc=14.2=h21ff451_1 217 | vs2015_runtime=14.27.29016=h5e58377_2 218 | wandb=0.12.17=pypi_0 219 | watchdog=2.1.8=pypi_0 220 | wcwidth=0.2.5=pypi_0 221 | webencodings=0.5.1=pypi_0 222 | websocket-client=1.3.2=pypi_0 223 | werkzeug=2.1.2=pypi_0 224 | wheel=0.37.1=pyhd3eb1b0_0 225 | widgetsnbextension=3.6.0=pypi_0 226 | wincertstore=0.2=py310haa95532_2 227 | xxhash=3.0.0=pypi_0 228 | xz=5.2.5=h8cc25b3_1 229 | yarl=1.7.2=pypi_0 230 | zc-lockfile=2.0=pypi_0 231 | zipp=3.8.0=pypi_0 232 | zlib=1.2.12=h8cc25b3_2 233 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paraphrase OPT 2 | 3 | Training OPT for paraphrasing through prompt engineering. 4 | 5 | It seems like GPT3 will perform quite well with just telling it directly to paraphrase the following sentence, and 6 | bumping up the frequency and presence penalties. 7 | 8 | However, OPT's smaller variants (125m & 350m) do not seem to be able to understand the prompt "paraphrase:" and instead 9 | attempt to continue the sentence: 10 | 11 | ``` 12 | 'Once, a group of frogs was roaming around the forest in search of water. 13 | Paraphrase: "I\'m thirsty."\n\nThe group of frogs was so thirsty that they were unable 14 | to find water.\n\nThe group of frogs was' 15 | ``` 16 | 17 | Trying 13B gives this result: 18 | ``` 19 | RuntimeError: CUDA out of memory. Tried to allocate 100.00 MiB (GPU 0; 39.59 GiB total capacity; 37.53 GiB already allocated; 22.19 MiB free; 37.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF 20 | ``` 21 | 22 | # Current Implementation 23 | 24 | Copy pasted soft prompt tuning and created a huggingface model wrapper around OPT for it. It seems to be working. 25 | Emphasis on *seems*. 26 | 27 | 28 | TODO: 29 | - [x] profit??? 30 | 31 | I had some time while waiting for the checkpoint to download 32 | 33 | ![](images/thats_quite_big.png) 34 | 35 | # Sprint Review 36 | BartScore (compare to BertScore) 37 | Why is it good 38 | Parabank was done on parallel language pairs > can be extended to like indonesian multilingual datasets 39 | provided we have a parallel dataset > we can train our own bartscore on different languages 40 | 41 | 42 | 43 | # Ideas 44 | 45 | Given that OPT is a decoder only model, how will we get it to perform what is traditionally considered a seq-2-seq task 46 | which involves cross attention from the encoder outputs, transforming the input sequence into an output sequence of a 47 | different phrasing. The function of the encoder output cross attention is for the model to maintain a strong reference 48 | point to the original sequence while predicting the next tokens, which is better than appending the input to the start 49 | of the sequence and referring to the causal mask of the decoder sequence since the encoder and decoder embedding spaces 50 | no longer have to be aligned. 51 | 52 | ## Soft Prompt Tuning 53 | 54 | [Parameter Efficient Soft Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf) seems to be the original implementation 55 | that was not referenced in the DART paper. The codebase is much simpler (actually just 1 .py file) and extends the 56 | HuggingFace library in a very simple way, just concatenating the soft prompt embeddings directly to the input. 57 | 58 | 59 | 60 | [Github](https://github.com/kipgparker/soft-prompt-tuning) source for soft prompt tuning. 61 | 62 | ## DART Implementation (???) 63 | 64 | Refer to soft prompt tuning. The methodology seems exactly the same, except that DART can be applied to any language 65 | model, and they added fluency constraint objectives to the model training to ensure the differentiable prompts retain 66 | association between template tokens. 67 | 68 | Differentiable prompts [DART](https://arxiv.org/pdf/2108.13161.pdf) except we adapt it from MLM to CLM. Instead of 69 | labels based on a the output of a single [MASK] token we generate a whole sequence and evaluate the semantic similarity 70 | of the output sequence. 71 | 72 | The input prompt when fed into an MLM model looks like this: 73 | Xprompt = [CLS] Xin [SEP] T [SEP] 74 | 75 | where T is the template prompt with containing single [MASK] token, of the form: 76 | {h0,...,hi,w([MASK]),hi+1,...,hm} 77 | 78 | Since OPT as a decoder is autoregressive, we alter T as such (predk are predicted tokens from previous k 79 | iterations): 80 | {h0,...,hi,pred0,...,predk,w([MASK])} 81 | 82 | Prompt embeddings that come after w([MASK]) will be masked and ignored anyway, hence we omit them in this 83 | implementation. The input prompt when fed into OPT (formatted similarly to GPT2's tokenizer) will then look like this: 84 | Xprompt = [EOS] Xin [BOS/EOS] {h0,...,hi,pred0,...,pred 85 | k,w([MASK])} 86 | 87 | We then iterate through multiple forward passes until we reach an eos_token output by the model or max length of the 88 | sequence. 89 | 90 | # ERROR SHEET 91 | 92 | Some errors may pop up when trying to run the program. "But it works on my machine" yeah it will work on your machine 93 | when you do these things. 94 | 95 | ### PyGraphViz Installation Errors (Especially on Windows) 96 | https://pygraphviz.github.io/documentation/stable/install.html#install 97 | Cancer. Follow the instructions. Use powershell or anaconda powershell. Restart the terminal/editor. 98 | 99 | 100 | ### Memory Errors 101 | Who doesn't love training large models? Some errors aren't due to the large model though. Like this one, this one occurs 102 | if the batch size is too large. Reduce the batch size because huggingface is trying to allocate a continguous block of 103 | gpu memory to compare the logits, and if the batch size is too large the logits are as a result too large to fit in the 104 | gpu. 105 | 106 | ```commandline 107 | File "/opt/conda/envs/OPT/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 951, in forward 108 | shift_logits = logits[..., :-1, :].contiguous() 109 | RuntimeError: CUDA out of memory. Tried to allocate 2.11 GiB (GPU 0; 39.59 GiB total capacity; 36.18 GiB already allocated; 910.19 MiB free; 36.66 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF 110 | ``` 111 | 112 | 113 | ### protobuf error 114 | 115 | Might encounter an error with protobuf apparently one of google's updates broke it so its incompatible with pytorch 116 | lightning. Quick fix is to downgrade it to an older version: 117 | 118 | ```buildoutcfg 119 | pip install protobuf==3.20.1 120 | ``` 121 | 122 | # GCloud Compute CLI Cheatsheet 123 | 124 | ## ssh 125 | ```commandline 126 | gcloud compute instances start liewweipyn 127 | gcloud compute ssh liewweipyn 128 | gcloud compute instances stop liewweipyn 129 | ``` 130 | 131 | ## tmux 132 | To detach sessions from the ssh shell, so we can close the ssh client without ending the training. 133 | Use ctrl + b + d to exit a session. 134 | ```commandline 135 | tmux new // create new session 136 | tmux ls // look at all created sessions 137 | tmux attach -t 0 // reattach to a detached session 138 | ``` 139 | 140 | ## scp 141 | To transfer files between google cloud compute and desktop, works both ways. 142 | ``` 143 | gcloud compute scp liewweipyn: 144 | ``` 145 | -------------------------------------------------------------------------------- /model_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List 4 | 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import torch 8 | import wandb 9 | from torch.nn.utils.rnn import pad_sequence 10 | from transformers import GPT2Tokenizer, BartTokenizer 11 | 12 | from metrics.bart_metric import BartScore 13 | from torchmetrics.text.bleu import BLEUScore 14 | from torchmetrics.text.rouge import ROUGEScore 15 | 16 | from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT 17 | from fine_tuning.fine_tune_opt import FineTuneOPT 18 | from fine_tuning.fine_tune_bart import FineTuneBART 19 | 20 | from training_datasets.para_nmt.para_nmt import ParaNMTDataModule 21 | from training_datasets.parabank.parabank import ParabankDataModule 22 | from training_datasets.paracombined import ParaCombinedDataModule 23 | 24 | """ 25 | Script for automatically benchmarking model outputs against BartScore, BLEU and ROUGE scores. The file should be .pkl 26 | format of a dataframe where the first column is the source (model predictions) and second column is the target (labels). 27 | """ 28 | 29 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 30 | 31 | 32 | def run_model(dataset: List[str], batch_size: int, save_path: str, model_type: str, model_name: str, checkpoint: str = None, append_seq: str = ""): 33 | # init the dataset 34 | print("Initialising.") 35 | if model_type == "soft": 36 | if checkpoint is None: 37 | model = ParaphraseOPT(model_name) 38 | else: 39 | model = ParaphraseOPT.load_from_custom_save(model_name, checkpoint) 40 | elif model_type == "fine-tuned": 41 | if checkpoint is None: 42 | model = FineTuneOPT(model_name) 43 | else: 44 | model = FineTuneOPT.load_from_checkpoint(checkpoint_path=checkpoint) 45 | elif model_type == "bart": 46 | model = FineTuneBART() 47 | model.load() 48 | else: 49 | # suffer 50 | assert False 51 | 52 | model = model.eval() 53 | model.to(device) 54 | 55 | if model_type == "bart": 56 | tokenizer = BartTokenizer.from_pretrained(model_name) 57 | else: 58 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 59 | 60 | # pad to the left because the model is autoregressive (anything to the right is ignored) 61 | tokenizer.padding_side = 'left' 62 | 63 | print("Encoding dataset.") 64 | # append a sequence to the end of every input (could be token or prompt like "paraphrase:") and encode all 65 | encoded_inputs = tokenizer([i + append_seq for i in dataset], padding=True, return_tensors='pt') 66 | 67 | print("Generating model predictions.") 68 | """ Yeah. Don't pass .generate() all the encoded inputs at once. 69 | RuntimeError: CUDA out of memory. Tried to allocate 17.61 GiB (GPU 0; 39.59 GiB total capacity; 23.04 GiB 70 | already allocated; 14.16 GiB free; 23.38 GiB reserved in total by PyTorch) If reserved memory is >> allocated 71 | memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and 72 | PYTORCH_CUDA_ALLOC_CONF 73 | """ 74 | output_sequences = list() 75 | # ensure no intermediate gradient tensors are stored. We need all the memory we can get. 76 | with torch.no_grad(): 77 | for i in tqdm(range(0, encoded_inputs['input_ids'].size(dim=0), batch_size)): 78 | batch = encoded_inputs['input_ids'][i:i+batch_size] 79 | # if use_cache=False is not used there will be dim mismatch as huggingface is cringe 80 | output_batch = model.model.generate(inputs=batch.to(model.model.device), 81 | max_length=100, 82 | use_cache=False).to('cpu') 83 | # free the memory (it isn't actually removed from gpu but is able to be overwritten) 84 | del batch 85 | 86 | # remove the source sentence based on the length of the inputs 87 | if model_type != "bart": 88 | output_batch = output_batch[:, encoded_inputs['input_ids'].size(dim=-1):] 89 | 90 | # decode outputs, after removal of source sentence should only remain eos token and padding on the right 91 | # which are omitted by skip_special_tokens=True 92 | outputs = tokenizer.batch_decode(output_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) 93 | output_sequences.extend(outputs) 94 | 95 | print("Dataframe saving.") 96 | df = pd.DataFrame({"preds": output_sequences, "src": dataset}) 97 | df.to_pickle(save_path) 98 | 99 | print(df) 100 | 101 | 102 | def benchmark_pairs(filepath, save_path): 103 | print("Loading for predictions.") 104 | df = pd.read_pickle(filepath) 105 | 106 | # init metrics 107 | bart = BartScore() 108 | rouge = ROUGEScore() 109 | bleu = BLEUScore() 110 | 111 | # apply the metrics on the source and target sentence 112 | def score(row): 113 | src, target = row 114 | bartscore = bart([src], [target])[0] 115 | bleuscore = bleu([src], [[target]]).item() 116 | rougescore = {k: v.item() for k, v in rouge(src, target).items()} 117 | series = pd.Series([src, target, bartscore, bleuscore], index=["src", "target", "bartscore", "bleuscore"]) 118 | return pd.concat([series, pd.Series(rougescore)]) 119 | 120 | # apply score function along each row 121 | print("Scoring sequence pairs.") 122 | df = df.apply(score, axis=1) 123 | print(df) 124 | df.to_pickle(save_path) 125 | 126 | 127 | if __name__ == "__main__": 128 | package_directory = os.path.dirname(os.path.abspath(__file__)) 129 | 130 | filename = "bart-samples=500.pkl" 131 | model_preds_save_path = "metrics/benchmark_runs/model_preds/" 132 | benchmark_save_path = "metrics/benchmark_runs/model_benchmarked_results/" 133 | checkpoint_path = "" 134 | 135 | model_name = 'facebook/bart-large-cnn' 136 | model_type = "bart" 137 | dataset_size = 500 138 | 139 | wandb.init(project="benchmark_popt", entity="clyde013", name="benchmark_run") 140 | wandb.config.update({"embedding_n_tokens": 111}, allow_val_change=True) 141 | 142 | print("Datamodule setup.") 143 | datamodule = ParaCombinedDataModule(model_name, 1, 1000, [ParabankDataModule, ParaNMTDataModule], 144 | probabilities=[0.5, 0.5], seed=82765, pre_tokenize=False) 145 | datamodule.setup() 146 | 147 | # get the values from {"source": "......"} dict and then take only the first as dataset input for model 148 | if model_type == "bart": 149 | dataset = [i["source"].split("<|endoftext|>")[0] for i in list(datamodule.dataset.take(dataset_size))] 150 | else: 151 | dataset = [i["source"].split("")[0] for i in list(datamodule.dataset.take(dataset_size))] 152 | 153 | run_model(dataset=dataset, 154 | batch_size=5, 155 | save_path=os.path.join(package_directory, model_preds_save_path, filename), 156 | model_type=model_type, 157 | model_name=model_name, 158 | checkpoint=os.path.join(package_directory, checkpoint_path), 159 | append_seq="") 160 | 161 | benchmark_pairs(os.path.join(package_directory, model_preds_save_path, filename), 162 | save_path=os.path.join(package_directory, benchmark_save_path, filename)) 163 | -------------------------------------------------------------------------------- /training_datasets/parabank/parabank.py: -------------------------------------------------------------------------------- 1 | """ 2 | README from parabank-2.0.zip 3 | 4 | The TSV file contains ParaBank 2, a diverse collection of paraphrases generated 5 | through bilingual generation. Details of the dataset and how it's created can 6 | be found here: 7 | 8 | Hu, J. E., A. Singh, N. Holzenberger, M. Post, & B. Van Durme. 2019. Large-scale, 9 | Diverse, Paraphrastic Bitexts via Sampling and Clustering. Proceedings of CoNLL 2019, 10 | Hong Kong, Nov 3 – Nov 4, 2019. 11 | 12 | Each line of the file contains a bilingual dual-condition score, a reference 13 | sentence, and paraphrases of the same reference sentence. A reference sentence may 14 | have between one to five distinct paraphrases. The lines are in descending 15 | order of the dual-conditioned score, a measurement of the quality of the 16 | original bilingual sentence pair. Within the same line, paraphrases are ranked by 17 | model score as described in the paper - i.e., the first paraphrase from left 18 | to right correspond to the system with subscript "1" in evaluation, and the 19 | last to "5". All sentences are raw text (untokenized). The reference sentences 20 | appear in ascending order of their bidirectional model scores (the lower the 21 | better), which we use to filter the bilingual resource used to generate ParaBank 2. 22 | """ 23 | from typing import Optional 24 | from datasets import load_dataset, IterableDataset 25 | from pytorch_lightning import LightningDataModule 26 | from torch.utils.data import DataLoader 27 | from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling 28 | 29 | 30 | class ParabankDataModule(LightningDataModule): 31 | """ 32 | LightningDataModule for parabank dataset for causal language modelling 33 | 34 | Note on num_workers: https://github.com/huggingface/datasets/pull/4375 35 | IterableDatasets do not support Dataloaders with num_workers > 0. Watch the PR to see if the fix will be merged. 36 | """ 37 | parabank_url = "http://cs.jhu.edu/~vandurme/data/parabank-2.0.zip" 38 | 39 | def __init__(self, opt_name, batch_size, steps_per_epoch, num_workers=0, seed=69, pre_tokenize=True): 40 | """ 41 | 42 | Parameters 43 | ---------- 44 | opt_name: str 45 | name of the OPT model type (i.e. facebook/opt-350m) 46 | batch_size: int 47 | batch_size output by dataloader 48 | steps_per_epoch: int 49 | dataset_size = steps_per_epoch * batch_size 50 | Since we do not know the dataset size we simply leave it to the user to determine how many steps per epoch 51 | we should have. 52 | num_workers: int 53 | refer to note above on PR https://github.com/huggingface/datasets/pull/4375 54 | seed: int 55 | haha funny number 56 | pre_tokenize: bool 57 | should we tokenize the texts (if true: dataset will return tokenized ids instead of source text) 58 | """ 59 | 60 | super().__init__() 61 | self.opt_name = opt_name 62 | self.batch_size = batch_size 63 | self.steps_per_epoch = steps_per_epoch 64 | self.num_workers = num_workers 65 | self.seed = seed 66 | self.pre_tokenize = pre_tokenize 67 | 68 | # init None to make pycharm happy 69 | self.tokenizer = None 70 | self.dataset = None 71 | 72 | def prepare_data(self) -> None: 73 | # download and cache 74 | GPT2Tokenizer.from_pretrained(self.opt_name) 75 | 76 | def setup(self, stage: Optional[str] = None) -> None: 77 | # load tokenizer (should be cached) 78 | self.tokenizer = GPT2Tokenizer.from_pretrained(self.opt_name, use_fast=False) 79 | 80 | # preprocess function for the dataset's entries 81 | def preprocess(examples): 82 | # list of len batch 83 | batch = examples['text'] 84 | processed_batch = list() 85 | for i in batch: 86 | # split by \t (it is a tsv file) and omit the initial dual-condition score (it is useless) 87 | i = i.split('\t')[1:] 88 | # filter entries without paraphrases and split them with a '' token to denote source-target 89 | if len(i) > 1: 90 | processed_batch.append(i[0] + self.tokenizer.eos_token + i[1]) 91 | 92 | if self.pre_tokenize: 93 | outputs = self.tokenizer( 94 | processed_batch, 95 | truncation=True, 96 | max_length=69, 97 | ) 98 | else: 99 | outputs = {"source": processed_batch} 100 | return outputs 101 | 102 | # init dataset in streaming mode 103 | self.dataset = load_dataset("text", data_files=self.parabank_url, streaming=True)['train'] 104 | # elements within buffer size will be shuffled as they are loaded in 105 | self.dataset = self.dataset.shuffle(seed=self.seed, buffer_size=10_000) 106 | # preprocessing will take place while being streamed by dataloader 107 | self.dataset = self.dataset.map(preprocess, batched=True, remove_columns=['text']) 108 | # ensure pytorch tensors are returned 109 | self.dataset = self.dataset.with_format("torch") 110 | 111 | # monkeypatch of __len__ function in the dataloader so that the trainer knows how many 112 | # steps there are per epoch. Sure this violates many programming paradigms but it works. 113 | n = self.steps_per_epoch 114 | 115 | def __len__(self): 116 | return n 117 | 118 | IterableDataset.__len__ = __len__ 119 | 120 | # dataloaders are basically all the same since we cannot split a streamed dataset 121 | def train_dataloader(self): 122 | dataloader = DataLoader(self.dataset, 123 | batch_size=self.batch_size, 124 | num_workers=self.num_workers) 125 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 126 | return dataloader 127 | 128 | def val_dataloader(self): 129 | dataloader = DataLoader(self.dataset, 130 | batch_size=self.batch_size, 131 | num_workers=self.num_workers) 132 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 133 | return dataloader 134 | 135 | def test_dataloader(self): 136 | dataloader = DataLoader(self.dataset, 137 | batch_size=self.batch_size, 138 | num_workers=self.num_workers) 139 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 140 | return dataloader 141 | 142 | def predict_dataloader(self): 143 | dataloader = DataLoader(self.dataset, 144 | batch_size=self.batch_size, 145 | num_workers=self.num_workers) 146 | if self.pre_tokenize: dataloader.collate_fn = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 147 | return dataloader 148 | 149 | 150 | if __name__ == "__main__": 151 | model_name = "facebook/opt-1.3b" 152 | datamodule = ParabankDataModule(model_name, 1, 1000, seed=1337) 153 | datamodule.setup() 154 | dl = datamodule.val_dataloader() 155 | it = iter(dl) 156 | 157 | for i in range(10): 158 | print(datamodule.tokenizer.batch_decode(next(it)['input_ids'])[0]) 159 | -------------------------------------------------------------------------------- /soft_prompt_tuning/soft_prompt_opt.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import reduce 3 | from typing import Dict 4 | 5 | from pytorch_lightning import LightningModule, Callback 6 | from torch.optim import Adam, SGD, Optimizer 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler 8 | from transformers.models.opt.modeling_opt import * 9 | from soft_prompt_tuning.soft_embedding import SoftEmbedding 10 | 11 | import wandb 12 | import os 13 | 14 | 15 | class SoftOPTModelWrapper(OPTForCausalLM): 16 | """Wrapper class for OPTForCausalLM to add learnable embedding functionality 17 | Simply initialise it with from_pretrained OPT files and it should work out of the box. 18 | """ 19 | _keys_to_ignore_on_load_missing = [r"soft_embedding.wte.weight", r"soft_embedding.learned_embedding", 20 | r"lm_head.weight"] 21 | 22 | def __init__(self, config: OPTConfig): 23 | super().__init__(config) 24 | 25 | # init parameters for embedding 26 | self.n_tokens = wandb.config["embedding_n_tokens"] 27 | self.init_from_vocab = wandb.config["init_from_vocab"] 28 | 29 | # initialise the embedding to learn 30 | self.soft_embedding = SoftEmbedding(self.get_input_embeddings(), 31 | n_tokens=self.n_tokens, 32 | initialize_from_vocab=self.init_from_vocab) 33 | 34 | @classmethod 35 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 36 | """Incredibly scuffed but we have to set the input embeddings to the soft embeddings only AFTER 37 | the pretrained weights have been loaded in. All parameters are the same as a normal from_pretrained() call 38 | """ 39 | 40 | pretrained_model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 41 | pretrained_model.set_input_embeddings(pretrained_model.soft_embedding) 42 | return pretrained_model 43 | 44 | def forward(self, 45 | input_ids: torch.LongTensor = None, 46 | attention_mask: Optional[torch.Tensor] = None, 47 | labels: Optional[torch.LongTensor] = None, 48 | **kwargs): 49 | """Shitty forward pass 50 | need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens 51 | even though it does not matter what we pad input_ids with, it's just to make HF happy 52 | """ 53 | 54 | batch_size = input_ids.shape[0] 55 | # Note: concatenation of tensors have to happen on the same device 56 | # concat padding representing our learned embedding tokens for batched inputs 57 | # inputs come in as (batch_size, seq_len) and are padded to be (batch_size, n_tokens + seq_len) 58 | input_ids = torch.cat([torch.full((batch_size, self.n_tokens), 50256).to(input_ids.device), input_ids], dim=1) 59 | attention_mask = torch.cat( 60 | [torch.full((batch_size, self.n_tokens), 1).to(attention_mask.device), attention_mask], dim=1) 61 | if labels is not None: 62 | labels = torch.cat([torch.full((batch_size, self.n_tokens), 50256).to(labels.device), labels], dim=1) 63 | 64 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) 65 | 66 | 67 | class ParaphraseOPT(LightningModule): 68 | def __init__(self, model_name="facebook/opt-350m", init_optimizer=None, init_lr_scheduler=None): 69 | super().__init__() 70 | self.model = SoftOPTModelWrapper.from_pretrained(model_name) 71 | 72 | # these inits should be exclusively used for loading from checkpoints 73 | # see load_from_custom_save for why. 74 | self.init_optimizer = init_optimizer 75 | self.init_lr_scheduler = init_lr_scheduler 76 | 77 | self.save_hyperparameters("model_name") 78 | 79 | def forward(self, **inputs): 80 | return self.model(**inputs) 81 | 82 | def training_step(self, batch, batch_idx): 83 | outputs = self(**batch) 84 | loss = outputs[0] 85 | self.log("train_loss", loss) 86 | 87 | return loss 88 | 89 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 90 | outputs = self(**batch) 91 | val_loss, logits = outputs[:2] 92 | 93 | # we care only about the last token being predicted 94 | pred_token_logits = logits[:, -1, :] 95 | pred_token = torch.argmax(pred_token_logits, dim=-1) 96 | labels = batch["labels"][:, -1] 97 | 98 | self.log("val_loss", val_loss) 99 | 100 | return {"loss": val_loss, "preds": pred_token, "labels": labels} 101 | 102 | def configure_optimizers(self): 103 | # thanks stack overflow! 104 | # https://stackoverflow.com/questions/38460918/regex-matching-a-dictionary-efficiently-in-python 105 | # extracting all the layers that are specified by layers_to_optimize using regex for partial matches 106 | regex_matches = [re.compile(".*" + pattern + ".*").match for pattern in wandb.config["layers_to_optimize"]] 107 | layers_to_optimize = [k for k in self.model.state_dict().keys() 108 | if any(regex_match(k) for regex_match in regex_matches)] 109 | 110 | # configure optimizer 111 | optimizers_key = {"Adam": Adam, "SGD": SGD} 112 | if self.init_optimizer is None: 113 | """ 114 | thanks forums! https://discuss.pytorch.org/t/how-to-access-to-a-layer-by-module-name/83797/8 115 | We cannot directly pass in the tensor output (value) from state_dict() since that is not the same reference 116 | as the actual layer, hence we instead look for the layer name with the regex matching, then access the 117 | module by name as below. 118 | """ 119 | def get_module_by_name(module: Union[torch.Tensor, nn.Module], 120 | access_string: str): 121 | """Retrieve a module nested in another by its access string. 122 | 123 | Works even when there is a Sequential in the module. 124 | """ 125 | names = access_string.split(sep='.') 126 | return reduce(getattr, names, module) 127 | 128 | layers = [get_module_by_name(self.model, layer_name) for layer_name in layers_to_optimize] 129 | # pass in the layers into optimizer 130 | optimizer_type = optimizers_key[wandb.config["optimizer_type"]] 131 | optimizer = optimizer_type(layers, **wandb.config["optimizer_params"]) 132 | else: 133 | optimizer = self.init_optimizer 134 | 135 | # configure learning rate scheduler 136 | lr_scheduler_key = {"ReduceLROnPlateau": ReduceLROnPlateau} 137 | if self.init_lr_scheduler is None: 138 | lr_scheduler_type = lr_scheduler_key[wandb.config["lr_scheduler_type"]] 139 | lr_scheduler = lr_scheduler_type(optimizer, **wandb.config["lr_scheduler_params"]) 140 | else: 141 | lr_scheduler = self.init_lr_scheduler 142 | 143 | lr_scheduler_config = {"scheduler": lr_scheduler} 144 | lr_scheduler_config.update(wandb.config["lr_scheduler_config"]) 145 | 146 | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} 147 | 148 | @classmethod 149 | def load_from_custom_save(cls, model_name, path, optimizer: Optimizer = None, lr_scheduler: _LRScheduler = None): 150 | """ 151 | Custom save function to load from checkpoints created by SpecificLayersCheckpoint callback. 152 | 153 | Unfortunately pytorch lightning locks the optimizers in place after instantiation and there is no clean way 154 | to change them afterwards. There are some workarounds but they all suck too: 155 | https://github.com/PyTorchLightning/pytorch-lightning/discussions/9354 156 | https://github.com/PyTorchLightning/pytorch-lightning/discussions/6131 157 | 158 | So the current implementation is to throw in the optimizer and lr_scheduler as optional parameters during model 159 | instantiation before then actually updating the model weights. 160 | 161 | To try different optimizers and lr_schedulers change configure_optimizers() directly. 162 | """ 163 | # load the saved checkpoint 164 | state_dict = torch.load(path) 165 | 166 | # load optimizer if required 167 | if optimizer is not None: 168 | optimizer.load_state_dict(state_dict["optimizer"]) 169 | 170 | # load lr_scheduler if required 171 | if lr_scheduler is not None: 172 | lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) 173 | 174 | # instantiate lightningmodule with pretrained model 175 | model = cls(model_name, optimizer, lr_scheduler) 176 | 177 | # load updated state dict into the model (as long as no layers are named optimizer or lr_scheduler) 178 | model.model.load_state_dict(state_dict, strict=False) 179 | 180 | return model 181 | 182 | """ 183 | Note on following hooks (on_train_epoch_start and on_validation_epoch_start): 184 | 185 | Using the following code to access dataloaders: self.train_dataloader().dataset.set_epoch(self.current_epoch) 186 | Results in an exception like such : pytorch_lightning.utilities.exceptions.MisconfigurationException: 187 | `val_dataloader` must be implemented to be used with the Lightning Trainer 188 | 189 | Although train_dataloader() is a valid hook, the hook is overridden only in the datamodule and we cannot reference 190 | that. We have to use self.trainer.train_dataloader.dataset which returns some CombinedDataset and then .datasets 191 | that one to get the original TorchIterableDataset. 192 | 193 | On the other hand, we can access validation dataloaders with self.trainer.val_dataloaders[0].dataset as that one is 194 | apparently a list and not a CombinedDataset. 195 | 196 | Pain. 197 | """ 198 | 199 | def on_train_epoch_start(self) -> None: 200 | # reshuffle the dataset for every train epoch 201 | self.trainer.train_dataloader.dataset.datasets.set_epoch(self.trainer.current_epoch) 202 | 203 | def on_validation_epoch_start(self) -> None: 204 | # reshuffle the dataset for every validation epoch 205 | self.trainer.val_dataloaders[0].dataset.set_epoch(self.trainer.current_epoch) 206 | 207 | 208 | class SpecificLayersCheckpoint(Callback): 209 | """ 210 | Custom saving of specific layers into a state_dict that can be loaded in using torch.load() 211 | Ideally, we load in the model with from_pretrained, and then use state_dict.update() to update the 212 | weights of the loaded model. 213 | """ 214 | 215 | def __init__(self, monitor: str, dirpath: str, filename: str, 216 | every_n_epochs: int, layers_to_save: List[str]): 217 | super().__init__() 218 | self.monitor = monitor 219 | self.dirpath = dirpath 220 | self.filename = filename 221 | self.every_n_epochs = every_n_epochs 222 | self.layers_to_save = layers_to_save 223 | 224 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 225 | # if model should be saved this epoch (+1 since epoch count starts from 0) 226 | if (trainer.current_epoch + 1) % self.every_n_epochs == 0: 227 | # thanks stack overflow! 228 | # https://stackoverflow.com/questions/38460918/regex-matching-a-dictionary-efficiently-in-python 229 | # extracting all the layers that are specified by layers_to_save using regex for partial matches 230 | regex_matches = [re.compile(".*" + pattern + ".*").match for pattern in self.layers_to_save] 231 | save_dict = {k: v for k, v in pl_module.model.state_dict().items() 232 | if any(regex_match(k) for regex_match in regex_matches)} 233 | 234 | # save the optimizer 235 | if pl_module.optimizers() is not None: 236 | save_dict.update({"optimizer": pl_module.optimizers().optimizer.state_dict()}) 237 | 238 | # save the lr_scheduler 239 | if pl_module.lr_schedulers() is not None: 240 | save_dict.update({"lr_scheduler": pl_module.lr_schedulers().state_dict()}) 241 | 242 | formatted_filename = self.filename.format(epoch=trainer.current_epoch, **trainer.callback_metrics) 243 | torch.save(save_dict, os.path.join(self.dirpath, formatted_filename)) 244 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/example-checkpoint.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |
6 |

7 |
8 | 9 | 11 | 12 | 29 | 30 | 31 | 32 | 33 |
34 | 35 | 36 | 79 | 80 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 18 | 19 | 25 | 26 | 27 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 40 | 41 | 43 | 44 | 45 | 48 | 56 | 57 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 99 | 100 | 101 | 121 | 122 | 123 | 143 | 144 | 145 | 165 | 166 | 167 | 187 | 188 | 189 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 1653358281680 232 | 238 | 239 | 240 | 241 | 243 | 244 | 246 | 247 | 281 | 282 | 283 | 284 | 285 | 286 | file://$USER_HOME$/anaconda3/envs/OPT/Lib/site-packages/transformers/models/opt/modeling_opt.py 287 | 954 288 | 290 | 291 | file://$PROJECT_DIR$/metrics/bartscore.py 292 | 64 293 | 295 | 296 | file://$USER_HOME$/anaconda3/envs/OPT/Lib/site-packages/faerun/faerun.py 297 | 1108 298 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | -------------------------------------------------------------------------------- /REPORT.md: -------------------------------------------------------------------------------- 1 | # Teaching OPT to Paraphrase through Soft Prompt Tuning 2 | 3 | ## Table of Contents 4 | 1. [Introduction](#introduction) 5 | 1. [Prompt Tuning](#prompt-tuning) 6 | 2. [Soft Prompts](#soft-prompts) 7 | 2. [Datasets](#datasets) 8 | 3. [Implementation](#implementation) 9 | 1. [HuggingFace Model Wrapper](#HF-wrapper) 10 | 3. [Training](#training) 11 | 4. [Results](#results) 12 | 1. [Scoring Metrics](#metrics) 13 | 2. [Visualisation](#visualisation) 14 | 5. [Conclusion](#conclusion) 15 | 16 | 17 |
18 | ## Introduction 19 | Open Pre-trained Transformer models are a collection of open source decoder-only pre-trained transformers ranging 20 | from 125M to 175B parameters, with 175B showing comparable performance with GPT3. Such large language models 21 | display remarkable performance in zero and few-shot tasks, making prompting a promising solution for many tasks 22 | due to the capability of coaxing a large model into solving tasks that they were not explicitly trained to do. 23 | 24 | The task that we are trying to accomplish here, is to prompt OPT models to paraphrase sentences. The 25 | task of paraphrasing is traditionally a sequence-to-sequence task accomplished using encoder-decoder 26 | transformer architectures (such as BART), however there is still some promise in leveraging the large pretrained 27 | decoder only models like OPT, whose capabilities to model natural language may overcome the architectural 28 | limitations. 29 | 30 | However, in the course of this experiment we only work with OPT1.3B, a much smaller variant of the OPT175B model, 31 | hence the results will of course not be incredibly good, as a smaller model is unable to grasp sufficient complexities 32 | of the task at hand. 33 | 34 | 35 | ### Prompt Tuning 36 | For example, in OpenAI's GPT3 playground, we can use different techniques such as 37 | [in-context learning](http://ai.stanford.edu/blog/in-context-learning/) and 38 | [chain of thought](https://arxiv.org/pdf/2205.11916.pdf) prompting. An excellent example of chain-of-thought 39 | prompting is provided by the aforementioned paper: 40 | 41 | ![](images/Screenshot%202022-06-20%20174801.png) 42 | ![](images/Screenshot%202022-06-20%20174920.png) 43 | 44 | 45 | ### Soft Prompts 46 | The concept of soft prompts was introduced by Lester et al. in the [paper](https://arxiv.org/pdf/2104.08691.pdf) titled 47 | "The Power of Scale for Parameter-Efficient Prompt Tuning", where they explored prepending learnable tokens to the 48 | inputs of frozen language models as such: 49 | 50 | 51 | 52 | The learnable tokens can be thought of as passing conceptual notions to the model in an attempt to get it to better 53 | understand the task that is requested. As the models represent words as numbers in a high dimensional space, 54 | we need not restrict our prompts to discrete words, and can search in the space between words to find the most 55 | suitable prompts to feed into the model. 56 | 57 | These prompts are very efficient in terms of memory and compute, requiring just 0.2% of the size of a complete model 58 | checkpoint which would store fine-tuned model parameters, as well as being capable of achieving good results in 59 | less training time. 60 | 61 | 62 | ## Datasets 63 | Two popular paraphrasic datasets were used in soft prompt training of the models. 64 | [ParaBank 2.0](https://nlp.jhu.edu/parabank/) and [ParaNMT-50M](https://arxiv.org/pdf/1711.05732.pdf), 65 | both datasets generated through automatic translation of large amounts of bilingual textual data, translating a 66 | foreign language to english to obtain english-english paraphrase pairs. 67 | 68 | For example, the ParaNMT-50M dataset used Czech-English parallel pairs and applied a Czech to English 69 | pretrained model for translation on the czech pairs. 70 | 71 | As the datasets are incredibly large, we utilised HuggingFace's [dataset streaming](https://huggingface.co/docs/datasets/stream) 72 | feature to progressively feed training data to the model. 73 | 74 | Initially the baseline 20 token model was trained on a 35%-65% split of Parabank 2.0 and ParaNMT-50M respectively, 75 | however for parameter optimization, all further models were trained on a 50%-50% split of Parabank 2.0 and ParaNMT-50M 76 | respectively. 77 | 78 | 79 | 80 | ## Implementation 81 | The model was implemented using the OPT model provided by the HuggingFace team, organising the 82 | training logic with Pytorch Lightning, tracking the model performance with Weights and Biases, and 83 | multiple visualisations using Streamlit and Graphistry. 84 | 85 | 86 | ### HuggingFace Model Wrapper 87 | The implementation of the soft prompts follows nearly identical to the 88 | [Github here](https://github.com/kipgparker/soft-prompt-tuning) where the soft prompts are 89 | simply float tensors duplicated from existing vocabulary and adding them to the module's list of 90 | parameters, to be considered backpropagatable tensors. 91 | 92 | The relevant code snippet is shown below, and the full implementation is 93 | [here](https://github.com/Clyde013/Paraphrase-OPT/blob/fb8f59d6987e3902baf05fa375c856f86e139bb3/soft_prompt_tuning/soft_embedding.py#L26-L44). 94 | ```python 95 | self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte, 96 | n_tokens, 97 | random_range, 98 | initialize_from_vocab)) 99 | 100 | def initialize_embedding(self, 101 | wte: nn.Embedding, 102 | n_tokens: int = 10, 103 | random_range: float = 0.5, 104 | initialize_from_vocab: bool = True) -> torch.Tensor: 105 | """initializes learned embedding 106 | Args: 107 | same as __init__ 108 | Returns: 109 | torch.float: initialized using original schemes 110 | """ 111 | if initialize_from_vocab: 112 | return self.wte.weight[:n_tokens].clone().detach() 113 | return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range) 114 | ``` 115 | 116 | We then subclass the HuggingFace's [`OPTForCausalLM`](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTForCausalLM) 117 | class, initialise a new soft embedding and 118 | [override the forward pass](https://github.com/Clyde013/Paraphrase-OPT/blob/fb8f59d6987e3902baf05fa375c856f86e139bb3/soft_prompt_tuning/soft_prompt_opt.py#L44-L64) 119 | to prepend our learned embeddings in front of the input. 120 | 121 | ```python 122 | def forward(self, 123 | input_ids: torch.LongTensor = None, 124 | attention_mask: Optional[torch.Tensor] = None, 125 | labels: Optional[torch.LongTensor] = None, 126 | **kwargs): 127 | batch_size = input_ids.shape[0] 128 | # Note: concatenation of tensors have to happen on the same device 129 | # concat padding representing our learned embedding tokens for batched inputs 130 | # inputs come in as (batch_size, seq_len) and are padded to be (batch_size, n_tokens + seq_len) 131 | input_ids = torch.cat([torch.full((batch_size, self.n_tokens), 50256).to(input_ids.device), input_ids], dim=1) 132 | attention_mask = torch.cat( 133 | [torch.full((batch_size, self.n_tokens), 1).to(attention_mask.device), attention_mask], dim=1) 134 | if labels is not None: 135 | labels = torch.cat([torch.full((batch_size, self.n_tokens), 50256).to(labels.device), labels], dim=1) 136 | 137 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) 138 | ``` 139 | 140 | 141 | ### Training 142 | Training was done on the OPT1.3B variant, and hyperparameter search for the optimal number of soft tokens 143 | using Optuna. 144 | 145 | All models were trained for 8000 steps per epoch with batch size of 32, and some early stopping applied to 146 | prune under performing models. 147 | 148 | It was clear early on that Adam optimizer performed better than Stochastic Gradient Descent, and as such all 149 | further trials were done using the Adam optimizer. 150 | 151 | Below are a few selected runs that show a very clear trend. 152 | ![](images/W&B%20Chart.png) 153 | 154 | 155 | ## Results 156 | The models were allowed to run on a small subset of the dataset and their outputs saved, as expected the 157 | results are not fantastic. The model is comparatively small, with only 1.3 billion parameters, and as such 158 | soft prompt tuning will not achieve state of the art performance. Nevertheless, it is observed that at 159 | semantic similarity is maintained, instead of the usual action of OPT continuing to generate the sentence. 160 | Unfortunately the model is unable to comprehend that it should be paraphrasing, and thus changing lexical components 161 | of the input, however as model size increases, it is reasonable to assume that performance will get better. The 162 | following is a selection of some of the better paraphrased results from the model. 163 | 164 | | model preds | target | 165 | |:------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------| 166 | | for the movie that's being shot in 1 5 minutes? | for the movie that we 're shooting in about 1 5 minutes ? | 167 | | adler took a moment to consider that, and nodded. | adler took a few seconds to consider that , then nodded thoughtfully . | 168 | | david schwartz was unaware of the fact that he was only narrowly avoided a permanent dent in his ego. | david schwartz was unaware of how narrowly he had escaped crippling and a permanent dent in his ego . | 169 | | i had no idea you were a stunt performer. | i had no idea you did stunt work . | 170 | | seldon was no longer traveling around only when accompanied. | seldon no longer traveled around only if accompanied . | 171 | 172 | The next question that comes to mind is how do we evaluate their predictions? 173 | 174 | 175 | ### Metrics 176 | In order to evaluate our model, we employ multiple different metrics. Traditional metrics such as BLEU and ROUGE might 177 | not be suitable to evaluate our model directly as good paraphrases usually do not share the same vocabulary, and thus 178 | would attain a lower ROUGE score, despite being semantically equivalent. 179 | 180 | Many alternative metrics are available to tackle this problem, and one of them is 181 | [BARTScore](https://github.com/neulab/BARTScore). BARTScore leverages a pretrained BART model for 182 | paraphrase generation to score sentence pairs. A generated sentence is scored by the BART model based 183 | on its probability that the model itself would generate the same sentence, gauging the quality 184 | of the paraphrase sentence according to how much the BART model agrees with it. 185 | 186 | Below are tabulated results of some selected models, compared to the baselines of OPT1.3B with their weights 187 | directly fine tuned for the task of paraphrasing, and the BART model fine tuned for paraphrasing. 188 | 189 | | | soft prompt 20 tokens | soft prompt 111 tokens | soft prompt 163 tokens | fine tuned | bart fine tuned | 190 | |:--------------------|------------------------:|-------------------------:|-------------------------:|-------------:|------------------:| 191 | | bartscore | -3.02511 | -2.15795 | -2.19397 | -4.32509 | -2.65748 | 192 | | bleuscore | 0.246091 | 0.342787 | 0.316936 | 0.0251696 | 0.0833621 | 193 | | rouge1_fmeasure | 0.632655 | 0.835004 | 0.834778 | 0.315754 | 0.316741 | 194 | | rouge1_precision | 0.70008 | 0.856809 | 0.850439 | 0.304833 | 0.207854 | 195 | | rouge1_recall | 0.636459 | 0.838207 | 0.833884 | 0.374748 | 0.935199 | 196 | | rouge2_fmeasure | 0.538138 | 0.737537 | 0.721758 | 0.140419 | 0.251569 | 197 | | rouge2_precision | 0.590409 | 0.756071 | 0.734675 | 0.130611 | 0.164845 | 198 | | rouge2_recall | 0.540979 | 0.743406 | 0.722555 | 0.178818 | 0.816269 | 199 | | rougeL_fmeasure | 0.626995 | 0.83046 | 0.829546 | 0.300252 | 0.301592 | 200 | | rougeL_precision | 0.693667 | 0.852231 | 0.845049 | 0.288716 | 0.197495 | 201 | | rougeL_recall | 0.630616 | 0.83334 | 0.828588 | 0.358478 | 0.900656 | 202 | | rougeLsum_fmeasure | 0.626495 | 0.830814 | 0.82999 | 0.302298 | 0.309371 | 203 | | rougeLsum_precision | 0.693297 | 0.852436 | 0.845449 | 0.290669 | 0.202609 | 204 | | rougeLsum_recall | 0.629847 | 0.833801 | 0.829088 | 0.360918 | 0.920124 | 205 | 206 | 207 | 208 | ### Visualisation 209 | The next step might be to visualise the meanings of the soft prompts with respect to where in the model's 210 | embedding space they end up in, for example in the original paper it was found that clusters of nearest neighbours 211 | maintained high lexical and semantic similarities, and that several prompt tokens end up in the vicinity of each other. 212 | 213 | The numerical representation of word tokens are of high dimensionality, and with the specific instance of 214 | OPT1.3B being used, has a hidden size of 2048. That is 2048 dimensions, incomprehensible to the human mind, and 215 | while traditional methods such as PCA and TSNE can produce viable results, lots of information is lost when decomposing 216 | a high dimensional space into 2 dimensions for us to view. In addition, the TSNE algorithm is stochastic and multiple 217 | restarts with different seeds can yield different embeddings, hence we have no way to directly compare two embedding 218 | spaces. 219 | 220 | The visualisation below is produced through the use of a data structure called a locality sensitive hash forest 221 | and a graph visualisation tool graphistry. However, this technique does suffer from information loss 222 | and is even stochastic to a certain extent. We mitigate this issue by utilising the fixed 223 | embeddings as anchor points, such that they always end up in the same position in the visualisation (determined by an 224 | initial random seed), and then fit the learned embeddings onto the generated anchor points. 225 | 226 | If the graph renders as a multicoloured tree, you might need to reload the page as it is a little buggy with 50k 227 | visualisation points :). The visualisation is also available 228 | [here](https://hub.graphistry.com/graph/graph.html?dataset=05a0c49697bd4a5ebe88c624d709d87f&type=arrow&splashAfter=false&info=False&play=0&menu=True&showArrows=False&pointSize=0.07&edgeCurvature=0.01&edgeSize=1.0&edgeOpacity=0.5&pointOpacity=0.9&lockedX=True&lockedY=True&lockedR=False&linLog=False&strongGravity=False&dissuadeHubs=False&edgeInfluence=1.0&precisionVsSpeed=1.0&gravity=1.0&scalingRatio=0.5&showLabels=True&showLabelOnHover=True&showPointsOfInterest=False&showPointsOfInterestLabel=False&showLabelPropertiesOnHover=True&pointsOfInterestMax=0). 229 | In the graph, red points are the default embeddings, blue points belong to the prompt of 59 prepended tokens, and green 230 | points belong to the prompt of 111 prepended tokens. 231 |
232 | 233 |
234 | 235 | 236 | ## Conclusion 237 | We've taught OPT1.3B to paraphrase! 238 | 239 | Much of the results of this implementation agree with the conclusions of the original prompt tuning paper authors. 240 | 1. Increasing model size improves soft prompt performance. 241 | 2. Increasing the length of the soft prompts improves model performance. 242 | 3. This method largely outperforms zero-shot prompting (i.e. "paraphrase the following:"), at least when tested on 243 | OPT1.3B. 244 | 245 | Furthermore, some exciting facets of exploration are: 246 | 1. Training the full OPT175B model. 247 | 2. Distilling the large prepended soft prompt model into a smaller model without need for prepended prompts. 248 | 3. Prompting to achieve better chain of thought intermediate responses, thereby improving the final response. 249 | 250 | Code is all available publicly on the github here: 251 | https://github.com/Clyde013/Paraphrase-OPT -------------------------------------------------------------------------------- /.ipynb_checkpoints/EDA_benchmarking-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ce5f4e17-2400-4b69-8933-e3234d70b104", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "id": "5777276a-0e78-4151-8fa0-c73f18c37af6", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/html": [ 22 | "
\n", 23 | "\n", 36 | "\n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | "
srctargetbartscorebleuscorerouge1_fmeasurerouge1_precisionrouge1_recallrouge2_fmeasurerouge2_precisionrouge2_recallrougeL_fmeasurerougeL_precisionrougeL_recallrougeLsum_fmeasurerougeLsum_precisionrougeLsum_recall
0david schwartz was unaware of the fact that he...david schwartz was unaware of how narrowly he ...-2.4006150.3967090.7027030.6842110.7222220.5142860.5000000.5294120.6486490.6315790.6666670.6486490.6315790.666667
1we 'll be safe here.we'il be safe here .-3.4855800.0000000.8000000.8000000.8000000.5000000.5000000.5000000.8000000.8000000.8000000.8000000.8000000.800000
2we 'll talk about it.why do n't you come in and we 'll talk about it .-3.9396870.1350160.5555561.0000000.3846150.5000001.0000000.3333330.5555561.0000000.3846150.5555561.0000000.384615
3What was your plan?What was your plan?-0.6028471.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
4historians can make a credible case that the p...historians can make a credible case that perio...-2.1472720.3534860.7450980.7916670.7037040.5714290.6086960.5384620.7058820.7500000.6666670.7058820.7500000.666667
...................................................
495adler took a moment to consider that, and nodded.adler took a few seconds to consider that , th...-2.6339090.0000000.7000000.7777780.6363640.4444440.5000000.4000000.7000000.7777780.6363640.7000000.7777780.636364
496Five minutes.Ladies and gentlemen, five minutes.-2.3461470.0000000.5714291.0000000.4000000.4000001.0000000.2500000.5714291.0000000.4000000.5714291.0000000.400000
497i, uh -- i brought you, uh -- i brought you.i , uh -- i brought you , uh ---2.5297540.2899780.8000000.6666671.0000000.7692310.6250001.0000000.8000000.6666671.0000000.8000000.6666671.000000
498Every child is different.Every child is different.-0.4539741.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
499Qatar motorcycle Grand PrixQatar motorcycle Grand Prix-1.0977751.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
\n", 270 | "

500 rows × 16 columns

\n", 271 | "
" 272 | ], 273 | "text/plain": [ 274 | " src \\\n", 275 | "0 david schwartz was unaware of the fact that he... \n", 276 | "1 we 'll be safe here. \n", 277 | "2 we 'll talk about it. \n", 278 | "3 What was your plan? \n", 279 | "4 historians can make a credible case that the p... \n", 280 | ".. ... \n", 281 | "495 adler took a moment to consider that, and nodded. \n", 282 | "496 Five minutes. \n", 283 | "497 i, uh -- i brought you, uh -- i brought you. \n", 284 | "498 Every child is different. \n", 285 | "499 Qatar motorcycle Grand Prix \n", 286 | "\n", 287 | " target bartscore bleuscore \\\n", 288 | "0 david schwartz was unaware of how narrowly he ... -2.400615 0.396709 \n", 289 | "1 we'il be safe here . -3.485580 0.000000 \n", 290 | "2 why do n't you come in and we 'll talk about it . -3.939687 0.135016 \n", 291 | "3 What was your plan? -0.602847 1.000000 \n", 292 | "4 historians can make a credible case that perio... -2.147272 0.353486 \n", 293 | ".. ... ... ... \n", 294 | "495 adler took a few seconds to consider that , th... -2.633909 0.000000 \n", 295 | "496 Ladies and gentlemen, five minutes. -2.346147 0.000000 \n", 296 | "497 i , uh -- i brought you , uh -- -2.529754 0.289978 \n", 297 | "498 Every child is different. -0.453974 1.000000 \n", 298 | "499 Qatar motorcycle Grand Prix -1.097775 1.000000 \n", 299 | "\n", 300 | " rouge1_fmeasure rouge1_precision rouge1_recall rouge2_fmeasure \\\n", 301 | "0 0.702703 0.684211 0.722222 0.514286 \n", 302 | "1 0.800000 0.800000 0.800000 0.500000 \n", 303 | "2 0.555556 1.000000 0.384615 0.500000 \n", 304 | "3 1.000000 1.000000 1.000000 1.000000 \n", 305 | "4 0.745098 0.791667 0.703704 0.571429 \n", 306 | ".. ... ... ... ... \n", 307 | "495 0.700000 0.777778 0.636364 0.444444 \n", 308 | "496 0.571429 1.000000 0.400000 0.400000 \n", 309 | "497 0.800000 0.666667 1.000000 0.769231 \n", 310 | "498 1.000000 1.000000 1.000000 1.000000 \n", 311 | "499 1.000000 1.000000 1.000000 1.000000 \n", 312 | "\n", 313 | " rouge2_precision rouge2_recall rougeL_fmeasure rougeL_precision \\\n", 314 | "0 0.500000 0.529412 0.648649 0.631579 \n", 315 | "1 0.500000 0.500000 0.800000 0.800000 \n", 316 | "2 1.000000 0.333333 0.555556 1.000000 \n", 317 | "3 1.000000 1.000000 1.000000 1.000000 \n", 318 | "4 0.608696 0.538462 0.705882 0.750000 \n", 319 | ".. ... ... ... ... \n", 320 | "495 0.500000 0.400000 0.700000 0.777778 \n", 321 | "496 1.000000 0.250000 0.571429 1.000000 \n", 322 | "497 0.625000 1.000000 0.800000 0.666667 \n", 323 | "498 1.000000 1.000000 1.000000 1.000000 \n", 324 | "499 1.000000 1.000000 1.000000 1.000000 \n", 325 | "\n", 326 | " rougeL_recall rougeLsum_fmeasure rougeLsum_precision rougeLsum_recall \n", 327 | "0 0.666667 0.648649 0.631579 0.666667 \n", 328 | "1 0.800000 0.800000 0.800000 0.800000 \n", 329 | "2 0.384615 0.555556 1.000000 0.384615 \n", 330 | "3 1.000000 1.000000 1.000000 1.000000 \n", 331 | "4 0.666667 0.705882 0.750000 0.666667 \n", 332 | ".. ... ... ... ... \n", 333 | "495 0.636364 0.700000 0.777778 0.636364 \n", 334 | "496 0.400000 0.571429 1.000000 0.400000 \n", 335 | "497 1.000000 0.800000 0.666667 1.000000 \n", 336 | "498 1.000000 1.000000 1.000000 1.000000 \n", 337 | "499 1.000000 1.000000 1.000000 1.000000 \n", 338 | "\n", 339 | "[500 rows x 16 columns]" 340 | ] 341 | }, 342 | "execution_count": 3, 343 | "metadata": {}, 344 | "output_type": "execute_result" 345 | } 346 | ], 347 | "source": [ 348 | "filepath = r\"metrics/benchmark_runs/model_benchmarked_results/1.3b-optimized-tokens=111-samples=500.pkl\"\n", 349 | "df = pd.read_pickle(filepath)\n", 350 | "df" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 5, 356 | "id": "0a785aa1-1893-4a5b-bba6-781e284cea48", 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stderr", 361 | "output_type": "stream", 362 | "text": [ 363 | "C:\\Users\\weipy\\AppData\\Local\\Temp\\ipykernel_5208\\3698961737.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError. Select only valid columns before calling the reduction.\n", 364 | " df.mean()\n" 365 | ] 366 | }, 367 | { 368 | "data": { 369 | "text/plain": [ 370 | "bartscore -2.157947\n", 371 | "bleuscore 0.342787\n", 372 | "rouge1_fmeasure 0.835004\n", 373 | "rouge1_precision 0.856809\n", 374 | "rouge1_recall 0.838207\n", 375 | "rouge2_fmeasure 0.737537\n", 376 | "rouge2_precision 0.756071\n", 377 | "rouge2_recall 0.743406\n", 378 | "rougeL_fmeasure 0.830460\n", 379 | "rougeL_precision 0.852231\n", 380 | "rougeL_recall 0.833340\n", 381 | "rougeLsum_fmeasure 0.830814\n", 382 | "rougeLsum_precision 0.852436\n", 383 | "rougeLsum_recall 0.833801\n", 384 | "dtype: float64" 385 | ] 386 | }, 387 | "execution_count": 5, 388 | "metadata": {}, 389 | "output_type": "execute_result" 390 | } 391 | ], 392 | "source": [ 393 | "df.mean()" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "id": "275fe94e-476a-436b-ac24-1c5eefac27c7", 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "filepath = r\"metrics/benchmark_runs/model_benchmarked_results/1.3b-fine-tuned-samples=500.pkl\"\n", 404 | "df = pd.read_pickle(filepath)\n", 405 | "df" 406 | ] 407 | } 408 | ], 409 | "metadata": { 410 | "kernelspec": { 411 | "display_name": "Python 3 (ipykernel)", 412 | "language": "python", 413 | "name": "python3" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.10.4" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 5 430 | } 431 | -------------------------------------------------------------------------------- /EDA_embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "b7e011c3-5eb8-454e-b54e-24f5b96bf80c", 7 | "metadata": { 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import tmap as tm\n", 15 | "import torch\n", 16 | "import time\n", 17 | "import numpy as np\n", 18 | "from numpy.random import RandomState\n", 19 | "import pandas as pd\n", 20 | "import re\n", 21 | "\n", 22 | "import graphistry\n", 23 | "\n", 24 | "import os\n", 25 | "from dotenv import load_dotenv\n", 26 | "load_dotenv() # take environment variables from .env.\n", 27 | "\n", 28 | "from pyvis import network as net\n", 29 | "import networkx as nx\n", 30 | "from sklearn import manifold\n", 31 | "from sklearn.decomposition import PCA\n", 32 | "\n", 33 | "%matplotlib inline\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "from matplotlib import ticker\n", 36 | "plt.rcParams['figure.figsize'] = [20, 20]" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "id": "8e0a7d83-dfc6-45ac-9cfa-b02e4bd401e6", 43 | "metadata": { 44 | "pycharm": { 45 | "name": "#%%\n" 46 | } 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "graphistry.register(api=3, protocol=\"https\", server=\"hub.graphistry.com\", username=os.environ['GRAPHISTRY_USERNAME'], password=os.environ['GRAPHISTRY_PASSWORD'])" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "id": "946c923d-2778-400b-a80b-4f4bf3d1344b", 57 | "metadata": { 58 | "pycharm": { 59 | "name": "#%%\n" 60 | } 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "import wandb\n", 65 | "from transformers import GPT2Tokenizer\n", 66 | "from soft_prompt_tuning.soft_prompt_opt import ParaphraseOPT" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "c88b40fa-726c-4062-95d6-3f14fc4cc60d", 72 | "metadata": { 73 | "pycharm": { 74 | "name": "#%% md\n" 75 | } 76 | }, 77 | "source": [ 78 | "# Init embedding space" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "id": "24cd0cac-59b8-4d04-9b8a-32e1eef717c3", 85 | "metadata": { 86 | "pycharm": { 87 | "name": "#%%\n" 88 | } 89 | }, 90 | "outputs": [ 91 | { 92 | "name": "stderr", 93 | "output_type": "stream", 94 | "text": [ 95 | "\u001B[34m\u001B[1mwandb\u001B[0m: Currently logged in as: \u001B[33mclyde013\u001B[0m. Use \u001B[1m`wandb login --relogin`\u001B[0m to force relogin\n" 96 | ] 97 | }, 98 | { 99 | "data": { 100 | "text/html": [ 101 | "wandb version 0.12.18 is available! To upgrade, please run:\n", 102 | " $ pip install wandb --upgrade" 103 | ], 104 | "text/plain": [ 105 | "" 106 | ] 107 | }, 108 | "metadata": {}, 109 | "output_type": "display_data" 110 | }, 111 | { 112 | "data": { 113 | "text/html": [ 114 | "Tracking run with wandb version 0.12.17" 115 | ], 116 | "text/plain": [ 117 | "" 118 | ] 119 | }, 120 | "metadata": {}, 121 | "output_type": "display_data" 122 | }, 123 | { 124 | "data": { 125 | "text/html": [ 126 | "Run data is saved locally in C:\\Users\\weipy\\OneDrive\\Documents\\GitHub\\Paraphrase-OPT\\wandb\\run-20220622_165114-37zd2lmm" 127 | ], 128 | "text/plain": [ 129 | "" 130 | ] 131 | }, 132 | "metadata": {}, 133 | "output_type": "display_data" 134 | }, 135 | { 136 | "data": { 137 | "text/html": [ 138 | "Syncing run test-model to Weights & Biases (docs)
" 139 | ], 140 | "text/plain": [ 141 | "" 142 | ] 143 | }, 144 | "metadata": {}, 145 | "output_type": "display_data" 146 | }, 147 | { 148 | "data": { 149 | "application/vnd.jupyter.widget-view+json": { 150 | "model_id": "206dd6934b3d4b8eb1bccf0b165b39c2", 151 | "version_major": 2, 152 | "version_minor": 0 153 | }, 154 | "text/plain": [ 155 | "Downloading: 0%| | 0.00/653 [00:00\n", 312 | "\n", 325 | "\n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | "
indexxytitle
000.51655787.0414120
111.97146980.5935590
2270.881744-50.7330250
331.03274087.2926030
4473.105972-49.0376400
...............
504375043775.097878-53.2998502
504385043853.05204469.5669022
504395043973.531258-49.4637722
504405044046.07576054.1895222
504415044173.672615-47.9431572
\n", 415 | "

50442 rows × 4 columns

\n", 416 | "" 417 | ], 418 | "text/plain": [ 419 | " index x y title\n", 420 | "0 0 0.516557 87.041412 0\n", 421 | "1 1 1.971469 80.593559 0\n", 422 | "2 2 70.881744 -50.733025 0\n", 423 | "3 3 1.032740 87.292603 0\n", 424 | "4 4 73.105972 -49.037640 0\n", 425 | "... ... ... ... ...\n", 426 | "50437 50437 75.097878 -53.299850 2\n", 427 | "50438 50438 53.052044 69.566902 2\n", 428 | "50439 50439 73.531258 -49.463772 2\n", 429 | "50440 50440 46.075760 54.189522 2\n", 430 | "50441 50441 73.672615 -47.943157 2\n", 431 | "\n", 432 | "[50442 rows x 4 columns]" 433 | ] 434 | }, 435 | "execution_count": 33, 436 | "metadata": {}, 437 | "output_type": "execute_result" 438 | } 439 | ], 440 | "source": [ 441 | "df = pd.DataFrame(embedded, columns=['x', 'y']).reset_index()\n", 442 | "df[\"title\"] = np.concatenate([np.zeros(50272), np.ones(111), np.full(59, 2)])\n", 443 | "df[\"title\"] = df[\"title\"].astype(int)\n", 444 | "df" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 34, 450 | "id": "b0e899b1", 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/html": [ 456 | "
\n", 457 | "\n", 470 | "\n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | "
sourcetarget
000
\n", 486 | "
" 487 | ], 488 | "text/plain": [ 489 | " source target\n", 490 | "0 0 0" 491 | ] 492 | }, 493 | "execution_count": 34, 494 | "metadata": {}, 495 | "output_type": "execute_result" 496 | } 497 | ], 498 | "source": [ 499 | "df_edges = pd.DataFrame({'source':[0], 'target':[0]})\n", 500 | "df_edges" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 35, 506 | "id": "9c922176", 507 | "metadata": {}, 508 | "outputs": [ 509 | { 510 | "data": { 511 | "text/html": [ 512 | "\n", 513 | " \n", 520 | " \n", 521 | " \n", 526 | " " 527 | ], 528 | "text/plain": [ 529 | "" 530 | ] 531 | }, 532 | "execution_count": 35, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | } 536 | ], 537 | "source": [ 538 | "graph = graphistry.bind(source=\"source\", destination=\"target\", point_x=\"x\", point_y=\"y\", point_title=\"index\")\n", 539 | "graph = graph.edges(df_edges).nodes(df, 'index')\n", 540 | "graph = graph.encode_point_color('title', categorical_mapping={'0': '#ff9999', '2': '#99F', '1': '#32a834'}, default_mapping='silver')\n", 541 | "graph = graph.encode_point_size('title', categorical_mapping={'0': 1, '2': 3, '1': 3}, default_mapping=1)\n", 542 | "graph = graph.settings(url_params={\n", 543 | " 'play': 0,\n", 544 | " 'menu': True, 'info': False,\n", 545 | " 'showArrows': False,\n", 546 | " 'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,\n", 547 | " 'edgeOpacity': 0.5, 'pointOpacity': 0.9,\n", 548 | " 'lockedX': True, 'lockedY': True, 'lockedR': False,\n", 549 | " 'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,\n", 550 | " 'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,\n", 551 | " 'showLabels': True, 'showLabelOnHover': True,\n", 552 | " 'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,\n", 553 | " 'pointsOfInterestMax': 0\n", 554 | " })\n", 555 | "graph.plot()" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "id": "664a18b2-d868-42d6-a4c4-5370933b0017", 561 | "metadata": { 562 | "pycharm": { 563 | "name": "#%% md\n" 564 | } 565 | }, 566 | "source": [ 567 | "The inconsistencies in the relative positions of the embedding points could be attributed to phase 4 in the tmap algorithm described [here](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-0416-x#Sec2). They conduct kruskal's to get a constructed MST tree, reducing computation times by large margins, before using a spring based graph layout alogrithm (probably the Fruchterman-Reingold force-directed algorithm) to plot out the points. But the tradeoff is that if points are not connected on the MST then their relative distances are not accounted for when placing points on the graph, only the neighbours. While this *might* yield multiple locally optimal solutions, it is possible that relative distances between global clusters of points are not accounted for, resulting in vastly different positionings of embedding points. While their neighbours might always be close together no matter the random initialisation, their global position might vary, as there is no way to solve for a deterministic solution without the connections of a fully connected graph (which is too expensive to compute)." 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "id": "00b35fad-f581-47d6-837a-7d784d39b5ec", 573 | "metadata": { 574 | "pycharm": { 575 | "name": "#%% md\n" 576 | } 577 | }, 578 | "source": [ 579 | "Have to create both an MST and the complete graph. We have 2 options:\n", 580 | "1. using the weights of the complete graph to generate a layout and then only plotting the MST's edge connections (very computationally expensive)\n", 581 | "1. following the original paper implementation of using MST for both layout generation and plotting edge connections.\n", 582 | "\n", 583 | "When using `spring_connection` we have to invert the edge weights as the edge weights become spring attractive coefficients, whereas in other layouts such as `kamada_kawai` edge weights are used as cost functions." 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "id": "efd85838-5dd4-4dde-954c-56ae15f2623f", 589 | "metadata": { 590 | "pycharm": { 591 | "name": "#%% md\n" 592 | } 593 | }, 594 | "source": [ 595 | "# Custom class" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "id": "612cc3fc-57d7-4171-adb2-a68627c4a190", 601 | "metadata": { 602 | "pycharm": { 603 | "name": "#%% md\n" 604 | } 605 | }, 606 | "source": [ 607 | "Create a custom class to deal with visualisations. Should be able to initialise from a model's fixed embeddings. Then in the init function construct a minhash encoder and lsh forest (with a seed), and then using the lsh forest create an initial MST use the networkx layouts to find the x, y positions, then use those as anchor points. Should then be able to take in inputs of learned embeddings with another function, and then using `query linear scan` find the closest neighbours of all the passed in learned embeddings, from there add the points to the MST, either without trying to form another MST (just leave knn connections) or find a locally optimal MST solution. Output graphs should ideally be pyvis networks as they allow for interactive visualisations." 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": null, 613 | "id": "11cd1a62", 614 | "metadata": { 615 | "pycharm": { 616 | "name": "#%%\n" 617 | } 618 | }, 619 | "outputs": [], 620 | "source": [ 621 | "class vis(): \n", 622 | " def __init__(self, fixed_embeddings, dims:int=512, load_path:str=None, save_path:str=r\"visualisations/vis_lf_fixed.dat\", seed:int=69, verbose:bool=True):\n", 623 | " if verbose: print(\"seeding...\")\n", 624 | " self.seed = seed\n", 625 | " np.random.seed(self.seed)\n", 626 | " \n", 627 | " self.fixed_embeddings = fixed_embeddings\n", 628 | " self.enc = tm.Minhash(self.fixed_embeddings.size(dim=1), seed=self.seed, sample_size=dims)\n", 629 | " self.lf = tm.LSHForest(dims * 2, 128)\n", 630 | " \n", 631 | " # init the LSHForest\n", 632 | " if load_path is None:\n", 633 | " tmp = []\n", 634 | " if verbose: print(\"batch add and indexing...\")\n", 635 | " for i in fixed_embeddings:\n", 636 | " tmp.append(tm.VectorFloat(i.tolist()))\n", 637 | " self.lf.batch_add(self.enc.batch_from_weight_array(tmp))\n", 638 | " self.lf.index()\n", 639 | " self.lf.store(save_path)\n", 640 | " else:\n", 641 | " if verbose: print(f\"loading from {load_path}...\")\n", 642 | " self.lf.restore(load_path)\n", 643 | " \n", 644 | " # Construct the k-nearest neighbour graph\n", 645 | " if verbose: print(\"Getting KNN graph...\")\n", 646 | " knng_from = tm.VectorUint()\n", 647 | " knng_to = tm.VectorUint()\n", 648 | " knng_weight = tm.VectorFloat()\n", 649 | " _ = self.lf.get_knn_graph(knng_from, knng_to, knng_weight, 10) \n", 650 | "\n", 651 | " # find the MST of the knn graph\n", 652 | " if verbose: print(\"Finding MST...\")\n", 653 | " self.g_mst = self.create_mst([i for i in zip(knng_from, knng_to, knng_weight) if i[0] != i[1]])\n", 654 | "\n", 655 | " # find x, y positions of the fixed embeddings layout\n", 656 | " if verbose: print(\"Generating layout...\")\n", 657 | " self.pos = nx.nx_agraph.graphviz_layout(self.g_mst, prog=\"sfdp\")\n", 658 | " self.fixed = list(self.pos.keys())\n", 659 | " \n", 660 | " def graph_learned_embeddings(self, learned_embeddings, type_str:str, g: nx.Graph=None):\n", 661 | " # create deepcopy of g_mst\n", 662 | " if g is None:\n", 663 | " g = nx.Graph(self.g_mst)\n", 664 | " # index to begin from (since indexes are 0 indexed we start from len)\n", 665 | " index = len(g)\n", 666 | " for i in learned_embeddings:\n", 667 | " query_hash = self.enc.from_weight_array(tm.VectorFloat(i.tolist()))\n", 668 | " # query_linear_scan returns list of tuples(weight, neighbour). invert the weights because spring layout.\n", 669 | " scan = self.lf.query_linear_scan(query_hash, 1)[0]\n", 670 | " g.add_edge(index, scan[1], weight=1-scan[0])\n", 671 | " g.nodes[index]['type_str'] = type_str\n", 672 | " index += 1\n", 673 | " \n", 674 | " return g\n", 675 | " \n", 676 | " # kruskals algorithm for finding MST\n", 677 | " def create_mst(self, edgelist):\n", 678 | " self.par = [i for i in range(0, self.fixed_embeddings.size(dim=0)+1)]\n", 679 | " self.rnk = [0 for i in range(0, self.fixed_embeddings.size(dim=0)+1)]\n", 680 | " edges = sorted(edgelist, key=lambda x:x[2])\n", 681 | " g_mst = nx.Graph()\n", 682 | "\n", 683 | " for edge in edges:\n", 684 | " x = edge[0]\n", 685 | " y = edge[1]\n", 686 | "\n", 687 | " if self._find_par(x) != self._find_par(y):\n", 688 | " # append edge to the mst. invert the weights because spring layout.\n", 689 | " g_mst.add_edge(edge[0], edge[1], weight=edge[2])\n", 690 | " self._join(x, y)\n", 691 | "\n", 692 | " return g_mst\n", 693 | " \n", 694 | " def _find_par(self, i):\n", 695 | " if self.par[i] == i:\n", 696 | " return i\n", 697 | " self.par[i] = self._find_par(self.par[i])\n", 698 | " return self.par[i]\n", 699 | "\n", 700 | " def _join(self, x, y):\n", 701 | " x = self._find_par(x)\n", 702 | " y = self._find_par(y)\n", 703 | " if x == y:\n", 704 | " return\n", 705 | " if self.rnk[x] < self.rnk[y]:\n", 706 | " self.par[x] = y\n", 707 | " else:\n", 708 | " self.par[y] = x\n", 709 | " if self.rnk[x] == self.rnk[y]:\n", 710 | " self.rnk[x] += 1" 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": null, 716 | "id": "d789c0b1", 717 | "metadata": { 718 | "pycharm": { 719 | "name": "#%%\n" 720 | } 721 | }, 722 | "outputs": [], 723 | "source": [ 724 | "visualisation = vis(original_embeddings, load_path=r\"visualisations/vis_lf_fixed.dat\")" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "id": "62303dfd", 731 | "metadata": { 732 | "pycharm": { 733 | "name": "#%%\n" 734 | } 735 | }, 736 | "outputs": [], 737 | "source": [ 738 | "g = visualisation.graph_learned_embeddings(learned_embeddings_111, \"111\")\n", 739 | "g = visualisation.graph_learned_embeddings(learned_embeddings_59, \"59\", g)\n", 740 | "pos = nx.spring_layout(g, fixed=visualisation.g_mst.nodes, pos=visualisation.pos, k=0.0001)\n", 741 | "\n", 742 | "mapping = {v: re.sub(r'(\\\\xc4|\\\\xa0)|[\\'\\\"\\\\]', '', repr(k.encode(\"utf-8\"))[2:-1]) for k, v in tokenizer.get_vocab().items()}\n", 743 | "\n", 744 | "for n, p in pos.items():\n", 745 | " g.nodes[n]['x'] = float(p[0])\n", 746 | " g.nodes[n]['y'] = float(p[1])\n", 747 | " # there are some unreachable tokens as the tokenizer's vocab size does not match that of the config\n", 748 | " g.nodes[n]['title'] = mapping[n] if n < len(tokenizer) else f\"learned embedding {n-len(visualisation.g_mst.nodes)}\"\n", 749 | " # denote learned vs fixed embeddings\n", 750 | " if n < len(visualisation.g_mst.nodes):\n", 751 | " g.nodes[n]['type_str'] = 'F'" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "id": "016f5aaa", 757 | "metadata": { 758 | "pycharm": { 759 | "name": "#%% md\n" 760 | } 761 | }, 762 | "source": [ 763 | "I really cannot tell if the bugginess comes from networkx integration with graphistry or what, so to be safe everything is being converted to pandas dataframes and fed in that way, which has actual documentation support." 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": null, 769 | "id": "da9d1479", 770 | "metadata": { 771 | "pycharm": { 772 | "name": "#%%\n" 773 | } 774 | }, 775 | "outputs": [], 776 | "source": [ 777 | "edges = nx.to_pandas_edgelist(g)\n", 778 | "nodes = pd.DataFrame.from_dict(dict(g.nodes(data=True)), orient='index').reset_index(level=0)\n", 779 | "nodes" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": null, 785 | "id": "cf3f27ec", 786 | "metadata": { 787 | "pycharm": { 788 | "name": "#%%\n" 789 | } 790 | }, 791 | "outputs": [], 792 | "source": [ 793 | "graph = graphistry.bind(source='source', destination='target', point_x=\"x\", point_y=\"y\", point_title=\"title\")\n", 794 | "graph = graph.edges(edges).nodes(nodes, 'index')\n", 795 | "graph = graph.encode_point_color('type_str', categorical_mapping={'F': '#ff9999', '59': '#99F', '111': '#32a834'}, default_mapping='silver')\n", 796 | "graph = graph.encode_point_size('type_str', categorical_mapping={'F': 1, '59': 3, '111': 3}, default_mapping=1)\n", 797 | "graph = graph.settings(url_params={\n", 798 | " 'play': 0,\n", 799 | " 'menu': True, 'info': False,\n", 800 | " 'showArrows': False,\n", 801 | " 'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,\n", 802 | " 'edgeOpacity': 0.5, 'pointOpacity': 0.9,\n", 803 | " 'lockedX': True, 'lockedY': True, 'lockedR': False,\n", 804 | " 'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,\n", 805 | " 'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,\n", 806 | " 'showLabels': True, 'showLabelOnHover': True,\n", 807 | " 'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,\n", 808 | " 'pointsOfInterestMax': 0\n", 809 | " })\n", 810 | "graph.plot()" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 85, 816 | "id": "b876e78c-732b-4943-a170-741dbeb792d1", 817 | "metadata": { 818 | "pycharm": { 819 | "name": "#%%\n" 820 | }, 821 | "scrolled": false, 822 | "tags": [] 823 | }, 824 | "outputs": [ 825 | { 826 | "data": { 827 | "text/html": [ 828 | "\n", 829 | " \n", 836 | " \n", 837 | " \n", 842 | " " 843 | ], 844 | "text/plain": [ 845 | "" 846 | ] 847 | }, 848 | "execution_count": 85, 849 | "metadata": {}, 850 | "output_type": "execute_result" 851 | } 852 | ], 853 | "source": [ 854 | "graph = graphistry.bind(source='source', destination='target', point_x=\"x\", point_y=\"y\", point_title=\"title\")\n", 855 | "graph = graph.edges(edges).nodes(nodes, 'index')\n", 856 | "graph = graph.encode_point_color('type_str', categorical_mapping={'F': '#ff9999', '59': '#99F', '111': '#32a834'}, default_mapping='silver')\n", 857 | "graph = graph.encode_point_size('type_str', categorical_mapping={'F': 1, '59': 3, '111': 3}, default_mapping=1)\n", 858 | "graph = graph.settings(url_params={\n", 859 | " 'play': 0,\n", 860 | " 'menu': True, 'info': False,\n", 861 | " 'showArrows': False,\n", 862 | " 'pointSize': 0.07, 'edgeCurvature': 0.01, 'edgeSize': 1.0,\n", 863 | " 'edgeOpacity': 0.5, 'pointOpacity': 0.9,\n", 864 | " 'lockedX': True, 'lockedY': True, 'lockedR': False,\n", 865 | " 'linLog': False, 'strongGravity': False, 'dissuadeHubs': False,\n", 866 | " 'edgeInfluence': 1.0, 'precisionVsSpeed': 1.0, 'gravity': 1.0, 'scalingRatio': 0.5,\n", 867 | " 'showLabels': True, 'showLabelOnHover': True,\n", 868 | " 'showPointsOfInterest': False, 'showPointsOfInterestLabel': False, 'showLabelPropertiesOnHover': True,\n", 869 | " 'pointsOfInterestMax': 0\n", 870 | " })\n", 871 | "graph.plot()" 872 | ] 873 | } 874 | ], 875 | "metadata": { 876 | "kernelspec": { 877 | "display_name": "Python 3 (ipykernel)", 878 | "language": "python", 879 | "name": "python3" 880 | }, 881 | "language_info": { 882 | "codemirror_mode": { 883 | "name": "ipython", 884 | "version": 3 885 | }, 886 | "file_extension": ".py", 887 | "mimetype": "text/x-python", 888 | "name": "python", 889 | "nbconvert_exporter": "python", 890 | "pygments_lexer": "ipython3", 891 | "version": "3.10.4" 892 | } 893 | }, 894 | "nbformat": 4, 895 | "nbformat_minor": 5 896 | } --------------------------------------------------------------------------------