├── models ├── __init__.py ├── clip_emb_aug.py ├── linear_emb_aug.py ├── cross_encoder.py ├── bert.py ├── bert_trainer.py ├── rnn.py └── latent_aug.py ├── .DS_Store ├── .gitattributes ├── generate_prompts ├── generate_background.sh ├── generate_images.py └── upsample_captions.py ├── train_background.sh ├── .vscode ├── settings.json └── launch.json ├── README.md ├── requirements.txt ├── cache_datasets.py ├── train_bert.py ├── .gitignore ├── utils.py ├── notebooks ├── seq2seq_dataset.ipynb ├── create_dataset.ipynb ├── bert.ipynb └── finetune.ipynb ├── train_linear_emb_aug.py ├── train_emb.py ├── train_latent_aug.py ├── train_cross_encoder.py ├── train_bert_seq2seq.py └── train_t5.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brianfitzgerald/superprompt/HEAD/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /generate_prompts/generate_background.sh: -------------------------------------------------------------------------------- 1 | nohup python upsample_captions.py --local & 2 | > nohup.out 3 | tail -f nohup.out -------------------------------------------------------------------------------- /train_background.sh: -------------------------------------------------------------------------------- 1 | fn_name="train_cross_encoder" 2 | pkill -9 -f $fn_name.py 3 | nohup python $fn_name.py --use_wandb & 4 | > nohup.out 5 | tail -f nohup.out -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "ms-python.black-formatter" 4 | }, 5 | "python.formatting.provider": "none", 6 | "python.analysis.typeCheckingMode": "basic" 7 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuperPrompt 2 | 3 | Set of models related to augmenting prompt-based image generation model inputs. 4 | 5 | ### Latent augmentation 6 | 7 | ```bash 8 | 9 | wget http://images.cocodataset.org/zips/val2017.zip 10 | 11 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf<=3.20.1 2 | torch 3 | transformers 4 | python-dotenv 5 | datasets 6 | wandb 7 | evaluate 8 | numpy 9 | fire 10 | tabulate==0.9.0 11 | pillow==10.0.1 12 | spacy 13 | diffusers==0.21.4 14 | accelerate 15 | torchinfo 16 | lpips 17 | webdataset 18 | einops 19 | omegaconf 20 | bitsandbytes 21 | scikit-learn -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Train BERT", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "train_bert.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "name": "Train Seq2Seq", 17 | "type": "python", 18 | "request": "launch", 19 | "program": "train_seq2seq.py", 20 | "console": "integratedTerminal", 21 | "justMyCode": true 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /models/clip_emb_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPEncoder, CLIPEncoderLayer 5 | import copy 6 | 7 | class CLIPEmbeddingAugmenter(nn.Module): 8 | def __init__(self, clip_model: CLIPTextModel): 9 | super().__init__() 10 | self.unfrozen_encoder_layer: CLIPEncoder = CLIPEncoderLayer(copy.deepcopy(clip_model.config)) 11 | self.unfrozen_encoder_layer.to(clip_model.device) 12 | self.unfrozen_encoder_layer.train() 13 | self.unfrozen_encoder_layer.requires_grad_(True) 14 | 15 | def forward(self, masked_emb: torch.Tensor): 16 | # x is the last hidden layer of clip text encoder 17 | emb_enc = self.unfrozen_encoder_layer(masked_emb, None, None)[0] 18 | return emb_enc 19 | -------------------------------------------------------------------------------- /cache_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from diffusers import ( 3 | StableDiffusionPipeline, 4 | ) 5 | import torch 6 | 7 | dataset = load_dataset("roborovski/diffusiondb-seq2seq") 8 | dataset = load_dataset("THUDM/ImageRewardDB", "4k", verification_mode="no_checks") 9 | dataset = load_dataset("bentrevett/multi30k") 10 | dataset = load_dataset( 11 | "huggan/CelebA-HQ", 12 | data_files={"train": "data/train-00000-of-00068.parquet"}, 13 | verification_mode="no_checks", 14 | ) 15 | dataset = load_dataset( 16 | "roborovski/celeba-faces-captioned", 17 | data_files={ 18 | "train": [ 19 | "data/train-00000-of-00036-416615b669d11cd3.parquet", 20 | "data/train-00001-of-00036-411c3786c0f93eac.parquet", 21 | ] 22 | }, 23 | verification_mode="no_checks", 24 | ) 25 | pipe = StableDiffusionPipeline.from_pretrained( 26 | "runwayml/stable-diffusion-v1-5", 27 | torch_dtype=torch.float16, 28 | safety_checker=None, 29 | ) 30 | -------------------------------------------------------------------------------- /models/linear_emb_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from typing import List 5 | import torch.nn.functional as F 6 | 7 | 8 | class LinearEmbAug(nn.Module): 9 | def __init__( 10 | self, 11 | n_tokens: int = 77, 12 | context_dim: int = 768, 13 | layers: List[int] = [1024], 14 | device="cpu", 15 | ): 16 | super().__init__() 17 | self.n_tokens = n_tokens 18 | self.context_dim = context_dim 19 | 20 | self.dropout = nn.Dropout(0.1) 21 | fc_layers = [] 22 | fc_layers.append(torch.nn.Linear(context_dim, layers[0])) 23 | for _, (in_size, out_size) in enumerate(zip(layers[:-1], layers[1:])): 24 | fc_layers.append(torch.nn.Linear(in_size, out_size)) 25 | fc_layers.append(torch.nn.ReLU()) 26 | fc_layers.append(torch.nn.Linear(layers[-1], context_dim)) 27 | self.embed_fc = torch.nn.Sequential(*fc_layers).to(device) 28 | self.norm = nn.LayerNorm(context_dim).to(device) 29 | 30 | def forward(self, x): 31 | if self.training: 32 | x = self.dropout(x) 33 | x = self.embed_fc(x) 34 | x = self.norm(x) 35 | return x 36 | -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import gc 3 | from superprompt.models.bert import BERT 4 | from models.trainer import BERTTrainer 5 | import wandb 6 | from datasets import load_dataset 7 | from transformers import BertTokenizer, DataCollatorForLanguageModeling 8 | import sys 9 | import os 10 | 11 | from utils import should_use_wandb, sample_prompt_pairs 12 | 13 | 14 | gc.collect() 15 | 16 | 17 | class Args(Namespace): 18 | hidden = 256 19 | batch_size = 64 20 | layers = 8 21 | attn_heads = 8 22 | adam_weight_decay = 0.01 23 | adam_beta1 = 0.9 24 | output_path = "/home/ubuntu/superprompt/saved" 25 | epochs = 500 26 | log_freq = 32 * 2 27 | save_freq = 32 * 10 28 | valid_freq = 32 * 5 29 | adam_beta2 = 0.999 30 | num_workers = 4 31 | lr = 3e-4 32 | max_len = 128 33 | use_wandb = should_use_wandb() 34 | 35 | 36 | if __name__ != "__main__": 37 | sys.exit(0) 38 | 39 | dataset = load_dataset( 40 | "Gustavosta/Stable-Diffusion-Prompts", 41 | streaming=True, 42 | ) 43 | tokenizer: BertTokenizer = BertTokenizer.from_pretrained( 44 | "bert-base-uncased", use_fast=True 45 | ) 46 | collator = DataCollatorForLanguageModeling( 47 | tokenizer=tokenizer, mlm=True, mlm_probability=0.15 48 | ) 49 | dataset = dataset.map( 50 | lambda x: tokenizer( 51 | x["Prompt"], 52 | truncation=True, 53 | padding="max_length", 54 | max_length=Args.max_len, 55 | return_tensors="pt", 56 | ), 57 | batched=True, 58 | ) 59 | dataset = dataset.remove_columns(["Prompt"]) 60 | 61 | 62 | bert = BERT( 63 | tokenizer.vocab_size, 64 | hidden=Args.hidden, 65 | n_layers=Args.layers, 66 | attn_heads=Args.attn_heads, 67 | max_len=Args.max_len, 68 | ) 69 | 70 | if Args.use_wandb: 71 | wandb.init(config=Args, project="superprompt") 72 | wandb.watch(bert, log_freq=Args.log_freq) 73 | 74 | print("Creating BERT Trainer") 75 | trainer = BERTTrainer( 76 | bert, 77 | tokenizer, 78 | collator, 79 | dataset["train"], 80 | dataset["test"], 81 | Args.lr, 82 | betas=(Args.adam_beta1, Args.adam_beta2), 83 | weight_decay=Args.adam_weight_decay, 84 | max_len=Args.max_len, 85 | log_freq=Args.log_freq, 86 | valid_freq=Args.valid_freq, 87 | batch_size=Args.batch_size, 88 | use_wandb=Args.use_wandb, 89 | ) 90 | 91 | for epoch in range(Args.epochs): 92 | trainer.train(epoch) 93 | -------------------------------------------------------------------------------- /models/cross_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | from typing import List 7 | from torch import Tensor 8 | 9 | def cos_sim(a: Tensor, b: Tensor): 10 | """ 11 | Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. 12 | :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) 13 | """ 14 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1) 15 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1) 16 | return torch.mm(a_norm, b_norm.transpose(0, 1)) 17 | 18 | def pooling( 19 | model_output: Tensor, attention_mask: Tensor 20 | ) -> Tensor: 21 | token_embeddings = model_output[0] 22 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 23 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 24 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 25 | return sum_embeddings / sum_mask 26 | 27 | def multiple_negatives_ranking_loss( 28 | embeddings_a: Tensor, 29 | embeddings_b: Tensor, 30 | scale: float = 20.0, 31 | ) -> Tensor: 32 | """ 33 | Cross entropy between a[i] and b[i] where b is the batch of embeddings. 34 | """ 35 | 36 | # [bsz, bsz] 37 | scores = cos_sim(embeddings_a, embeddings_b) * scale 38 | # label here is the index of the positive example for a given example 39 | labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device) # Example a[i] should match with b[i] 40 | return scores, labels 41 | 42 | class CrossEncoder(nn.Module): 43 | def __init__(self, device: torch.device) -> None: 44 | super().__init__() 45 | self.tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") 46 | self.language_model = AutoModel.from_pretrained( 47 | "distilroberta-base" 48 | ) 49 | self.language_model.to(device) 50 | self.language_model.train() 51 | self.embedding_dimension: int = self.language_model.config.hidden_size 52 | self.device = device 53 | 54 | def forward(self, input_text: List[str],) -> Tensor: 55 | tokenizer_kwargs = { 56 | "return_tensors": "pt", 57 | "padding": True, 58 | "truncation": True, 59 | "max_length": 128, 60 | } 61 | tokenized = self.tokenizer(input_text, **tokenizer_kwargs).to(self.device) 62 | out = self.language_model(**tokenized) 63 | out = pooling(out, tokenized["attention_mask"]) 64 | 65 | return out -------------------------------------------------------------------------------- /generate_prompts/generate_images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import Dataset, Features 3 | from datasets import Image as ImageFeature 4 | from datasets import Value, load_dataset 5 | from diffusers import DiffusionPipeline 6 | 7 | BATCH_SIZE = 4 8 | 9 | 10 | def main(): 11 | print("Loading dataset...") 12 | drawbench = load_dataset( 13 | "sayakpaul/drawbench-upsampled-zephyr-7b-alpha", split="train" 14 | ) 15 | 16 | print("Loading pipeline...") 17 | ckpt_id = "stabilityai/stable-diffusion-xl-base-1.0" 18 | pipe = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.float16).to( 19 | "cuda" 20 | ) 21 | pipe.set_progress_bar_config(disable=True) 22 | 23 | seed = 0 24 | generator = torch.manual_seed(seed) 25 | 26 | print("Running inference...") 27 | main_dict = {} 28 | regular_caption_paths = [] 29 | upsampled_caption_paths = [] 30 | 31 | for i in range(0, len(drawbench), BATCH_SIZE): 32 | samples = drawbench[i : i + BATCH_SIZE] 33 | 34 | # Regular captions. 35 | prompts = list(samples["Prompt"]) 36 | images = pipe(prompts, generator=generator, num_inference_steps=25).images 37 | for j in range(len(images)): 38 | img_name = f"sdxl_{i + j}.png" 39 | images[j].save(img_name) 40 | regular_caption_paths.append(img_name) 41 | 42 | # Upsampled captions. 43 | usampled_prompts = list(samples["Upsampled Prompt"]) 44 | images = pipe( 45 | usampled_prompts, generator=generator, num_inference_steps=25 46 | ).images 47 | for j in range(len(images)): 48 | img_name = f"sdxl_upsampled_prompt_{i + j}.png" 49 | images[j].save(img_name) 50 | upsampled_caption_paths.append(img_name) 51 | 52 | for i in range(len(drawbench)): 53 | sample = drawbench[i] 54 | main_dict.update( 55 | { 56 | i: { 57 | "Prompt": sample["Prompt"], 58 | "Image": regular_caption_paths[i], 59 | "Upsampled_Prompt": sample["Upsampled Prompt"], 60 | "Image_With_Upsampled_Prompt": upsampled_caption_paths[i], 61 | } 62 | } 63 | ) 64 | 65 | def generation_fn(): 66 | for i in main_dict: 67 | entry = main_dict[i] 68 | yield { 69 | "Prompt": entry["Prompt"], 70 | "Image": entry["Image"], 71 | "Upsampled_Prompt": entry["Upsampled_Prompt"], 72 | "Image_With_Upsampled_Prompt": entry["Image_With_Upsampled_Prompt"], 73 | "model_name": ckpt_id, 74 | "seed": seed, 75 | } 76 | 77 | print("Preparing HF dataset...") 78 | ds = Dataset.from_generator( 79 | generation_fn, 80 | features=Features( 81 | Prompt=Value("string"), 82 | Image=ImageFeature(), 83 | Upsampled_Prompt=Value("string"), 84 | Image_With_Upsampled_Prompt=ImageFeature(), 85 | model_name=Value("string"), 86 | seed=Value("int64"), 87 | ), 88 | ) 89 | ds_id = "drawbench-sdxl" 90 | ds.push_to_hub(ds_id) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | 154 | wandb/* 155 | 156 | *.parquet 157 | nohup.out 158 | *.pt 159 | out 160 | 161 | vae.safetensors 162 | *.zip 163 | 164 | image_reward_processed 165 | *_cache 166 | 167 | generate_prompts/*.csv -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import platform 4 | from typing import Dict, List 5 | import numpy as np 6 | 7 | 8 | def get_available_device(): 9 | device = torch.device("cpu") 10 | if torch.cuda.is_available(): 11 | device = torch.device("cuda") 12 | elif torch.backends.mps.is_available(): 13 | device = torch.device("mps") 14 | return device 15 | 16 | 17 | def should_use_wandb(): 18 | if os.environ.get("NO_WANDB", False): 19 | return False 20 | return os.environ.get("USER") == "ubuntu" and platform.system().lower() == "linux" 21 | 22 | 23 | sample_prompt_pairs = [ 24 | ( 25 | "portait witch hyper background", 26 | "portait of mystical witch, hyper detailed, flowing background, intricate and detailed, trippy, 8 k ", 27 | ), 28 | ( 29 | "painting ghost riders sky sunrise wlop tooth wu charlie russell", 30 | "a beautiful painting of ghost riders in the sky, sunrise, by wlop, tooth wu and charlie russell", 31 | ), 32 | ( 33 | "princess dance fairy gold dress sky stanley artgerm lau greg rutkowski victo ngai alphonse loish norman", 34 | "chinese princess, dance, fairy, beautiful, stunning, red and gold dress, spinning in the sky, by stanley artgerm lau, greg rutkowski, victo ngai, alphonse mucha, loish, norman rockwell", 35 | ), 36 | ( 37 | "scene painting girl balustrade dress pattern seaside resort buildings background dusk clouds seagulls artstation krenz cushart alphonse maria mucha point composition k resolution hand illustration style", 38 | "a beautiful scene painting of a young girl, with a maiden balustrade in a white dress with a beautiful pattern, a beautiful deserted seaside resort with many wooden buildings in the background, romantic dusk, beautiful clouds, seagulls, trending on artstation, by krenz cushart and alphonse maria mucha, three - point composition, 8 k resolution, hand - painted, illustration style", 39 | ), 40 | ( 41 | "portrait samurai goth punk colors style alexander mcqueen hyper art bill sienkiewicz artstation background", 42 | "close up portrait of old samurai, goth punk, vibrant yellow colors, surreal, french baroque style by alexander mcqueen, hyper detailed, cinematic, art by bill sienkiewicz trending artstation, remove red background", 43 | ), 44 | ] 45 | 46 | sample_translate_pairs = [ 47 | ("Ich bin ein Mann mit einem Pferd", "I am a man with a horse"), 48 | ("Ich möchte den Gipfel des Berges sehen", "I wish to see the top of the mountain"), 49 | ] 50 | 51 | 52 | def get_model_gradient_norm(model): 53 | total_norm = 0 54 | for p in model.parameters(): 55 | if p.grad is not None and p.grad.data is not None: 56 | param_norm = p.grad.data.norm(2) 57 | total_norm += param_norm.item() ** 2 58 | total_norm = total_norm ** (1.0 / 2) 59 | return total_norm 60 | 61 | 62 | def weights_biases_sum(model): 63 | total_weight_sum = 0.0 64 | for param in model.parameters(): 65 | total_weight_sum += param.data.sum().item() 66 | return total_weight_sum 67 | 68 | 69 | def split_subject_descriptors(batch: Dict, nlp): 70 | """ 71 | Splits a batch of prompts into subjects and descriptors. 72 | """ 73 | out = { 74 | "subject": [], 75 | "descriptor": [], 76 | } 77 | for prompt in batch["prompt"]: 78 | doc = nlp(prompt) 79 | subject_tokens, descriptor_tokens = [], [] 80 | 81 | # find the first chunk with either an entity or a proper noun. 82 | subject_found = False 83 | for chunk in doc.noun_chunks: 84 | if subject_found: 85 | descriptor_tokens.append(chunk.text) 86 | else: 87 | proper_nouns = [token for token in chunk if token.pos_ == "PROPN"] 88 | proper_ents, non_proper_ents = [], [] 89 | for ent in chunk.ents: 90 | if ent.label_ == "PERSON" or ent.label_ == "ORG": 91 | proper_ents.append(ent) 92 | else: 93 | non_proper_ents.append(ent) 94 | subject_tokens.append(chunk.root.text) 95 | if len(non_proper_ents) > 0: 96 | subject_tokens.append(chunk.text) 97 | subject_found = True 98 | elif len(proper_nouns) > 0 and len(proper_ents) == 0: 99 | subject_tokens.append(chunk.text) 100 | subject_found = True 101 | 102 | # print("token deps") 103 | subject_tokens = [ 104 | tok for i, tok in enumerate(subject_tokens) if tok not in subject_tokens[:i] 105 | ] 106 | out["subject"].append(" ".join(subject_tokens)) 107 | out["descriptor"].append(" ".join(descriptor_tokens)) 108 | return out 109 | 110 | def compute_dcg(relevance: List[int], k): 111 | dcg = 0.0 112 | for i in range(k): 113 | dcg += (2 ** relevance[i] - 1) / np.log2(i + 2) 114 | return dcg 115 | 116 | def compute_ndcg(true_rankings, pred_rankings, k): 117 | true_relevance = [1 if i in true_rankings else 0 for i in range(k)] 118 | true_dcg = compute_dcg(true_relevance, k) 119 | pred_relevance = [1 if i in pred_rankings else 0 for i in range(k)] 120 | pred_dcg = compute_dcg(pred_relevance, k) 121 | return pred_dcg / true_dcg -------------------------------------------------------------------------------- /notebooks/seq2seq_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from urllib.request import urlretrieve\n", 11 | "import pandas as pd\n", 12 | "from datasets import load_dataset\n", 13 | "\n", 14 | "table_url = f'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata-large.parquet'\n", 15 | "\n", 16 | "pbar = None\n", 17 | "\n", 18 | "if not os.path.exists('metadata.parquet'):\n", 19 | " print(\"retrieving metadata file\")\n", 20 | " urlretrieve(table_url, 'metadata.parquet')\n", 21 | "# Read the table using Pandas\\n\",\n", 22 | "metadata_df = pd.read_parquet('metadata.parquet')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import importlib.util\n", 32 | "spec = importlib.util.find_spec(\"en_core_web_trf\")\n", 33 | "if spec is None:\n", 34 | " print(\"Installing en_core_web_trf\")\n", 35 | " ! pip install https://huggingface.co/spacy/en_core_web_trf/resolve/main/en_core_web_trf-any-py3-none-any.whl" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import en_core_web_trf\n", 45 | "import spacy\n", 46 | "nlp = en_core_web_trf.load()" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "metadata_df[\"prompt\"].head(10)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "\n", 65 | "n_samples = 100_000\n", 66 | "first_n_unique_prompts = metadata_df[\"prompt\"].sample(n=n_samples, random_state=42).drop_duplicates().head(n_samples)\n", 67 | "display(first_n_unique_prompts.head(5).tolist())\n", 68 | "display(first_n_unique_prompts.shape)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "from datasets import Dataset, Features\n", 78 | "from spacy import displacy\n", 79 | "from spacy.symbols import nsubj, VERB\n", 80 | "\n", 81 | "\n", 82 | "def process(batch):\n", 83 | " out = {\n", 84 | " \"subject\": [],\n", 85 | " \"descriptor\": [],\n", 86 | " }\n", 87 | " for prompt in batch[\"prompt\"]:\n", 88 | " doc = nlp(prompt)\n", 89 | " # displacy.render(doc, style=\"dep\")\n", 90 | " subject_tokens, descriptor_tokens = [], []\n", 91 | " # find the first chunk with either an entity or a proper noun.\n", 92 | " subject_found = False\n", 93 | " for chunk in doc.noun_chunks:\n", 94 | " if subject_found:\n", 95 | " descriptor_tokens.append(chunk.text)\n", 96 | " else:\n", 97 | " proper_nouns = [token for token in chunk if token.pos_ == \"PROPN\"]\n", 98 | " proper_ents, non_proper_ents = [], []\n", 99 | " for ent in chunk.ents:\n", 100 | " if ent.label_ == \"PERSON\" or ent.label_ == \"ORG\":\n", 101 | " proper_ents.append(ent)\n", 102 | " else:\n", 103 | " non_proper_ents.append(ent)\n", 104 | " subject_tokens.append(chunk.root.text)\n", 105 | " if len(non_proper_ents) > 0:\n", 106 | " subject_tokens.append(chunk.text)\n", 107 | " subject_found = True\n", 108 | " elif len(proper_nouns) > 0 and len(proper_ents) == 0:\n", 109 | " subject_tokens.append(chunk.text)\n", 110 | " subject_found = True\n", 111 | "\n", 112 | " # print(\"token deps\")\n", 113 | " subject_tokens = [\n", 114 | " tok for i, tok in enumerate(subject_tokens) if tok not in subject_tokens[:i]\n", 115 | " ]\n", 116 | " out[\"subject\"].append(\" \".join(subject_tokens))\n", 117 | " out[\"descriptor\"].append(\" \".join(descriptor_tokens))\n", 118 | " return out\n", 119 | "\n", 120 | "\n", 121 | "# display([(p, process(p)) for p in [\n", 122 | "# \"stunning goddess of beers portrait, clear eyes and dark skin. realistic, symmetrical face. art by bowater charlie, mark brooks, julie bell, arian mark, tony sandoval \"\n", 123 | "# ]])\n", 124 | "display([(p, process({\"prompt\": [p]})) for p in first_n_unique_prompts[:10]])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "prompt_only_df = first_n_unique_prompts.to_frame()\n", 134 | "dataset = Dataset.from_pandas(prompt_only_df, preserve_index=False)\n", 135 | "dataset" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "dataset = dataset.map(process, batched=True, batch_size=512, remove_columns=[\"prompt\"])" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "from huggingface_hub import login\n", 154 | "\n", 155 | "display(dataset)\n", 156 | "login(\"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\", add_to_git_credential=True)\n", 157 | "dataset.push_to_hub(\"roborovski/diffusiondb-seq2seq\")\n" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.8.10" 178 | }, 179 | "orig_nbformat": 4 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 2 183 | } 184 | -------------------------------------------------------------------------------- /train_linear_emb_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | import os 7 | from tqdm.auto import tqdm 8 | from models.linear_emb_aug import LinearEmbAug 9 | from collections import defaultdict 10 | from torchvision import transforms 11 | import fire 12 | from torchinfo import summary 13 | from datasets import load_dataset, Dataset 14 | import bitsandbytes as bnb 15 | import wandb 16 | from utils import get_model_gradient_norm 17 | from torch.amp.autocast_mode import autocast 18 | from transformers import CLIPTextModel, AutoTokenizer, CLIPTokenizer 19 | 20 | torch.manual_seed(42) 21 | 22 | size = 512 23 | to_tensor = transforms.ToTensor() 24 | image_transforms = transforms.Compose( 25 | [ 26 | transforms.RandomCrop(size, pad_if_needed=True, padding_mode="reflect"), 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.5], [0.5]), 29 | ] 30 | ) 31 | 32 | 33 | def calculate_loss( 34 | model: LinearEmbAug, 35 | batch, 36 | text_encoder: CLIPTextModel, 37 | tokenizer: CLIPTokenizer, 38 | device, 39 | ): 40 | input_tokenized = tokenizer( 41 | batch["Prompt"], truncation=True, padding=True, return_tensors="pt" 42 | ).to(device) 43 | input_encoded = text_encoder(**input_tokenized).pooler_output 44 | 45 | label_tokenized = tokenizer( 46 | batch["Upsampled"], truncation=True, padding=True, return_tensors="pt" 47 | ).to(device) 48 | label_encoded = text_encoder(**label_tokenized).pooler_output 49 | 50 | out = model(input_encoded) 51 | loss = F.mse_loss(out, label_encoded) 52 | return loss 53 | 54 | 55 | def train( 56 | test_steps: int = 1000, 57 | test_batches: int = 10, 58 | output_filename: str = "linear_emb_aug.pt", 59 | save_steps: int = 5000, 60 | batch_size: int = 8, 61 | num_dataloader_workers: int = 0, 62 | n_epochs: int = 10, 63 | lr: float = 1e-4, 64 | clip_grad_val: float = 0, 65 | use_bnb: bool = True, 66 | use_wandb: bool = False, 67 | ): 68 | device = torch.device("cuda") 69 | 70 | text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") # type: ignore 71 | text_encoder.to(device) # type: ignore 72 | tokenizer: CLIPTokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") # type: ignore 73 | 74 | dataset = load_dataset( 75 | "roborovski/upsampled-prompts-parti", verification_mode="no_checks" 76 | ) 77 | dataset = dataset["train"].train_test_split(test_size=0.05) # type: ignore 78 | train_dataset: Dataset = dataset["train"] 79 | test_dataset: Dataset = dataset["test"] 80 | 81 | train_dataloader = DataLoader( 82 | train_dataset, # type: ignore 83 | batch_size=batch_size, 84 | num_workers=num_dataloader_workers, 85 | ) 86 | test_dataloader = DataLoader( 87 | test_dataset, # type: ignore 88 | batch_size=batch_size, 89 | num_workers=num_dataloader_workers, 90 | ) 91 | n_steps = len(train_dataloader) 92 | 93 | model = LinearEmbAug().to(device) 94 | 95 | model.train() 96 | model.requires_grad_(True) 97 | 98 | if use_wandb: 99 | wandb.init(project="superprompt-latent-aug") 100 | wandb.watch(model) 101 | 102 | print(summary(model)) 103 | 104 | if use_bnb: 105 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=lr) 106 | else: 107 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr) 108 | scaler = torch.cuda.amp.GradScaler(enabled=True) # type: ignore 109 | linear_scheduler = torch.optim.lr_scheduler.LinearLR( 110 | optimizer, start_factor=0.001, total_iters=200 111 | ) 112 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_steps) 113 | scheduler = torch.optim.lr_scheduler.SequentialLR( 114 | optimizer, schedulers=[linear_scheduler, cosine_scheduler], milestones=[20] 115 | ) 116 | 117 | model.train() 118 | epoch = 0 119 | global_step = 0 120 | 121 | for epoch in range(n_epochs): 122 | 123 | progress_bar = tqdm(range(n_steps)) 124 | progress_bar.set_description("Steps") 125 | 126 | for batch in train_dataloader: 127 | global_step += 1 128 | 129 | # get loss 130 | with autocast(device_type="cuda", dtype=torch.float16): 131 | loss = calculate_loss(model, batch, text_encoder, tokenizer, device) 132 | loss_rounded = round(loss.cpu().item(), 2) 133 | 134 | scaler.scale(loss).backward() # type: ignore 135 | scaler.unscale_(optimizer) 136 | 137 | if clip_grad_val > 0: 138 | nn.utils.clip_grad_norm_(model.parameters(), clip_grad_val) # type: ignore 139 | 140 | norm_text, lr_text = ( 141 | round(get_model_gradient_norm(model), 3), 142 | scheduler.get_last_lr()[0], 143 | ) 144 | progress_bar.set_postfix( 145 | loss=loss_rounded, lr=lr_text, norm=norm_text, epoch=epoch 146 | ) 147 | 148 | if use_wandb: 149 | wandb.log({"lr": lr_text, "norm": norm_text, "loss": loss}) 150 | 151 | scaler.step(optimizer) 152 | scaler.update() 153 | optimizer.zero_grad() 154 | 155 | progress_bar.update(1) 156 | scheduler.step() 157 | 158 | if (global_step % save_steps) == 0: 159 | base, ext = os.path.splitext(output_filename) 160 | save_filename = f"{base}-{global_step}{ext}" 161 | torch.save(model.state_dict(), save_filename) 162 | if global_step % test_steps == 0: 163 | test_batches = 0 164 | model.eval() 165 | for batch in test_dataloader: 166 | with torch.inference_mode(): 167 | loss = calculate_loss( 168 | model, batch, text_encoder, tokenizer, device 169 | ) 170 | test_batches += 1 171 | if test_batches >= test_batches: 172 | break 173 | model.train() 174 | 175 | wandb.finish() 176 | torch.save(model.state_dict(), output_filename) 177 | print("Model saved") 178 | 179 | 180 | if __name__ == "__main__": 181 | fire.Fire(train) 182 | -------------------------------------------------------------------------------- /generate_prompts/upsample_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Tuple 3 | 4 | import pandas as pd 5 | import torch 6 | from datasets import Dataset, load_dataset, concatenate_datasets 7 | from transformers import pipeline, Pipeline 8 | import fire 9 | from huggingface_hub import login 10 | from dotenv import load_dotenv 11 | from vllm import LLM, SamplingParams, RequestOutput 12 | 13 | 14 | def load_chat_pipeline_hf(): 15 | """Loads the HuggingFaceH4/zephyr-7b-alpha model and wraps into a handy text-generation pipeline.""" 16 | pipe = pipeline( 17 | "text-generation", 18 | model="HuggingFaceH4/zephyr-7b-alpha", 19 | torch_dtype=torch.bfloat16, 20 | device_map="auto", 21 | ) 22 | return pipe 23 | 24 | 25 | def get_messages_for_chat() -> Tuple[Dict, List[Dict]]: 26 | """ 27 | Prepares the system and user-assistant style messages for inference. 28 | 29 | Example messages come from the DALL-E 3 technical report: 30 | https://cdn.openai.com/papers/dall-e-3.pdf. 31 | """ 32 | system_message = { 33 | "role": "system", 34 | "content": """You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say in square brackets. For example, outputting "a beautiful morning in the woods with the sun peaking through the trees" will trigger your partner bot to output an image of a forest morning, as described. You will be prompted by people looking to create detailed, amazing images. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 35 | 36 | There are a few rules to follow: 37 | 38 | - You will only ever output a single image description per user request. 39 | - Sometimes the user will request that you modify previous captions. In this case, you should refer to your previous conversations with the user and make the modifications requested. 40 | - When modifications are requested, you should not simply make the description longer. You should refactor the entire description to integrate the suggestions. 41 | - Other times the user will not want modifications, but instead want a new image. In this case, you should ignore your previous conversation with the user." 42 | - Image descriptions must be between 15-80 words. Extra words will be ignored. 43 | """, 44 | } 45 | 46 | user_conversation = [ 47 | { 48 | "role": "user", 49 | "content": "Create an imaginative image descriptive caption or modify an earlier caption for the user input : 'make the light red'", 50 | }, 51 | { 52 | "role": "assistant", 53 | "content": "a pale figure with long white hair stands in the center of a dark forest, holding a sword high above his head. the blade glows with a red light, casting a warm glow on the trees and bushes surrounding him.", 54 | }, 55 | { 56 | "role": "user", 57 | "content": "Create an imaginative image descriptive caption or modify an earlier caption for the user input : 'draw a frog playing dominoes'", 58 | }, 59 | { 60 | "role": "assistant", 61 | "content": "a frog sits on a worn table playing a game of dominoes with an elderly raccoon. the table is covered in a green cloth, and the frog is wearing a jacket and a pair of jeans. The scene is set in a forest, with a large tree in the background.", 62 | }, 63 | { 64 | "role": "user", 65 | "content": "Create an imaginative image descriptive caption or modify an earlier caption for the user input : '{prompt}'", 66 | }, 67 | ] 68 | return system_message, user_conversation 69 | 70 | 71 | def upsample_caption_hf(pipeline: Pipeline, message: list[Dict[str, str]]): 72 | """Performs inference on a single prompt.""" 73 | outputs = pipeline( 74 | message, 75 | max_new_tokens=256, 76 | do_sample=True, 77 | temperature=0.7, 78 | top_k=50, 79 | top_p=0.95, 80 | ) 81 | return outputs 82 | 83 | 84 | def upload_dataset( 85 | hf_dataset: Dataset, hf_dataset_name: str, new_dataset_rows: List[Dict] 86 | ): 87 | dataset_new_rows = Dataset.from_list(new_dataset_rows) 88 | dataset_new_rows.to_csv("upsampled_new_prompts.csv") 89 | 90 | concat_dataset = concatenate_datasets([hf_dataset, dataset_new_rows]) 91 | 92 | print(f"Uploading {len(new_dataset_rows)} new prompts to the Hub...") 93 | concat_dataset.push_to_hub(hf_dataset_name) 94 | 95 | 96 | def main(): 97 | hf_dataset_name = "roborovski/upsampled-prompts-parti" 98 | 99 | print("Loading existing prompts...") 100 | hf_dataset: Dataset = load_dataset(hf_dataset_name, split="train") # type: ignore 101 | 102 | print("Loading new prompts...") 103 | parti_prompts: pd.DataFrame = pd.read_csv("PartiPrompts.tsv", sep="\t") 104 | 105 | source_prompts_list = parti_prompts 106 | 107 | new_dataset_rows: List[Dict] = [] 108 | 109 | print("Logging into the Hub...") 110 | file_dir = os.path.dirname(os.path.abspath(__file__)) 111 | load_dotenv(os.path.join(file_dir, ".env")) 112 | token = os.getenv("HF_TOKEN") 113 | print(f"Logging in with token: {token}") 114 | login(token=token, add_to_git_credential=True) 115 | 116 | # initial test upload before loading the pipeline 117 | upload_dataset(hf_dataset, hf_dataset_name, new_dataset_rows) 118 | 119 | n_epochs = 100 120 | 121 | sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=256) 122 | print("Loading local pipeline...") 123 | model = LLM(model="HuggingFaceH4/zephyr-7b-beta", dtype="auto") 124 | print("Pipeline loaded.") 125 | 126 | tokenizer = model.get_tokenizer() 127 | 128 | print("Upsampling captions...") 129 | for epoch in range(n_epochs): 130 | for i, row in enumerate(source_prompts_list.itertuples()): 131 | original_prompt, category = row.Prompt, row.Category 132 | system_message, user_conversation = get_messages_for_chat() 133 | updated_prompt = user_conversation[-1]["content"].format( 134 | prompt=original_prompt 135 | ) 136 | user_conversation[-1]["content"] = updated_prompt 137 | 138 | final_message = [system_message, *user_conversation] 139 | full_conversation_formatted: str = tokenizer.apply_chat_template( # type: ignore 140 | final_message, tokenize=False, add_generation_prompt=True 141 | ) 142 | 143 | outputs = model.generate(full_conversation_formatted, sampling_params) 144 | 145 | upsampled_caption = outputs[0].outputs[0].text 146 | new_dataset_rows.append( 147 | { 148 | "Prompt": original_prompt, 149 | "Category": category, 150 | "Upsampled": upsampled_caption, 151 | } 152 | ) 153 | 154 | print( 155 | f"Upsampled prompt {epoch} {i} ({category}): {original_prompt} -> {upsampled_caption}" 156 | ) 157 | 158 | if i % 500 == 0: 159 | print(f"Upsampled {i} prompts") 160 | upload_dataset(hf_dataset, hf_dataset_name, new_dataset_rows) 161 | 162 | upload_dataset(hf_dataset, hf_dataset_name, new_dataset_rows) 163 | 164 | 165 | if __name__ == "__main__": 166 | fire.Fire(main) 167 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class PositionEmbed(nn.Module): 8 | def __init__(self, dim_model, max_len=512) -> None: 9 | super().__init__() 10 | 11 | # embedding is a matrix of size (max_len, dim_model) 12 | # for each possible position i, j contains the sinusoid of frequency i / 10000^(2j/dim_model) 13 | pe = torch.zeros(max_len, dim_model) 14 | pe.requires_grad = False 15 | 16 | # create a 2D tensor with the position indices 17 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 18 | div_term = ( 19 | torch.arange(0, dim_model, 2, dtype=torch.float) 20 | * -(math.log(10000.0) / dim_model) 21 | ).exp() 22 | 23 | # for each 2 entries, starting at 0, we get a sin and cos activation 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | 27 | pe = pe.unsqueeze(0) 28 | self.register_buffer("pe", pe) 29 | 30 | def forward(self, x): 31 | # get the position embeddings for all tokens up to the current position idx 32 | return self.pe[:, : x.size(1)] 33 | 34 | 35 | class BERTEmbedding(nn.Module): 36 | def __init__(self, vocab_size, embed_size=512, dropout=0.1, max_len=512) -> None: 37 | super().__init__() 38 | 39 | self.token = nn.Embedding(vocab_size, embed_size, padding_idx=0) 40 | embedding_dim = self.token.embedding_dim 41 | self.position = PositionEmbed(dim_model=embedding_dim, max_len=max_len) 42 | self.dropout = nn.Dropout(p=dropout) 43 | self.embed_size = embed_size 44 | 45 | def forward(self, sequence): 46 | x = self.token(sequence) 47 | x = x + self.position(x) 48 | x = self.dropout(x) 49 | return x 50 | 51 | 52 | # Compute a single attention head 53 | class Attention(nn.Module): 54 | # matrix multiplication of query and key, then scaled by the square root of the dimension of the query 55 | def forward(self, query, key, value, mask=None, dropout=None): 56 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 57 | if mask is not None: 58 | scores = scores.masked_fill(mask == 0, -1e9) 59 | 60 | p_attn = F.softmax(scores, dim=-1) 61 | 62 | if dropout is not None: 63 | p_attn = dropout(p_attn) 64 | 65 | return torch.matmul(p_attn, value), p_attn 66 | 67 | 68 | class MultiHeadAttention(nn.Module): 69 | def __init__(self, attn_heads, hidden, dropout=0.1) -> None: 70 | super().__init__() 71 | assert hidden % attn_heads == 0 72 | 73 | # We assume d_v always equals d_k 74 | self.d_k = hidden // attn_heads 75 | self.h = attn_heads 76 | 77 | # linear layers for query, key and value 78 | self.linear_layers = nn.ModuleList( 79 | [nn.Linear(hidden, hidden) for _ in range(3)] 80 | ) 81 | # final linear layer for output 82 | self.output_linear = nn.Linear(hidden, hidden) 83 | 84 | # attention - performed per batch of queries 85 | self.attention = Attention() 86 | 87 | self.dropout = nn.Dropout(p=dropout) 88 | 89 | def forward(self, query, key, value, mask=None): 90 | batch_size = query.size(0) 91 | 92 | # linear projection from hidden to d_k * h 93 | # i.e. for each linear layer, we get the query, key and value 94 | # these represent the linear layer for each head 95 | query, key, value = [ 96 | l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 97 | for l, x in zip(self.linear_layers, (query, key, value)) 98 | ] 99 | 100 | # compute attention for all heads in a batch 101 | x, attention = self.attention( 102 | query, key, value, mask=mask, dropout=self.dropout 103 | ) 104 | 105 | # concatenate all heads 106 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 107 | 108 | # apply final linear layer 109 | return self.output_linear(x) 110 | 111 | 112 | class SublayerConnection(nn.Module): 113 | def __init__(self, hidden, dropout) -> None: 114 | super(SublayerConnection, self).__init__() 115 | self.norm = nn.LayerNorm(hidden) 116 | self.dropout = nn.Dropout(p=dropout) 117 | 118 | def forward(self, x, sublayer): 119 | return x + self.dropout(sublayer(self.norm(x))) 120 | 121 | 122 | # Feed forward layer, with dropout and GELU activation 123 | class PositionwiseFeedForward(nn.Module): 124 | def __init__(self, hidden, feed_forward_hidden, dropout=0.1) -> None: 125 | super(PositionwiseFeedForward, self).__init__() 126 | self.w_1 = nn.Linear(hidden, feed_forward_hidden) 127 | self.w_2 = nn.Linear(feed_forward_hidden, hidden) 128 | self.dropout = nn.Dropout(p=dropout) 129 | # gelu is the same as RELU with a slight dip before 0 130 | self.activation = nn.LeakyReLU() 131 | 132 | def forward(self, x): 133 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 134 | 135 | 136 | class TransformerBlock(nn.Module): 137 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout) -> None: 138 | super().__init__() 139 | self.attention = MultiHeadAttention(attn_heads, hidden) 140 | self.feed_forward = PositionwiseFeedForward( 141 | hidden, feed_forward_hidden, dropout 142 | ) 143 | self.input_sublayer = SublayerConnection(hidden, dropout) 144 | self.output_sublayer = SublayerConnection(hidden, dropout) 145 | self.dropout = nn.Dropout(p=dropout) 146 | 147 | def forward(self, x, mask): 148 | x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask)) 149 | x = self.output_sublayer(x, self.feed_forward) 150 | return self.dropout(x) 151 | 152 | 153 | class BERT(nn.Module): 154 | def __init__( 155 | self, 156 | vocab_size, 157 | hidden=768, 158 | n_layers=12, 159 | attn_heads=12, 160 | dropout=0.1, 161 | max_len=512, 162 | ): 163 | super().__init__() 164 | self.hidden = hidden 165 | self.n_layers = n_layers 166 | self.attn_heads = attn_heads 167 | 168 | self.feed_forward_hidden = hidden * 4 # 4 is hyperparameter 169 | 170 | self.embedding = BERTEmbedding(vocab_size, hidden, dropout, max_len) 171 | 172 | self.transformer_blocks = nn.ModuleList( 173 | [ 174 | TransformerBlock(hidden, attn_heads, hidden * 4, dropout) 175 | for _ in range(n_layers) 176 | ] 177 | ) 178 | 179 | # masked LM 180 | self.linear = nn.Linear(hidden, vocab_size) 181 | self.softmax = nn.LogSoftmax(dim=-1) 182 | 183 | def forward(self, x, mask): 184 | # attention mask for padded token 185 | # torch.ByteTensor([batch_size, 1, seq_len, seq_len) 186 | mask = mask.unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 187 | 188 | # get the embedding for the input sequence 189 | x = self.embedding(x) 190 | 191 | for transformer in self.transformer_blocks: 192 | x = transformer(x, mask) 193 | 194 | # masked LM 195 | x = self.softmax(self.linear(x)) 196 | 197 | return x 198 | -------------------------------------------------------------------------------- /models/bert_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import AdamW 4 | import numpy as np 5 | import tqdm 6 | from superprompt.models.bert import BERT 7 | import wandb 8 | from transformers import BertTokenizer, DataCollatorForLanguageModeling 9 | import random 10 | from datasets import IterableDataset 11 | from utils import get_available_device, sample_prompt_pairs 12 | 13 | 14 | class ScheduledOptim: 15 | """A simple wrapper class for learning rate scheduling""" 16 | 17 | def __init__(self, optimizer, d_model, n_warmup_steps): 18 | self._optimizer = optimizer 19 | self.n_warmup_steps = n_warmup_steps 20 | self.n_current_steps = 0 21 | self.init_lr = np.power(d_model, -0.5) 22 | 23 | def step_and_update_lr(self): 24 | "Step with the inner optimizer" 25 | self._update_learning_rate() 26 | self._optimizer.step() 27 | 28 | def zero_grad(self): 29 | "Zero out the gradients by the inner optimizer" 30 | self._optimizer.zero_grad() 31 | 32 | def _get_lr_scale(self): 33 | return np.min( 34 | [ 35 | np.power(self.n_current_steps, -0.5), 36 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps, 37 | ] 38 | ) 39 | 40 | def _update_learning_rate(self): 41 | """Learning rate scheduling per step""" 42 | 43 | self.n_current_steps += 1 44 | lr = self.init_lr * self._get_lr_scale() 45 | 46 | for param_group in self._optimizer.param_groups: 47 | param_group["lr"] = lr 48 | 49 | 50 | class BERTTrainer: 51 | def __init__( 52 | self, 53 | bert: BERT, 54 | tokenizer: BertTokenizer, 55 | collator: DataCollatorForLanguageModeling, 56 | train_dataset: IterableDataset, 57 | test_dataset: IterableDataset, 58 | lr: float = 1e-4, 59 | betas=(0.9, 0.999), 60 | weight_decay: float = 0.01, 61 | max_len: int = 256, 62 | batch_size: int = 32, 63 | log_freq: int = 10, 64 | valid_freq: int = 10, 65 | save_freq: int = 1000, 66 | output_path: str = "./saved", 67 | use_wandb: bool = False, 68 | ): 69 | # Setup cuda device for BERT training, argument -c, --cuda should be true 70 | # This BERT model will be saved every epoch 71 | # Initialize the BERT Language Model, with BERT model 72 | self.device = get_available_device() 73 | self.model = bert.to(self.device) 74 | 75 | self.tokenizer = tokenizer 76 | self.collator: DataCollatorForLanguageModeling = collator 77 | self.use_wandb: bool = use_wandb 78 | self.batch_size: int = batch_size 79 | self.output_path: str = output_path 80 | 81 | # Setting the train and test data loader 82 | self.train_data = train_dataset 83 | self.test_data = test_dataset 84 | self.max_len = max_len 85 | 86 | # Setting the Adam optimizer with hyper-param 87 | self.optim = AdamW( 88 | self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay 89 | ) 90 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, 10) 91 | 92 | # Using Negative Log Likelihood Loss function for predicting the masked_token 93 | print("mask token id", tokenizer.mask_token_id) 94 | self.criterion = nn.NLLLoss(ignore_index=tokenizer.mask_token_id) 95 | 96 | self.log_freq = log_freq 97 | self.save_freq = save_freq 98 | self.valid_freq = valid_freq 99 | self.table_rows = [] 100 | 101 | def train(self, epoch): 102 | self.iteration(epoch, self.train_data) 103 | 104 | def test(self, epoch): 105 | self.iteration(epoch, self.test_data, train=False) 106 | 107 | def iteration(self, epoch, dataset: IterableDataset, train=True): 108 | # Setting the tqdm progress bar 109 | data_iter = tqdm.tqdm( 110 | enumerate(dataset.iter(batch_size=self.batch_size)), 111 | bar_format="{l_bar}{r_bar}", 112 | ) 113 | 114 | avg_loss = 0.0 115 | 116 | for i, data in data_iter: 117 | # 0. batch_data will be sent into the device(GPU or cpu) 118 | collated = self.collator(data["input_ids"]) 119 | input_ids = collated["input_ids"].to(self.device) 120 | attn_mask = torch.stack(data["attention_mask"]).to(self.device) 121 | 122 | mask_lm_output = self.model.forward(input_ids, attn_mask) 123 | 124 | transposed_output = mask_lm_output.transpose(1, 2) 125 | 126 | loss = self.criterion(transposed_output, input_ids) 127 | 128 | avg_loss += loss.item() 129 | avg_loss /= i + 1 130 | print(f"epoch {epoch} i {i} avg_loss {avg_loss}") 131 | 132 | if train: 133 | self.optim.zero_grad() 134 | loss.backward() 135 | self.optim.step() 136 | 137 | if i % self.log_freq == 0: 138 | post_fix = { 139 | "epoch": epoch, 140 | "batch": i, 141 | "avg_loss": avg_loss, 142 | "loss": loss.item(), 143 | } 144 | data_iter.write(str(post_fix)) 145 | if self.use_wandb: 146 | wandb.log(post_fix) 147 | if i % self.valid_freq == 0: 148 | decoded = self.eval_sample() 149 | if self.use_wandb: 150 | self.table_rows.append([epoch, avg_loss, decoded]) 151 | print("table", len(self.table_rows)) 152 | table = wandb.Table( 153 | data=self.table_rows, 154 | columns=["epoch", "avg_loss", "sample"], 155 | ) 156 | wandb.log({"samples": table}) 157 | if i % self.save_freq == 0: 158 | self.save(epoch, self.output_path) 159 | 160 | def eval_sample(self): 161 | prompt = random.choice(sample_prompt_pairs) 162 | print("---EVAL---") 163 | print("prompt", prompt) 164 | tokenized = self.tokenizer( 165 | prompt, 166 | truncation=True, 167 | padding="max_length", 168 | max_length=self.max_len, 169 | return_tensors="pt", 170 | ) 171 | print("tokenized", tokenized) 172 | eval_batch = self.collator([tokenized]) 173 | print("batch", eval_batch) 174 | input_ids = eval_batch["input_ids"].squeeze(0).to(self.device) 175 | attn_mask = tokenized["attention_mask"].to(self.device) 176 | print("input ids", input_ids) 177 | mask_lm_output = self.model.forward(input_ids, attn_mask) 178 | output = torch.argmax(mask_lm_output, dim=2) 179 | print("output", output) 180 | decoded = self.tokenizer.decode(output[0]) 181 | print("decoded", decoded) 182 | return decoded 183 | 184 | def save(self, epoch, file_path="output/bert_trained.model"): 185 | """ 186 | Saving the current BERT model on file_path 187 | 188 | :param epoch: current epoch number 189 | :param file_path: model output path which gonna be file_path+"ep%d" % epoch 190 | :return: final_output_path 191 | """ 192 | output_path = file_path + ".ep%d" % epoch 193 | torch.save(self.model.cpu(), output_path) 194 | self.model.to(self.device) 195 | print("EP:%d Model Saved on:" % epoch, output_path) 196 | return output_path 197 | -------------------------------------------------------------------------------- /train_emb.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import Dict 3 | from datasets import load_dataset 4 | import spacy 5 | import wandb 6 | from models.clip_emb_aug import CLIPEmbeddingAugmenter 7 | import torch.nn as nn 8 | import fire 9 | import torch 10 | import torch.nn as nn 11 | import wandb 12 | from torch.optim import AdamW 13 | from tqdm import tqdm 14 | from torch.utils.data import DataLoader 15 | from diffusers.utils.import_utils import is_xformers_available 16 | import gc 17 | from torchinfo import summary 18 | import os 19 | from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR 20 | from diffusers import ( 21 | StableDiffusionPipeline, 22 | ) 23 | from transformers import AutoTokenizer, CLIPTextModel 24 | from utils import get_available_device 25 | from enum import Enum 26 | from typing import Dict 27 | import torch.nn.functional as F 28 | 29 | if not spacy.util.is_package("en_core_web_sm"): 30 | spacy.cli.download("en_core_web_sm") 31 | 32 | torch.manual_seed(0) 33 | 34 | 35 | def loss_fn_emb_aug( 36 | batch: Dict, device: torch.device, model: CLIPEmbeddingAugmenter, 37 | ): 38 | mask_emb, unmask_emb = ( 39 | batch["masked_embeddings"].to(device), 40 | batch["unmasked_embeddings"].to(device), 41 | ) 42 | 43 | model_out = model(mask_emb) 44 | loss = F.mse_loss(model_out, unmask_emb) 45 | 46 | return loss, model_out 47 | 48 | 49 | def main(use_wandb: bool = False, eval_every: int = 25): 50 | device = get_available_device() 51 | 52 | clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to( 53 | device 54 | ) 55 | tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") 56 | max_clip_length = clip_model.config.max_position_embeddings 57 | 58 | model = CLIPEmbeddingAugmenter(clip_model) 59 | print(summary(model)) 60 | 61 | nlp = spacy.load("en_core_web_sm") 62 | 63 | def mask_non_nouns(prompt): 64 | doc = nlp(prompt) 65 | masked_tokens = [] 66 | pos = set({"NOUN", "PROPN"}) 67 | for token in doc: 68 | if token.pos_ in pos or (token.dep_ == "nsubj" and token.pos_ == "VERB"): 69 | masked_tokens.append(token.text) 70 | return " ".join(masked_tokens) 71 | 72 | def preprocess_dataset(batch): 73 | unmasked_prompts = batch["prompt"] 74 | unmasked_inputs = tokenizer( 75 | text=unmasked_prompts, 76 | return_tensors="pt", 77 | max_length=max_clip_length, 78 | padding="max_length", 79 | truncation=True, 80 | ).to(device) 81 | unmasked_clip_out = clip_model(**unmasked_inputs) 82 | unmasked_embeddings = unmasked_clip_out.last_hidden_state 83 | 84 | 85 | masked_prompts = [mask_non_nouns(prompt) for prompt in unmasked_prompts] 86 | masked_inputs = tokenizer( 87 | text=masked_prompts, 88 | return_tensors="pt", 89 | max_length=max_clip_length, 90 | padding="max_length", 91 | truncation=True, 92 | ).to(device) 93 | masked_clip_out = clip_model(**masked_inputs) 94 | masked_embeddings = masked_clip_out.last_hidden_state 95 | batch_dict = { 96 | "unmasked_embeddings": unmasked_embeddings, 97 | "masked_embeddings": masked_embeddings, 98 | "masked_prompts": masked_prompts, 99 | "unmasked_prompts": unmasked_prompts, 100 | } 101 | return batch_dict 102 | 103 | # TODO filter for only high image text alignment scores 104 | remove_cols = [ 105 | "image", 106 | "prompt_id", 107 | "prompt", 108 | "classification", 109 | "image_amount_in_total", 110 | "rank", 111 | "overall_rating", 112 | "image_text_alignment_rating", 113 | "fidelity_rating", 114 | ] 115 | dataset = load_dataset("THUDM/ImageRewardDB", "4k", verification_mode="no_checks") 116 | dataset.set_format("torch") 117 | dataset = dataset.map( 118 | preprocess_dataset, 119 | cache_file_names={"train": "train_cache", "validation": "val_cache", "test": "test_cache"}, 120 | batched=True, 121 | num_proc=1, 122 | drop_last_batch=True, 123 | batch_size=96, 124 | remove_columns=remove_cols, 125 | ) 126 | # dataset.save_to_disk("image_reward_processed") 127 | 128 | print("Loading model..") 129 | 130 | model.train() 131 | 132 | # Hyperparameters 133 | num_epochs: int = 200 134 | learning_rate: float = 1e-5 135 | batch_size: int = 64 136 | 137 | optimizer = AdamW(model.parameters(), lr=learning_rate) 138 | scheduler = CosineAnnealingLR( 139 | optimizer, T_max=num_epochs // 4, eta_min=learning_rate / 10 140 | ) 141 | 142 | if use_wandb: 143 | wandb.init(project="superprompt-aug") 144 | wandb.watch(model) 145 | 146 | train_dataset, val_dataset = dataset["train"], dataset["validation"] 147 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 148 | val_loader = DataLoader(val_dataset, batch_size=4) 149 | 150 | for epoch in range(num_epochs): 151 | train_iter = tqdm(train_loader, total=len(train_loader)) 152 | for batch in train_iter: 153 | 154 | loss, model_out = loss_fn_emb_aug( 155 | batch, device, model 156 | ) 157 | 158 | log_dict = { 159 | "loss": loss.item(), 160 | "lr": optimizer.param_groups[0]["lr"], 161 | "epoch": epoch, 162 | } 163 | 164 | train_iter.set_postfix(log=log_dict) 165 | if use_wandb: 166 | wandb.log(log_dict) 167 | 168 | # Backward pass and optimization 169 | optimizer.zero_grad() 170 | loss.backward() 171 | optimizer.step() 172 | before_lr = optimizer.param_groups[0]["lr"] 173 | scheduler.step() 174 | after_lr = optimizer.param_groups[0]["lr"] 175 | print("Epoch %d: SGD lr %.4f -> %.4f" % (epoch, before_lr, after_lr)) 176 | 177 | if i % eval_every == 0: 178 | for batch in val_loader: 179 | pipe = StableDiffusionPipeline.from_pretrained( 180 | "runwayml/stable-diffusion-v1-5", 181 | torch_dtype=torch.float16, 182 | safety_checker=None, 183 | ) 184 | pipe = pipe.to("cuda") 185 | loss, model_out = loss_fn_emb_aug( 186 | batch, device, model 187 | ) 188 | 189 | if is_xformers_available(): 190 | pipe.unet.enable_xformers_memory_efficient_attention() 191 | 192 | log_dict = { 193 | "loss": loss.item(), 194 | "unmasked": [], 195 | "encoded": [], 196 | } 197 | 198 | unmask_emb = batch["unmasked_embeddings"].to(device) 199 | 200 | for key in ("unmasked", "encoded"): 201 | print(f"Generating {key} images...") 202 | embeds = unmask_emb if key == "unmasked" else model_out 203 | generations = pipe(prompt_embeds=embeds).images 204 | os.makedirs("out", exist_ok=True) 205 | for i, generation in enumerate(generations): 206 | generation.save(f"out/{key}_{i}.png") 207 | if use_wandb: 208 | log_dict[key].append(wandb.Image(generation)) 209 | 210 | if use_wandb: 211 | wandb.log(log_dict) 212 | 213 | del pipe 214 | gc.collect() 215 | torch.cuda.empty_cache() 216 | 217 | break 218 | 219 | wandb.finish() 220 | 221 | 222 | if __name__ == "__main__": 223 | fire.Fire(main) 224 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout): 9 | super().__init__() 10 | 11 | self.embedding = nn.Embedding(input_dim, emb_dim) 12 | 13 | self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True) 14 | 15 | self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) 16 | 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, src, src_len): 20 | # src = [src len, batch size] 21 | # src_len = [batch size] 22 | 23 | embedded = self.dropout(self.embedding(src)) 24 | 25 | # embedded = [src len, batch size, emb dim] 26 | 27 | # need to explicitly put lengths on cpu! 28 | packed_embedded = nn.utils.rnn.pack_padded_sequence( 29 | embedded, src_len.to("cpu"), enforce_sorted=False 30 | ) 31 | 32 | packed_outputs, hidden = self.rnn(packed_embedded) 33 | 34 | # packed_outputs is a packed sequence containing all hidden states 35 | # hidden is now from the final non-padded element in the batch 36 | 37 | outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 38 | 39 | # outputs is now a non-packed sequence, all hidden states obtained 40 | # when the input is a pad token are all zeros 41 | 42 | # outputs = [src len, batch size, hid dim * num directions] 43 | # hidden = [n layers * num directions, batch size, hid dim] 44 | 45 | # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...] 46 | # outputs are always from the last layer 47 | 48 | # hidden [-2, :, : ] is the last of the forwards RNN 49 | # hidden [-1, :, : ] is the last of the backwards RNN 50 | 51 | # initial decoder hidden is final hidden state of the forwards and backwards 52 | # encoder RNNs fed through a linear layer 53 | hidden = torch.tanh( 54 | self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) 55 | ) 56 | 57 | # outputs = [src len, batch size, enc hid dim * 2] 58 | # hidden = [batch size, dec hid dim] 59 | 60 | return outputs, hidden 61 | 62 | 63 | class Attention(nn.Module): 64 | def __init__(self, enc_hid_dim, dec_hid_dim): 65 | super().__init__() 66 | 67 | self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) 68 | self.v = nn.Linear(dec_hid_dim, 1, bias=False) 69 | 70 | def forward(self, hidden, encoder_outputs, mask): 71 | # hidden = [batch size, dec hid dim] 72 | # encoder_outputs = [src len, batch size, enc hid dim * 2] 73 | 74 | batch_size = encoder_outputs.shape[1] 75 | src_len = encoder_outputs.shape[0] 76 | 77 | # repeat decoder hidden state src_len times 78 | hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) 79 | 80 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 81 | 82 | # hidden = [batch size, src len, dec hid dim] 83 | # encoder_outputs = [batch size, src len, enc hid dim * 2] 84 | 85 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2))) 86 | 87 | # energy = [batch size, src len, dec hid dim] 88 | 89 | attention = self.v(energy).squeeze(2) 90 | 91 | # attention = [batch size, src len] 92 | narrow_mask = mask[:, :attention.shape[1]] 93 | 94 | attention = attention.masked_fill(narrow_mask == 0, -1e10) 95 | 96 | return F.softmax(attention, dim=1) 97 | 98 | 99 | class Decoder(nn.Module): 100 | def __init__( 101 | self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention 102 | ): 103 | super().__init__() 104 | 105 | self.output_dim = output_dim 106 | self.attention = attention 107 | 108 | self.embedding = nn.Embedding(output_dim, emb_dim) 109 | 110 | self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) 111 | 112 | self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) 113 | 114 | self.dropout = nn.Dropout(dropout) 115 | 116 | def forward(self, input, hidden, encoder_outputs, mask): 117 | # input = [batch size] 118 | # hidden = [n layers * n directions, batch size, hid dim] 119 | # cell = [n layers * n directions, batch size, hid dim] 120 | 121 | # n directions in the decoder will both always be 1, therefore: 122 | # hidden = [n layers, batch size, hid dim] 123 | # context = [n layers, batch size, hid dim] 124 | 125 | input = input.unsqueeze(0) 126 | 127 | # input = [1, batch size] 128 | 129 | embedded = self.dropout(self.embedding(input)) 130 | 131 | # embedded = [1, batch size, emb dim] 132 | 133 | a = self.attention(hidden, encoder_outputs, mask) 134 | 135 | a = a.unsqueeze(1) 136 | 137 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 138 | 139 | weighted = torch.bmm(a, encoder_outputs) 140 | 141 | weighted = weighted.permute(1, 0, 2) 142 | 143 | rnn_input = torch.cat((embedded, weighted), dim=2) 144 | 145 | # rnn_input = [1, batch size, (enc hid dim * 2) + emb dim] 146 | 147 | output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) 148 | 149 | assert (output == hidden).all() 150 | 151 | embedded = embedded.squeeze(0) 152 | output = output.squeeze(0) 153 | weighted = weighted.squeeze(0) 154 | 155 | prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1)) 156 | 157 | # prediction = [batch size, output dim] 158 | 159 | return prediction, hidden.squeeze(0), a.squeeze(1) 160 | 161 | 162 | class Seq2Seq(nn.Module): 163 | def __init__(self, encoder, decoder, src_pad_idx, device): 164 | super().__init__() 165 | 166 | self.encoder = encoder 167 | self.decoder = decoder 168 | self.src_pad_idx = src_pad_idx 169 | self.device = device 170 | 171 | def create_mask(self, src): 172 | mask = (src != self.src_pad_idx).permute(1, 0) 173 | return mask 174 | 175 | def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5): 176 | # src = [src len, batch size] 177 | # src_len = [batch size] 178 | # trg = [trg len, batch size] 179 | # teacher_forcing_ratio is probability to use teacher forcing 180 | # e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time 181 | 182 | batch_size = src.shape[1] 183 | trg_len = trg.shape[0] 184 | trg_vocab_size = self.decoder.output_dim 185 | 186 | # tensor to store decoder outputs 187 | outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) 188 | 189 | # encoder_outputs is all hidden states of the input sequence, back and forwards 190 | # hidden is the final forward and backward hidden states, passed through a linear layer 191 | encoder_outputs, hidden = self.encoder(src, src_len) 192 | 193 | # first input to the decoder is the tokens 194 | input = trg[0, :] 195 | 196 | mask = self.create_mask(src) 197 | 198 | # mask = [batch size, src len] 199 | 200 | for t in range(1, trg_len): 201 | # insert input token embedding, previous hidden state, all encoder hidden states 202 | # and mask 203 | # receive output tensor (predictions) and new hidden state 204 | output, hidden, _ = self.decoder(input, hidden, encoder_outputs, mask) 205 | 206 | # place predictions in a tensor holding predictions for each token 207 | outputs[t] = output 208 | 209 | # decide if we are going to use teacher forcing or not 210 | teacher_force = random.random() < teacher_forcing_ratio 211 | 212 | # get the highest predicted token from our predictions 213 | top1 = output.argmax(1) 214 | 215 | # if teacher forcing, use actual next token as next input 216 | # if not, use predicted token 217 | input = trg[t] if teacher_force else top1 218 | 219 | return outputs 220 | -------------------------------------------------------------------------------- /train_latent_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from diffusers import AutoencoderKL 7 | import os 8 | from tqdm.auto import tqdm 9 | from models.latent_aug import LatentAugmenter 10 | import lpips 11 | from collections import defaultdict 12 | from torchvision import transforms 13 | import fire 14 | from typing import List 15 | from torchinfo import summary 16 | from enum import Enum 17 | from datasets import load_dataset 18 | import bitsandbytes as bnb 19 | import wandb 20 | from utils import get_model_gradient_norm 21 | 22 | size = 512 23 | to_tensor = transforms.ToTensor() 24 | image_transforms = transforms.Compose( 25 | [ 26 | transforms.RandomCrop(size, pad_if_needed=True, padding_mode="reflect"), 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.5], [0.5]), 29 | ] 30 | ) 31 | 32 | 33 | def collate_fn(batch): 34 | processed_batch = {} 35 | for label in ("img_better", "img_worse"): 36 | imgs = [sample[label] for sample in batch] 37 | imgs = [image_transforms(img) for img in imgs] 38 | img_tensors = torch.stack(imgs) 39 | processed_batch[label] = img_tensors 40 | 41 | return processed_batch 42 | 43 | 44 | def calculate_loss( 45 | model: LatentAugmenter, 46 | batch, 47 | device: torch.device, 48 | vae: AutoencoderKL, 49 | lpips_fn, 50 | dtype, 51 | decoded_weight: float, 52 | lpips_weight: float, 53 | mse_latent_weight: float, 54 | ): 55 | img_input = batch["img_worse"].to(device, dtype=dtype) 56 | img_target = batch["img_better"].to(device, dtype=dtype) 57 | latent_input = ( 58 | vae.config.scaling_factor * vae.encode(img_input).latent_dist.sample() 59 | ) 60 | latent_target = ( 61 | vae.config.scaling_factor * vae.encode(img_target).latent_dist.sample() 62 | ) 63 | size = latent_target.shape[-2:] 64 | resized = model(latent_input, size=size) 65 | mse_latent = F.mse_loss(resized, latent_target) 66 | logs = {"mse_latent": mse_latent.cpu().item()} 67 | loss = mse_latent_weight * mse_latent 68 | if decoded_weight > 0: 69 | decoded = vae.decode(resized / vae.config.scaling_factor)[0] 70 | decoded_loss = F.mse_loss(decoded, img_target) 71 | logs["mse"] = decoded_loss 72 | loss = loss + decoded_weight * decoded_loss 73 | if lpips_weight > 0: 74 | lpips_loss = lpips_fn(decoded, img_target).mean() 75 | logs["lpips"] = lpips_loss.cpu().item() 76 | loss = loss + lpips_weight * lpips_loss 77 | logs["loss"] = loss.cpu().item() 78 | return loss, logs 79 | 80 | 81 | class Objective(Enum): 82 | Upscale = "upscale" 83 | augment = "augment" 84 | 85 | 86 | 87 | def train( 88 | vae_path: str = "runwayml/stable-diffusion-v1-5", 89 | objective: str = "upscale", 90 | test_path: List[str] = None, 91 | test_steps: int = 1000, 92 | test_batches: int = 10, 93 | output_filename: str = "sdxl_resizer.pt", 94 | steps: int = 1e4, 95 | save_steps: int = 5000, 96 | batch_size: int = 2, 97 | num_dataloader_workers: int = 0, 98 | lr: float = 2e-4, 99 | dropout: float = 0.0, 100 | clip_grad_val: float = 50.0, 101 | device: str = "cuda", 102 | init_weights: str = None, 103 | fp16: bool = True, 104 | use_bnb: bool = True, 105 | use_wandb: bool = False, 106 | ): 107 | device = torch.device(device) 108 | objective = Objective(objective) 109 | steps = int(steps) 110 | 111 | dataset = load_dataset( 112 | "THUDM/ImageRewardDB", "2k_pair", verification_mode="no_checks" 113 | ) 114 | train_dataset = dataset["train"] 115 | test_dataset = dataset["test"] 116 | 117 | train_dataloader = DataLoader( 118 | train_dataset, 119 | batch_size=batch_size, 120 | collate_fn=collate_fn, 121 | num_workers=num_dataloader_workers, 122 | ) 123 | test_dataloader = DataLoader( 124 | test_dataset, 125 | batch_size=batch_size, 126 | collate_fn=collate_fn, 127 | num_workers=num_dataloader_workers, 128 | ) 129 | 130 | vae_dtype = torch.float32 131 | if fp16: 132 | vae_dtype = torch.float16 133 | 134 | vae = ( 135 | AutoencoderKL.from_pretrained(vae_path, subfolder="vae") 136 | .to(device) 137 | .to(dtype=vae_dtype) 138 | ) 139 | # Use this scale even with SD 1.5 140 | vae.config.scaling_factor = 0.13025 141 | 142 | lpips_fn = lpips.LPIPS(net="vgg").to(device=device, dtype=vae_dtype) 143 | 144 | if init_weights: 145 | model = LatentAugmenter.load_model( 146 | init_weights, 147 | device=device, 148 | dropout=dropout, 149 | dtype=torch.float32, 150 | ) 151 | else: 152 | model = LatentAugmenter(dropout=dropout).to(device) 153 | 154 | model.train() 155 | model.requires_grad_(True) 156 | 157 | if use_wandb: 158 | wandb.init(project="superprompt-latent-aug") 159 | wandb.watch(model) 160 | 161 | print(summary(model)) 162 | 163 | if use_bnb: 164 | optimizer = bnb.optim.Adam8bit(model.parameters(), lr=lr) 165 | else: 166 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr) 167 | scaler = torch.cuda.amp.GradScaler(enabled=True) 168 | linear_scheduler = torch.optim.lr_scheduler.LinearLR( 169 | optimizer, start_factor=0.001, total_iters=200 170 | ) 171 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps) 172 | scheduler = torch.optim.lr_scheduler.SequentialLR( 173 | optimizer, schedulers=[linear_scheduler, cosine_scheduler], milestones=[20] 174 | ) 175 | 176 | model.train() 177 | epoch = 0 178 | step = 0 179 | progress_bar = tqdm(range(steps)) 180 | progress_bar.set_description("Steps") 181 | 182 | while step < steps: 183 | epoch += 1 184 | for batch in train_dataloader: 185 | step += 1 186 | 187 | # get loss 188 | with torch.autocast(device_type="cuda", dtype=torch.float16): 189 | loss, logs = calculate_loss( 190 | model, 191 | batch, 192 | device, 193 | vae, 194 | lpips_fn, 195 | vae_dtype, 196 | lpips_weight=0, 197 | decoded_weight=0, 198 | mse_latent_weight=1, 199 | ) 200 | loss_rounded = round(loss.cpu().item(), 2) 201 | 202 | scaler.scale(loss).backward() 203 | scaler.unscale_(optimizer) 204 | 205 | if clip_grad_val > 0: 206 | nn.utils.clip_grad_norm_(model.parameters(), clip_grad_val) 207 | 208 | norm_text, lr_text = round(get_model_gradient_norm(model), 3), scheduler.get_last_lr()[0] 209 | progress_bar.set_postfix( 210 | loss=loss_rounded, lr=lr_text, norm=norm_text 211 | ) 212 | 213 | if use_wandb: 214 | wandb.log({**logs, "lr": lr_text, "norm": norm_text}) 215 | 216 | scaler.step(optimizer) 217 | scaler.update() 218 | optimizer.zero_grad() 219 | 220 | progress_bar.update(1) 221 | scheduler.step() 222 | 223 | if step >= steps: 224 | break 225 | if (step % save_steps) == 0: 226 | base, ext = os.path.splitext(output_filename) 227 | save_filename = f"{base}-{step}{ext}" 228 | torch.save(model.state_dict(), save_filename) 229 | if test_path and (step % test_steps) == 0: 230 | test_batches = 0 231 | test_logs = defaultdict(float) 232 | model.eval() 233 | for batch in test_dataloader: 234 | with torch.inference_mode(): 235 | _, logs = calculate_loss( 236 | model, 237 | batch, 238 | device, 239 | vae, 240 | lpips_fn, 241 | vae_dtype, 242 | lpips_weight=1, 243 | decoded_weight=1, 244 | ) 245 | test_batches += 1 246 | for k in logs.keys(): 247 | test_logs[k] += logs[k] 248 | if test_batches >= test_batches: 249 | break 250 | model.train() 251 | 252 | wandb.finish() 253 | torch.save(model.state_dict(), output_filename) 254 | print("Model saved") 255 | 256 | 257 | if __name__ == "__main__": 258 | fire.Fire(train) 259 | -------------------------------------------------------------------------------- /models/latent_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | def normalization(channels): 9 | return nn.GroupNorm(32, channels) 10 | 11 | 12 | def zero_module(module): 13 | for p in module.parameters(): 14 | p.detach().zero_() 15 | return module 16 | 17 | 18 | class AttnBlock(nn.Module): 19 | def __init__(self, in_channels): 20 | super().__init__() 21 | self.in_channels = in_channels 22 | 23 | self.norm = normalization(in_channels) 24 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 25 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 26 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 27 | self.proj_out = nn.Conv2d( 28 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 29 | ) 30 | 31 | def attention(self, h_: torch.Tensor) -> torch.Tensor: 32 | h_ = self.norm(h_) 33 | q = self.q(h_) 34 | k = self.k(h_) 35 | v = self.v(h_) 36 | 37 | b, c, h, w = q.shape 38 | q, k, v = map( 39 | lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) 40 | ) 41 | h_ = nn.functional.scaled_dot_product_attention( 42 | q, k, v 43 | ) # scale is dim ** -0.5 per default 44 | 45 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 46 | 47 | def forward(self, x, **kwargs): 48 | h_ = x 49 | h_ = self.attention(h_) 50 | h_ = self.proj_out(h_) 51 | return x + h_ 52 | 53 | 54 | # TODO add other attn types 55 | def make_attn(in_channels, attn_kwargs=None): 56 | return AttnBlock(in_channels) 57 | 58 | 59 | class ResBlockEmb(nn.Module): 60 | def __init__( 61 | self, 62 | channels, 63 | emb_channels, 64 | dropout=0, 65 | out_channels=None, 66 | use_conv=False, 67 | use_scale_shift_norm=False, 68 | kernel_size=3, 69 | exchange_temb_dims=False, 70 | skip_t_emb=False, 71 | ): 72 | super().__init__() 73 | self.channels = channels 74 | self.emb_channels = emb_channels 75 | self.dropout = dropout 76 | self.out_channels = out_channels or channels 77 | self.use_conv = use_conv 78 | self.use_scale_shift_norm = use_scale_shift_norm 79 | self.exchange_temb_dims = exchange_temb_dims 80 | 81 | padding = kernel_size // 2 82 | 83 | self.in_layers = nn.Sequential( 84 | normalization(channels), 85 | nn.SiLU(), 86 | nn.Conv2d(channels, self.out_channels, kernel_size, padding=padding), 87 | ) 88 | 89 | self.skip_t_emb = skip_t_emb 90 | self.emb_out_channels = ( 91 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels 92 | ) 93 | if self.skip_t_emb: 94 | print(f"Skipping timestep embedding in {self.__class__.__name__}") 95 | assert not self.use_scale_shift_norm 96 | self.emb_layers = None 97 | self.exchange_temb_dims = False 98 | else: 99 | self.emb_layers = nn.Sequential( 100 | nn.SiLU(), 101 | nn.Linear( 102 | emb_channels, 103 | self.emb_out_channels, 104 | ), 105 | ) 106 | 107 | self.out_layers = nn.Sequential( 108 | normalization(self.out_channels), 109 | nn.SiLU(), 110 | nn.Dropout(p=dropout), 111 | zero_module( 112 | nn.Conv2d( 113 | self.out_channels, 114 | self.out_channels, 115 | kernel_size, 116 | padding=padding, 117 | ) 118 | ), 119 | ) 120 | 121 | if self.out_channels == channels: 122 | self.skip_connection = nn.Identity() 123 | elif use_conv: 124 | self.skip_connection = nn.Conv2d( 125 | channels, self.out_channels, kernel_size, padding=padding 126 | ) 127 | else: 128 | self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) 129 | 130 | def forward(self, x, emb): 131 | h = self.in_layers(x) 132 | 133 | if self.skip_t_emb: 134 | emb_out = torch.zeros_like(h) 135 | else: 136 | emb_out = self.emb_layers(emb).type(h.dtype) 137 | while len(emb_out.shape) < len(h.shape): 138 | emb_out = emb_out[..., None] 139 | if self.use_scale_shift_norm: 140 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 141 | scale, shift = torch.chunk(emb_out, 2, dim=1) 142 | h = out_norm(h) * (1 + scale) + shift 143 | h = out_rest(h) 144 | else: 145 | if self.exchange_temb_dims: 146 | emb_out = rearrange(emb_out, "b t c ... -> b c t ...") 147 | h = h + emb_out 148 | h = self.out_layers(h) 149 | return self.skip_connection(x) + h 150 | 151 | 152 | class LatentAugmenter(nn.Module): 153 | def __init__(self, in_blocks=10, out_blocks=10, channels=128, dropout=0, attn=True): 154 | super().__init__() 155 | self.conv_in = nn.Conv2d(4, channels, 3, padding=1) 156 | 157 | self.channels = channels 158 | embed_dim = 16 159 | self.embed = nn.Sequential( 160 | nn.Linear(1, embed_dim), 161 | nn.SiLU(), 162 | nn.Linear(embed_dim, embed_dim), 163 | ) 164 | 165 | self.in_blocks = nn.ModuleList([]) 166 | for b in range(in_blocks): 167 | if (b == 1 or b == in_blocks - 1) and attn: 168 | self.in_blocks.append(make_attn(channels)) 169 | self.in_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) 170 | 171 | self.out_blocks = nn.ModuleList([]) 172 | for b in range(out_blocks): 173 | if (b == 1 or b == out_blocks - 1) and attn: 174 | self.out_blocks.append(make_attn(channels)) 175 | self.out_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) 176 | 177 | self.norm_out = normalization(channels) 178 | self.conv_out = nn.Conv2d(channels, 4, 3, padding=1) 179 | 180 | @classmethod 181 | def load_model(cls, filename, device="cpu", dtype=torch.float32, dropout=0): 182 | if not "weights_only" in torch.load.__code__.co_varnames: 183 | weights = torch.load(filename, map_location=torch.device("cpu")) 184 | else: 185 | weights = torch.load( 186 | filename, map_location=torch.device("cpu"), weights_only=True 187 | ) 188 | in_blocks = 0 189 | out_blocks = 0 190 | in_tfs = 0 191 | out_tfs = 0 192 | channels = weights["conv_in.bias"].shape[0] 193 | for k in weights.keys(): 194 | k = k.split(".") 195 | if k[0] == "in_blocks": 196 | in_blocks = max(in_blocks, int(k[1])) 197 | if k[2] == "q" and k[3] == "weight": 198 | in_tfs += 1 199 | if k[0] == "out_blocks": 200 | out_blocks = max(out_blocks, int(k[1])) 201 | if k[2] == "q" and k[3] == "weight": 202 | out_tfs += 1 203 | in_blocks = in_blocks + 1 - in_tfs 204 | out_blocks = out_blocks + 1 - out_tfs 205 | augmenter = cls( 206 | in_blocks=in_blocks, 207 | out_blocks=out_blocks, 208 | channels=channels, 209 | dropout=dropout, 210 | attn=(out_tfs != 0), 211 | ) 212 | augmenter.load_state_dict(weights) 213 | augmenter.eval() 214 | augmenter.to(device, dtype=dtype) 215 | return augmenter 216 | 217 | def forward(self, x, scale=None, size=None): 218 | if scale is None and size is None: 219 | raise ValueError("Either scale or size needs to be not None") 220 | if scale is not None and size is not None: 221 | raise ValueError("Both scale or size can't be not None") 222 | if scale is not None: 223 | size = (x.shape[-2] * scale, x.shape[-1] * scale) 224 | size = tuple([int(round(i)) for i in size]) 225 | else: 226 | scale = size[-1] / x.shape[-1] 227 | 228 | scale = torch.tensor([scale - 1], dtype=x.dtype).to(x.device).unsqueeze(0) 229 | emb = self.embed(scale) 230 | 231 | x = self.conv_in(x) 232 | 233 | for b in self.in_blocks: 234 | if isinstance(b, ResBlockEmb): 235 | x = b(x, emb) 236 | else: 237 | x = b(x) 238 | x = F.interpolate(x, size=size, mode="bilinear") 239 | for b in self.out_blocks: 240 | if isinstance(b, ResBlockEmb): 241 | x = b(x, emb) 242 | else: 243 | x = b(x) 244 | 245 | x = self.norm_out(x) 246 | x = F.silu(x) 247 | x = self.conv_out(x) 248 | return x 249 | -------------------------------------------------------------------------------- /notebooks/create_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "22ae025e-8f57-4454-a01a-7b9cf50098cc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install datasets transformers flair -q" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "044e834a-f6ab-4c93-8a1a-0747371c6cd7", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "/home/ubuntu/.local/lib/python3.8/site-packages/pandas/core/computation/expressions.py:20: UserWarning: Pandas requires version '2.7.3' or newer of 'numexpr' (version '2.7.1' currently installed).\n", 24 | " from pandas.core.computation.check import NUMEXPR_INSTALLED\n" 25 | ] 26 | }, 27 | { 28 | "data": { 29 | "application/vnd.jupyter.widget-view+json": { 30 | "model_id": "85a3f01bf9384318a16b4c4aac0a0823", 31 | "version_major": 2, 32 | "version_minor": 0 33 | }, 34 | "text/plain": [ 35 | "Downloading readme: 0%| | 0.00/777 [00:00] 185.54M 88.2MB/s in 2.1s \n", 183 | "\n", 184 | "2023-04-22 14:49:08 (88.2 MB/s) - ‘metadata.parquet’ saved [194548652/194548652]\n", 185 | "\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "!wget https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata.parquet" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "f957480f-5cd9-4d19-95ab-01fc4dfe1cae", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.8.10" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 5 223 | } 224 | -------------------------------------------------------------------------------- /train_cross_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from datasets import load_dataset 3 | import spacy 4 | import wandb 5 | from models.clip_emb_aug import CLIPEmbeddingAugmenter 6 | from models.cross_encoder import CrossEncoder, multiple_negatives_ranking_loss 7 | import fire 8 | import torch 9 | import wandb 10 | from torch.optim import AdamW 11 | from tqdm import tqdm 12 | from torch.utils.data import DataLoader 13 | from torchinfo import summary 14 | from utils import get_available_device, get_model_gradient_norm, compute_ndcg 15 | from typing import Dict, List 16 | from transformers import get_cosine_schedule_with_warmup 17 | from sklearn.metrics import recall_score, precision_score 18 | from dataclasses import dataclass 19 | from torch import Tensor 20 | import torch.nn.functional as F 21 | from copy import copy 22 | 23 | if not spacy.util.is_package("en_core_web_sm"): 24 | spacy.cli.download("en_core_web_sm") 25 | 26 | torch.manual_seed(0) 27 | 28 | 29 | @dataclass 30 | class LossFnOutput: 31 | loss: Tensor 32 | subject_embedding: Tensor 33 | descriptor_embedding: Tensor 34 | subject_text: List[str] 35 | descriptor_text: List[str] 36 | scores: Tensor 37 | labels: Tensor 38 | 39 | 40 | def loss_fn_emb_aug( 41 | batch: Dict, 42 | model: CLIPEmbeddingAugmenter, 43 | ) -> LossFnOutput: 44 | out = {} 45 | for key in ("subject", "descriptor"): 46 | emb = model(batch[key]) 47 | out[key] = batch[key] 48 | out[f"emb_{key}"] = emb 49 | # get the embeddings for both the subject and the descriptor batch 50 | scores, labels = multiple_negatives_ranking_loss( 51 | out["emb_subject"], out["emb_descriptor"] 52 | ) 53 | loss = F.cross_entropy(scores, labels) 54 | return LossFnOutput( 55 | loss, 56 | out["emb_subject"], 57 | out["emb_descriptor"], 58 | out["subject"], 59 | out["descriptor"], 60 | scores, 61 | labels, 62 | ) 63 | 64 | 65 | # eval_every and valid_every are in terms of batches 66 | def main(use_wandb: bool = False, eval_every: int = 100, valid_every: int = 100): 67 | device = get_available_device() 68 | 69 | print("Loading dataset..") 70 | dataset = load_dataset("roborovski/diffusiondb-seq2seq") 71 | 72 | print("Loading model..") 73 | model = CrossEncoder(device) 74 | print(summary(model)) 75 | 76 | # Hyperparameters 77 | num_epochs: int = 200 78 | learning_rate: float = 2e-5 79 | batch_size: int = 64 80 | warmup_steps: int = 10 81 | max_grad_norm = 1 82 | 83 | samples_table = wandb.Table( 84 | data=[], 85 | columns=[ 86 | "epoch", 87 | "subject", 88 | "descriptor", 89 | ], 90 | ) 91 | 92 | optimizer = AdamW(model.parameters(), lr=learning_rate) 93 | steps_per_epoch = len(dataset["train"]) // batch_size 94 | num_training_steps = steps_per_epoch * num_epochs 95 | 96 | scheduler = get_cosine_schedule_with_warmup( 97 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps 98 | ) 99 | 100 | if use_wandb: 101 | wandb.init(project="superprompt-cross-encoder") 102 | wandb.watch(model) 103 | 104 | dataset = dataset["train"].train_test_split(test_size=int(48), seed=42) 105 | train_loader = DataLoader(dataset["train"], batch_size=batch_size) 106 | eval_loader = DataLoader(dataset["test"], batch_size=len(dataset["test"])) 107 | 108 | for rank, epoch in enumerate(range(num_epochs)): 109 | train_iter = tqdm(train_loader, total=len(train_loader)) 110 | for j, batch in enumerate(train_iter): 111 | out = loss_fn_emb_aug(batch, model) 112 | lr = optimizer.param_groups[0]["lr"] 113 | 114 | loss_formatted = round(out.loss.item(), 4) 115 | lr_formatted = round(lr, 8) 116 | gradient_norm = round(get_model_gradient_norm(model), 4) 117 | log_dict = { 118 | "loss": loss_formatted, 119 | "lr": lr_formatted, 120 | "grad_norm": gradient_norm, 121 | "epoch": epoch, 122 | } 123 | 124 | train_iter.set_postfix(log=log_dict) 125 | if use_wandb: 126 | wandb.log(log_dict) 127 | 128 | # Backward pass and optimization 129 | out.loss.backward() 130 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 131 | optimizer.step() 132 | optimizer.zero_grad() 133 | scheduler.step() 134 | 135 | if j % eval_every == 0: 136 | print("---Running eval---") 137 | eval_iter = tqdm(eval_loader, total=len(eval_loader)) 138 | model.eval() 139 | for batch in eval_iter: 140 | out = loss_fn_emb_aug(batch, model) 141 | true_rankings = out.labels.cpu().detach().numpy() 142 | pred_rankings = ( 143 | torch.argmax(out.scores, dim=1).cpu().detach().numpy() 144 | ) 145 | 146 | k_value = 10 147 | 148 | loss_formatted = round(out.loss.item(), 4) 149 | 150 | # accuracy - how many of the top 10 are correct 151 | # precision - number of correct results divided by the number of all returned results 152 | # recall - number of correct results divided by the number of results that should have been returned 153 | # MRR - mean reciprocal rank 154 | # NDCG - normalized discounted cumulative gain 155 | # MAP - mean average precision 156 | 157 | # hits = number of correct results in top 10 158 | # mrr = 1 / rank of first correct result 159 | hits, mrr, sum_precisions = 0, 0, 0 160 | for rank in range(0, k_value): 161 | if true_rankings[rank] in pred_rankings[:k_value]: 162 | hits += 1 163 | if hits == 1: 164 | mrr = 1 / (rank + 1) 165 | sum_precisions += hits / (rank + 1) 166 | 167 | precision = hits / k_value 168 | recall = hits / len(true_rankings) 169 | f1_score = 2 * (precision * recall) / (precision + recall) 170 | avg_precision = sum_precisions / k_value 171 | 172 | ndcg = compute_ndcg(true_rankings, pred_rankings, k_value) 173 | 174 | num_samples_to_log = 4 175 | 176 | subject_text = out.subject_text[:num_samples_to_log] 177 | subject_descriptors_ranked = [ 178 | out.descriptor_text[i] for i in pred_rankings 179 | ][:num_samples_to_log] 180 | 181 | log_dict = { 182 | "eval_loss": loss_formatted, 183 | "ndcg": ndcg, 184 | "mrr": mrr, 185 | "recall": recall, 186 | "precision": precision, 187 | "f1": f1_score, 188 | "map": avg_precision, 189 | "epoch": epoch, 190 | } 191 | 192 | # why are wandb tables so bad 193 | # https://docs.wandb.ai/guides/track/log/log-tables 194 | for rank in range(num_samples_to_log): 195 | samples_table.add_data( 196 | epoch, subject_text[rank], subject_descriptors_ranked[rank] 197 | ) 198 | if use_wandb: 199 | log_dict["samples"] = samples_table 200 | wandb.log(log_dict) 201 | 202 | print("---Eval stats---") 203 | print(log_dict) 204 | 205 | model.train() 206 | 207 | # # TODO rewrite this. use the retrieved annotations to generate images 208 | # if i % valid_every == 0: 209 | # eval_iter = tqdm(eval_loader, total=len(eval_loader)) 210 | # pipe = StableDiffusionPipeline.from_pretrained( 211 | # "runwayml/stable-diffusion-v1-5", 212 | # torch_dtype=torch.float16, 213 | # safety_checker=None, 214 | # ) 215 | # pipe = pipe.to("cuda") 216 | # if is_xformers_available(): 217 | # pipe.unet.enable_xformers_memory_efficient_attention() 218 | # for batch in eval_iter: 219 | 220 | # loss = loss_fn_emb_aug(batch, model) 221 | 222 | # log_dict = { 223 | # "loss": loss.item(), 224 | # "unmasked": [], 225 | # "encoded": [], 226 | # } 227 | 228 | # # clear out directory 229 | # Path("out").mkdir(exist_ok=True) 230 | # shutil.rmtree("out") 231 | # Path("out").mkdir(exist_ok=True) 232 | 233 | # generations = pipe(prompt="").images 234 | # os.makedirs("out", exist_ok=True) 235 | # for i, generation in enumerate(generations): 236 | # generation.save(f"out/{key}_{i}.png") 237 | # if use_wandb: 238 | # log_dict[key].append(wandb.Image(generation)) 239 | 240 | # if use_wandb: 241 | # wandb.log(log_dict) 242 | 243 | # del pipe 244 | # gc.collect() 245 | # torch.cuda.empty_cache() 246 | 247 | wandb.finish() 248 | 249 | 250 | if __name__ == "__main__": 251 | fire.Fire(main) 252 | -------------------------------------------------------------------------------- /train_bert_seq2seq.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from enum import IntEnum 3 | import torch 4 | import torch.nn as nn 5 | from utils import ( 6 | get_available_device, 7 | sample_prompt_pairs, 8 | sample_translate_pairs, 9 | ) 10 | from superprompt.models.rnn import Attention, Encoder, Decoder, Seq2Seq 11 | from torch.optim import AdamW 12 | import time 13 | import math 14 | from datasets import load_dataset, Dataset, ReadInstruction 15 | from transformers import ( 16 | BertTokenizer, 17 | DataCollatorForSeq2Seq, 18 | ) 19 | import random 20 | import wandb 21 | import numpy as np 22 | from torch.utils.data import DataLoader 23 | import fire 24 | from typing import List 25 | from tabulate import tabulate 26 | import textwrap 27 | 28 | SEED = 1234 29 | 30 | random.seed(SEED) 31 | np.random.seed(SEED) 32 | torch.manual_seed(SEED) 33 | torch.cuda.manual_seed(SEED) 34 | torch.backends.cudnn.deterministic = True 35 | 36 | 37 | class Task(IntEnum): 38 | DIFFUSION = 1 39 | TRANSLATE = 2 40 | 41 | 42 | class Args(Namespace): 43 | enc_emb_dim = 256 44 | dec_emb_dim = 256 45 | enc_hid_dim = 512 46 | dec_hid_dim = 512 47 | enc_dropout = 0.2 48 | dec_dropout = 0.2 49 | n_epochs = 10 50 | max_norm = 1 51 | max_length = 64 52 | batch_size = 64 53 | log_freq = 2 54 | # this is in samples 55 | valid_freq = 128 56 | task = Task.DIFFUSION.value 57 | sample_limit = 10e5 58 | 59 | max_length = Args.max_length 60 | 61 | def tokenize_batch(batch): 62 | src = [f"[BOS] {s} [EOS]" for s in batch["src"]] 63 | src = tokenizer( 64 | batch["src"], 65 | truncation=True, 66 | return_length=True, 67 | padding="max_length", 68 | max_length=max_length, 69 | return_tensors="pt", 70 | ) 71 | 72 | trg = [f"[BOS] {s} [EOS]" for s in batch["trg"]] 73 | trg = tokenizer( 74 | batch["trg"], 75 | truncation=True, 76 | return_length=True, 77 | padding="max_length", 78 | max_length=max_length, 79 | return_tensors="pt", 80 | ) 81 | return { 82 | "src_input_ids": src["input_ids"], 83 | "src_len": src["length"], 84 | "trg_input_ids": trg["input_ids"], 85 | } 86 | 87 | 88 | tokenizer: BertTokenizer = BertTokenizer.from_pretrained( 89 | "bert-base-uncased", use_fast=True 90 | ) 91 | tokenizer.add_special_tokens( 92 | {"pad_token": "[PAD]", "bos_token": "[BOS]", "eos_token": "[EOS]"} 93 | ) 94 | 95 | print("Task: ", Task(Args.task).name) 96 | valid_src = [] 97 | 98 | if Args.task == Task.DIFFUSION.value: 99 | dataset = load_dataset( 100 | "roborovski/diffusiondb-masked-no-descriptors", 101 | split=ReadInstruction("train", to=25, unit="%") 102 | ) 103 | dataset = dataset.rename_columns({"masked": "src", "prompt": "trg"}) 104 | dataset = dataset.map( 105 | tokenize_batch, 106 | batched=True, 107 | batch_size=256, 108 | remove_columns=["src", "trg"], 109 | ) 110 | dataset = dataset.train_test_split(test_size=0.1) 111 | valid_dataset = Dataset.from_dict( 112 | { 113 | "src": [x[0] for x in sample_prompt_pairs], 114 | "trg": [x[1] for x in sample_prompt_pairs], 115 | } 116 | ) 117 | elif Args.task == Task.TRANSLATE.value: 118 | dataset = load_dataset("bentrevett/multi30k") 119 | dataset = dataset.rename_columns({"de": "src", "en": "trg"}) 120 | dataset = dataset.map( 121 | tokenize_batch, 122 | batched=True, 123 | batch_size=Args.batch_size, 124 | remove_columns=["src", "trg"], 125 | ) 126 | valid_dataset = Dataset.from_dict( 127 | { 128 | "src": [x[0] for x in sample_translate_pairs], 129 | "trg": [x[1] for x in sample_translate_pairs], 130 | } 131 | ) 132 | 133 | valid_src = [valid_dataset["src"], valid_dataset["trg"]] 134 | valid_dataset = tokenize_batch(valid_dataset) 135 | 136 | input_dim_size = tokenizer.vocab_size 137 | attn = Attention(Args.enc_hid_dim, Args.dec_hid_dim) 138 | enc = Encoder( 139 | input_dim_size, 140 | Args.enc_emb_dim, 141 | Args.enc_hid_dim, 142 | Args.dec_hid_dim, 143 | Args.enc_dropout, 144 | ) 145 | dec = Decoder( 146 | input_dim_size, 147 | Args.dec_emb_dim, 148 | Args.enc_hid_dim, 149 | Args.dec_hid_dim, 150 | Args.dec_dropout, 151 | attn, 152 | ) 153 | 154 | device = get_available_device() 155 | model = Seq2Seq(enc, dec, tokenizer.pad_token_id, device).to(device) 156 | 157 | print("Device: ", device) 158 | 159 | 160 | def init_weights(m): 161 | for name, param in m.named_parameters(): 162 | if "weight" in name: 163 | nn.init.normal_(param.data, mean=0, std=0.01) 164 | else: 165 | nn.init.constant_(param.data, 0) 166 | 167 | 168 | model.apply(init_weights) 169 | 170 | optimizer = AdamW(model.parameters(), lr=3e-4) 171 | 172 | criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) 173 | 174 | 175 | def train(model: Seq2Seq, epoch: int, valid_table_data: List[str], dataset: Dataset, optimizer, criterion, use_wandb: bool): 176 | model.train() 177 | 178 | epoch_loss = 0 179 | loader = DataLoader(dataset, batch_size=Args.batch_size, shuffle=True) 180 | 181 | validate(model, valid_table_data, 0, use_wandb) 182 | for i, batch in enumerate(loader): 183 | if i > Args.sample_limit: 184 | print("Sample limit reached, returning.") 185 | break 186 | src_input_ids = torch.stack(batch["src_input_ids"]).to(device) 187 | trg_input_ids = torch.stack(batch["trg_input_ids"]).to(device) 188 | src_len = batch["src_len"].to(device) 189 | 190 | optimizer.zero_grad() 191 | 192 | output = model(src_input_ids, src_len, trg_input_ids) 193 | 194 | # trg = [trg len, batch size] 195 | # output = [trg len, batch size, output dim] 196 | 197 | output_dim = output.shape[-1] 198 | 199 | output = output[1:].view(-1, output_dim) 200 | trg_input_ids = trg_input_ids[1:].view(-1) 201 | 202 | # trg = [(trg len - 1) * batch size] 203 | # output = [(trg len - 1) * batch size, output dim] 204 | 205 | loss = criterion(output, trg_input_ids) 206 | loss_rounded = round(loss.item(), 3) 207 | # loss.requires_grad = True 208 | print(f"Batch {i}: {loss_rounded}") 209 | if i % Args.log_freq == 0 and use_wandb: 210 | wandb.log({"loss": loss_rounded}) 211 | 212 | if i % Args.valid_freq == 0: 213 | validate(model, valid_table_data, i, use_wandb) 214 | 215 | loss.backward() 216 | 217 | torch.nn.utils.clip_grad_norm_(model.parameters(), Args.max_norm) 218 | 219 | optimizer.step() 220 | 221 | epoch_loss += loss.item() 222 | 223 | return epoch_loss / len(dataset) 224 | 225 | 226 | def evaluate(model: Seq2Seq, dataset: Dataset, criterion): 227 | model.eval() 228 | 229 | epoch_loss = 0 230 | loader = DataLoader(dataset, batch_size=Args.batch_size, shuffle=True) 231 | 232 | with torch.no_grad(): 233 | for i, batch in enumerate(loader): 234 | src_input_ids = torch.stack(batch["src_input_ids"]).to(device) 235 | trg_input_ids = torch.stack(batch["trg_input_ids"]).to(device) 236 | src_len = batch["src_len"].to(device) 237 | 238 | optimizer.zero_grad() 239 | 240 | output = model(src_input_ids, src_len, trg_input_ids) 241 | 242 | # trg = [trg len, batch size] 243 | # output = [trg len, batch size, output dim] 244 | 245 | output_dim = output.shape[-1] 246 | 247 | output = output[1:].view(-1, output_dim) 248 | trg_input_ids = trg_input_ids[1:].view(-1) 249 | # trg = [trg len, batch size] 250 | # output = [trg len, batch size, output dim] 251 | 252 | output_dim = output.shape[-1] 253 | 254 | output = output[1:].view(-1, output_dim) 255 | 256 | # trg = [(trg len - 1) * batch size] 257 | # output = [(trg len - 1) * batch size, output dim] 258 | 259 | if output.shape[0] != trg_input_ids.shape[0]: 260 | print( 261 | "output shape : ", 262 | output.shape, 263 | " does not match trg_input_ids: ", 264 | trg_input_ids.shape, 265 | ) 266 | continue 267 | loss = criterion(output, trg_input_ids) 268 | 269 | epoch_loss += loss.item() 270 | 271 | return epoch_loss / len(dataset) 272 | 273 | 274 | def validate(model: Seq2Seq, valid_table_data: List[str], batch_idx: int, use_wandb: bool): 275 | model.eval() 276 | 277 | with torch.no_grad(): 278 | src_input_ids = valid_dataset["src_input_ids"].transpose(1, 0).to(device) 279 | trg_input_ids = valid_dataset["trg_input_ids"].transpose(1, 0).to(device) 280 | src_len = valid_dataset["src_len"].to(device) 281 | 282 | outputs = model(src_input_ids, src_len, trg_input_ids) 283 | 284 | # Create a target tensor with batch size 1 and length max_length with all tokens masked 285 | outputs = outputs.argmax(dim=-1) 286 | outputs = outputs.transpose(1, 0) 287 | output_ls = outputs.squeeze().tolist() 288 | outputs = [tokenizer.decode(x) for x in output_ls] 289 | for i in range(len(valid_src)): 290 | source_text = textwrap.wrap(valid_src[0][i], width=50, break_long_words=True) 291 | expected_text = textwrap.wrap(valid_src[1][i], width=50, break_long_words=True) 292 | generated_text = textwrap.wrap(outputs[i], width=50, break_long_words=True) 293 | valid_table_data.append( 294 | [ 295 | source_text, expected_text, generated_text 296 | ] 297 | ) 298 | print(tabulate(valid_table_data, headers=["input", "expected", "output"])) 299 | if use_wandb: 300 | sample_table = wandb.Table( 301 | columns=["epoch", "idx", "input", "expected", "output"], 302 | data=valid_table_data, 303 | ) 304 | wandb.log({"sample": sample_table}) 305 | model.train() 306 | 307 | 308 | def epoch_time(start_time, end_time): 309 | elapsed_time = end_time - start_time 310 | elapsed_mins = int(elapsed_time / 60) 311 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 312 | return elapsed_mins, elapsed_secs 313 | 314 | def main(use_wandb: bool = False): 315 | 316 | best_valid_loss = float("inf") 317 | 318 | if use_wandb: 319 | wandb.init(config=Args, project="superprompt-seq2seq-rnn") 320 | wandb.watch(model, log_freq=Args.log_freq) 321 | print("wandb initialized") 322 | 323 | valid_table_data = [] 324 | 325 | for epoch in range(Args.n_epochs): 326 | start_time = time.time() 327 | 328 | train_loss = train(model, epoch, valid_table_data, dataset["train"], optimizer, criterion, use_wandb) 329 | print("train_loss: ", train_loss) 330 | eval_loss = evaluate(model, dataset["test"], criterion) 331 | if use_wandb: 332 | wandb.log( 333 | { 334 | "train_loss": train_loss, 335 | "eval_loss": eval_loss, 336 | "epoch": epoch, 337 | "lr": optimizer.param_groups[0]["lr"], 338 | "PPL": math.exp(train_loss), 339 | } 340 | ) 341 | print("eval_loss", eval_loss) 342 | 343 | end_time = time.time() 344 | 345 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 346 | 347 | if eval_loss < best_valid_loss: 348 | best_valid_loss = eval_loss 349 | task = Args.task 350 | torch.save(model.state_dict(), f"model-{epoch}-task{task}.pt") 351 | 352 | print(f"Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s") 353 | print(f"\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}") 354 | print(f"\t Val. Loss: {eval_loss:.3f} | Val. PPL: {math.exp(eval_loss):7.3f}") 355 | 356 | 357 | if __name__ == "__main__": 358 | fire.Fire(main) -------------------------------------------------------------------------------- /notebooks/bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2e1c04ae-a3a2-4b05-8d3c-a724e7673625", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "Note: you may need to restart the kernel to use updated packages.\n", 14 | "Defaulting to user installation because normal site-packages is not writeable\n", 15 | "Collecting protobuf<=3.20.1\n", 16 | " Using cached protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n", 17 | "Installing collected packages: protobuf\n", 18 | " Attempting uninstall: protobuf\n", 19 | " Found existing installation: protobuf 3.20.1\n", 20 | " Uninstalling protobuf-3.20.1:\n", 21 | " Successfully uninstalled protobuf-3.20.1\n", 22 | "Successfully installed protobuf-3.20.1\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "%pip install transformers datasets wandb -q\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import torch.nn.functional as F\n", 31 | "import math\n", 32 | "\n", 33 | "!pip install 'protobuf<=3.20.1' --force-reinstall\n", 34 | "!export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "0cf14902-c183-4ce8-bd2d-36d0ad23029d", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "\n", 45 | "class PositionEmbed(nn.Module):\n", 46 | " \n", 47 | " def __init__(self, dim_model, max_len=512) -> None:\n", 48 | " super().__init__()\n", 49 | " \n", 50 | " # embedding is a matrix of size (max_len, dim_model)\n", 51 | " # for each possible position i, j contains the sinusoid of frequency i / 10000^(2j/dim_model)\n", 52 | " pe = torch.zeros(max_len, dim_model)\n", 53 | " pe.requires_grad = False\n", 54 | " \n", 55 | " # create a 2D tensor with the position indices\n", 56 | " position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n", 57 | " div_term = (torch.arange(0, dim_model, 2, dtype=torch.float) * -(math.log(10000.0) / dim_model)).exp()\n", 58 | "\n", 59 | " # for each 2 entries, starting at 0, we get a sin and cos activation\n", 60 | " pe[:, 0::2] = torch.sin(position * div_term)\n", 61 | " pe[:, 1::2] = torch.cos(position * div_term)\n", 62 | "\n", 63 | " pe = pe.unsqueeze(0)\n", 64 | " self.register_buffer('pe', pe)\n", 65 | "\n", 66 | " def forward(self, x):\n", 67 | " # get the position embeddings for all tokens up to the current position idx\n", 68 | " return self.pe[:, :x.size(1)]\n", 69 | "\n", 70 | "class BERTEmbedding(nn.Module):\n", 71 | " def __init__(self, vocab_size, embed_size=512, dropout=0.1) -> None:\n", 72 | " super().__init__()\n", 73 | " \n", 74 | " self.token = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n", 75 | " embedding_dim = self.token.embedding_dim\n", 76 | " self.position = PositionEmbed(dim_model=embedding_dim)\n", 77 | " self.dropout = nn.Dropout(p=dropout)\n", 78 | " self.embed_size = embed_size\n", 79 | "\n", 80 | " def forward(self, sequence):\n", 81 | " x = self.token(sequence)\n", 82 | " x = x + self.position(x)\n", 83 | " x = self.dropout(x)\n", 84 | " return x\n", 85 | "\n", 86 | "# Compute a single attention head\n", 87 | "class Attention(nn.Module):\n", 88 | " \n", 89 | " # matrix multiplication of query and key, then scaled by the square root of the dimension of the query\n", 90 | " def forward(self, query, key, value, mask=None, dropout=None):\n", 91 | " scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))\n", 92 | " if mask is not None:\n", 93 | " scores = scores.masked_fill(mask == 0, -1e9)\n", 94 | "\n", 95 | " p_attn = F.softmax(scores, dim=-1)\n", 96 | "\n", 97 | " if dropout is not None:\n", 98 | " p_attn = dropout(p_attn)\n", 99 | "\n", 100 | " return torch.matmul(p_attn, value), p_attn\n", 101 | "\n", 102 | "class MultiHeadAttention(nn.Module):\n", 103 | " def __init__(self, attn_heads, hidden, dropout=0.1) -> None:\n", 104 | " super().__init__()\n", 105 | " assert hidden % attn_heads == 0\n", 106 | "\n", 107 | " # We assume d_v always equals d_k\n", 108 | " self.d_k = hidden // attn_heads\n", 109 | " self.h = attn_heads\n", 110 | "\n", 111 | " # linear layers for query, key and value\n", 112 | " self.linear_layers = nn.ModuleList([nn.Linear(hidden, hidden) for _ in range(3)])\n", 113 | " # final linear layer for output\n", 114 | " self.output_linear = nn.Linear(hidden, hidden)\n", 115 | " \n", 116 | " # attention - performed per batch of queries\n", 117 | " self.attention = Attention()\n", 118 | "\n", 119 | " self.dropout = nn.Dropout(p=dropout)\n", 120 | "\n", 121 | " def forward(self, query, key, value, mask=None):\n", 122 | " batch_size = query.size(0)\n", 123 | " \n", 124 | " # linear projection from hidden to d_k * h\n", 125 | " # i.e. for each linear layer, we get the query, key and value\n", 126 | " # these represent the linear layer for each head\n", 127 | " query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]\n", 128 | "\n", 129 | " # compute attention for all heads in a batch\n", 130 | " x, attention = self.attention(query, key, value, mask=mask, dropout=self.dropout)\n", 131 | "\n", 132 | " # concatenate all heads\n", 133 | " x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)\n", 134 | "\n", 135 | " # apply final linear layer\n", 136 | " return self.output_linear(x)\n", 137 | "\n", 138 | "class SublayerConnection(nn.Module):\n", 139 | " \n", 140 | " def __init__(self, hidden, dropout) -> None:\n", 141 | " super(SublayerConnection, self).__init__()\n", 142 | " self.norm = nn.LayerNorm(hidden)\n", 143 | " self.dropout = nn.Dropout(p=dropout)\n", 144 | "\n", 145 | " def forward(self, x, sublayer):\n", 146 | " return x + self.dropout(sublayer(self.norm(x)))\n", 147 | "\n", 148 | "# Feed forward layer, with dropout and GELU activation\n", 149 | "class PositionwiseFeedForward(nn.Module):\n", 150 | " \n", 151 | " def __init__(self, hidden, feed_forward_hidden, dropout=0.1) -> None:\n", 152 | " super(PositionwiseFeedForward, self).__init__()\n", 153 | " self.w_1 = nn.Linear(hidden, feed_forward_hidden)\n", 154 | " self.w_2 = nn.Linear(feed_forward_hidden, hidden)\n", 155 | " self.dropout = nn.Dropout(p=dropout)\n", 156 | " # gelu is the same as RELU with a slight dip before 0\n", 157 | " self.activation = nn.GELU()\n", 158 | "\n", 159 | " def forward(self, x):\n", 160 | " return self.w_2(self.dropout(self.activation(self.w_1(x))))\n", 161 | "\n", 162 | "class TransformerBlock(nn.Module):\n", 163 | " \n", 164 | " def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout) -> None:\n", 165 | " super().__init__()\n", 166 | " self.attention = MultiHeadAttention(attn_heads, hidden)\n", 167 | " self.feed_forward = PositionwiseFeedForward(hidden, feed_forward_hidden, dropout)\n", 168 | " self.input_sublayer = SublayerConnection(hidden, dropout)\n", 169 | " self.output_sublayer = SublayerConnection(hidden, dropout)\n", 170 | " self.dropout = nn.Dropout(p=dropout)\n", 171 | "\n", 172 | " def forward(self, x, mask):\n", 173 | " x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask))\n", 174 | " x = self.output_sublayer(x, self.feed_forward)\n", 175 | " return self.dropout(x)\n", 176 | "\n", 177 | "class BERT(nn.Module):\n", 178 | " \n", 179 | " def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):\n", 180 | " super().__init__()\n", 181 | " self.hidden = hidden\n", 182 | " self.n_layers = n_layers\n", 183 | " self.attn_heads = attn_heads \n", 184 | "\n", 185 | " self.feed_forward_hidden = hidden * 4 # 4 is hyperparameter\n", 186 | "\n", 187 | " self.embedding = BERTEmbedding(vocab_size, hidden, dropout)\n", 188 | "\n", 189 | " self.transformer_blocks = nn.ModuleList(\n", 190 | " [TransformerBlock(hidden, attn_heads, hidden*4, dropout) for _ in range(n_layers)]\n", 191 | " )\n", 192 | "\n", 193 | " # masked LM\n", 194 | " self.linear = nn.Linear(hidden, vocab_size)\n", 195 | " self.softmax = nn.LogSoftmax(dim=-1)\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | " \n", 199 | " # attention mask for padded token\n", 200 | " # torch.ByteTensor([batch_size, 1, seq_len, seq_len)\n", 201 | " mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n", 202 | "\n", 203 | " # get the embedding for the input sequence\n", 204 | " x = self.embedding(x)\n", 205 | "\n", 206 | " for transformer in self.transformer_blocks:\n", 207 | " x = transformer(x, mask)\n", 208 | "\n", 209 | " # masked LM\n", 210 | " x = self.softmax(self.linear(x))\n", 211 | " \n", 212 | " return x" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 3, 218 | "id": "5f05ba6c", 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "name": "stderr", 223 | "output_type": "stream", 224 | "text": [ 225 | "/home/ubuntu/.local/lib/python3.8/site-packages/pandas/core/computation/expressions.py:20: UserWarning: Pandas requires version '2.7.3' or newer of 'numexpr' (version '2.7.1' currently installed).\n", 226 | " from pandas.core.computation.check import NUMEXPR_INSTALLED\n", 227 | "Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/Gustavosta___parquet/Gustavosta--Stable-Diffusion-Prompts-d22aeec0ba2a9fdb/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" 228 | ] 229 | }, 230 | { 231 | "data": { 232 | "application/vnd.jupyter.widget-view+json": { 233 | "model_id": "1c424e957f864cf1a5a754c60e0680b0", 234 | "version_major": 2, 235 | "version_minor": 0 236 | }, 237 | "text/plain": [ 238 | " 0%| | 0/2 [00:00 1:\n", 341 | " print(\"Using %d GPUS for BERT\" % torch.cuda.device_count())\n", 342 | " self.model = nn.DataParallel(self.model, device_ids=cuda_devices)\n", 343 | "\n", 344 | " # Setting the train and test data loader\n", 345 | " self.train_data = train_dataloader\n", 346 | " self.test_data = test_dataloader\n", 347 | "\n", 348 | " # Setting the Adam optimizer with hyper-param\n", 349 | " self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)\n", 350 | " self.optim_schedule = ScheduledOptim(self.optim, self.model.hidden, n_warmup_steps=warmup_steps)\n", 351 | "\n", 352 | " # Using Negative Log Likelihood Loss function for predicting the masked_token\n", 353 | " self.criterion = nn.NLLLoss(ignore_index=0)\n", 354 | "\n", 355 | " self.log_freq = log_freq\n", 356 | "\n", 357 | " def train(self, epoch):\n", 358 | " self.iteration(epoch, self.train_data)\n", 359 | "\n", 360 | " def test(self, epoch):\n", 361 | " self.iteration(epoch, self.test_data, train=False)\n", 362 | "\n", 363 | " def iteration(self, epoch, data_loader, train=True):\n", 364 | " str_code = \"train\" if train else \"test\"\n", 365 | "\n", 366 | " # Setting the tqdm progress bar\n", 367 | " data_iter = tqdm.tqdm(enumerate(data_loader),\n", 368 | " desc=\"EP_%s:%d\" % (str_code, epoch),\n", 369 | " total=len(data_loader),\n", 370 | " bar_format=\"{l_bar}{r_bar}\")\n", 371 | "\n", 372 | " avg_loss = 0.0\n", 373 | "\n", 374 | " for i, data in data_iter:\n", 375 | " # 0. batch_data will be sent into the device(GPU or cpu)\n", 376 | " input_ids = data[\"input_ids\"].to(self.device)\n", 377 | "\n", 378 | " # 1. forward the next_sentence_prediction and masked_lm model\n", 379 | " mask_lm_output = self.model.forward(input_ids)\n", 380 | "\n", 381 | " # 2-2. NLLLoss of predicting masked token word\n", 382 | " loss = self.criterion(mask_lm_output.transpose(1, 2), input_ids)\n", 383 | "\n", 384 | " # 3. backward and optimization only in train\n", 385 | " if train:\n", 386 | " self.optim_schedule.zero_grad()\n", 387 | " loss.backward()\n", 388 | " self.optim_schedule.step_and_update_lr()\n", 389 | "\n", 390 | " # next sentence prediction accuracy\n", 391 | " avg_loss += loss.item()\n", 392 | "\n", 393 | " post_fix = {\n", 394 | " \"epoch\": epoch,\n", 395 | " \"iter\": i,\n", 396 | " \"avg_loss\": avg_loss / (i + 1),\n", 397 | " \"loss\": loss.item()\n", 398 | " }\n", 399 | "\n", 400 | " if i % self.log_freq == 0:\n", 401 | " wandb.log(post_fix)\n", 402 | " data_iter.write(str(post_fix))\n", 403 | " print(\"EP%d_%s, avg_loss=\" % (epoch, str_code), avg_loss / len(data_iter))\n", 404 | "\n", 405 | " def save(self, epoch, file_path=\"output/bert_trained.model\"):\n", 406 | " \"\"\"\n", 407 | " Saving the current BERT model on file_path\n", 408 | "\n", 409 | " :param epoch: current epoch number\n", 410 | " :param file_path: model output path which gonna be file_path+\"ep%d\" % epoch\n", 411 | " :return: final_output_path\n", 412 | " \"\"\"\n", 413 | " output_path = file_path + \".ep%d\" % epoch\n", 414 | " torch.save(self.bert.cpu(), output_path)\n", 415 | " self.bert.to(self.device)\n", 416 | " print(\"EP:%d Model Saved on:\" % epoch, output_path)\n", 417 | " return output_path\n" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "id": "c64545c8", 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "Namespace(adam_beta1=0.9, adam_beta2=0.999, adam_weight_decay=0.01, attn_heads=8, batch_size=64, cuda_devices=[0], epochs=10, hidden=256, layers=8, log_freq=50, lr=0.001, num_workers=4, output_path='/home/ubuntu/superprompt/saved', with_cuda=True)\n" 431 | ] 432 | }, 433 | { 434 | "data": { 435 | "text/html": [ 436 | "Finishing last run (ID:1yv71wqu) before initializing another..." 437 | ], 438 | "text/plain": [ 439 | "" 440 | ] 441 | }, 442 | "metadata": {}, 443 | "output_type": "display_data" 444 | }, 445 | { 446 | "data": { 447 | "text/html": [ 448 | "Waiting for W&B process to finish... (success)." 449 | ], 450 | "text/plain": [ 451 | "" 452 | ] 453 | }, 454 | "metadata": {}, 455 | "output_type": "display_data" 456 | } 457 | ], 458 | "source": [ 459 | "from argparse import Namespace\n", 460 | "from torch.utils.data import DataLoader\n", 461 | "import gc\n", 462 | "gc.collect()\n", 463 | "\n", 464 | "torch.cuda.empty_cache()\n", 465 | "\n", 466 | "\n", 467 | "args = Namespace(\n", 468 | " hidden=256,\n", 469 | " batch_size=64,\n", 470 | " layers=8,\n", 471 | " attn_heads=8,\n", 472 | " adam_weight_decay=0.01,\n", 473 | " adam_beta1=0.9,\n", 474 | " output_path=\"/home/ubuntu/superprompt/saved\",\n", 475 | " epochs=10,\n", 476 | " log_freq=50,\n", 477 | " adam_beta2=0.999,\n", 478 | " cuda_devices=[0],\n", 479 | " num_workers=4,\n", 480 | " lr=1e-3,\n", 481 | " with_cuda=True,\n", 482 | ")\n", 483 | "\n", 484 | "print(args)\n", 485 | "# !wandb init\n", 486 | "wandb.init(config=args, project=\"superprompt\")\n", 487 | "\n", 488 | "\n", 489 | "print(\"Building BERT model\")\n", 490 | "\n", 491 | "bert = BERT(\n", 492 | " tokenizer.vocab_size,\n", 493 | " hidden=args.hidden,\n", 494 | " n_layers=args.layers,\n", 495 | " attn_heads=args.attn_heads,\n", 496 | ")\n", 497 | "wandb.watch(bert, log_freq=args.log_freq)\n", 498 | "\n", 499 | "\n", 500 | "train_dataloader = DataLoader(\n", 501 | " dataset[\"train\"], batch_size=args.batch_size, num_workers=args.num_workers\n", 502 | ")\n", 503 | "test_dataloader = DataLoader(\n", 504 | " dataset[\"test\"], batch_size=args.batch_size, num_workers=args.num_workers\n", 505 | ")\n", 506 | "\n", 507 | "print(\"Creating BERT Trainer\")\n", 508 | "trainer = BERTTrainer(\n", 509 | " bert,\n", 510 | " tokenizer.vocab_size,\n", 511 | " train_dataloader=train_dataloader,\n", 512 | " test_dataloader=test_dataloader,\n", 513 | " lr=args.lr,\n", 514 | " betas=(args.adam_beta1, args.adam_beta2),\n", 515 | " weight_decay=args.adam_weight_decay,\n", 516 | " log_freq=args.log_freq,\n", 517 | " with_cuda=args.with_cuda,\n", 518 | " cuda_devices=args.cuda_devices,\n", 519 | ")\n", 520 | "\n", 521 | "for epoch in range(args.epochs):\n", 522 | " trainer.train(epoch)\n", 523 | " trainer.save(epoch, args.output_path)\n", 524 | "\n", 525 | " if test_dataloader is not None:\n", 526 | " trainer.test(epoch)\n", 527 | "\n" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "id": "d1089e36-90ed-47a4-a465-50e5c8e6834d", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [] 537 | } 538 | ], 539 | "metadata": { 540 | "kernelspec": { 541 | "display_name": "Python 3", 542 | "language": "python", 543 | "name": "python3" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": { 547 | "name": "ipython", 548 | "version": 3 549 | }, 550 | "file_extension": ".py", 551 | "mimetype": "text/x-python", 552 | "name": "python", 553 | "nbconvert_exporter": "python", 554 | "pygments_lexer": "ipython3", 555 | "version": "3.8.10" 556 | } 557 | }, 558 | "nbformat": 4, 559 | "nbformat_minor": 5 560 | } 561 | -------------------------------------------------------------------------------- /train_t5.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import warnings 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import datasets 9 | import evaluate 10 | import numpy as np 11 | from datasets import load_dataset, DatasetDict 12 | 13 | import transformers 14 | from transformers import ( 15 | AutoConfig, 16 | AutoModelForSeq2SeqLM, 17 | AutoTokenizer, 18 | DataCollatorForSeq2Seq, 19 | Seq2SeqTrainer, 20 | Seq2SeqTrainingArguments, 21 | default_data_collator, 22 | set_seed, 23 | ) 24 | from transformers.trainer_utils import get_last_checkpoint 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | @dataclass 31 | class ModelArguments: 32 | """ 33 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 34 | """ 35 | 36 | model_name_or_path: str = field( 37 | metadata={ 38 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 39 | } 40 | ) 41 | config_name: Optional[str] = field( 42 | default=None, 43 | metadata={ 44 | "help": "Pretrained config name or path if not the same as model_name" 45 | }, 46 | ) 47 | tokenizer_name: Optional[str] = field( 48 | default=None, 49 | metadata={ 50 | "help": "Pretrained tokenizer name or path if not the same as model_name" 51 | }, 52 | ) 53 | cache_dir: Optional[str] = field( 54 | default=None, 55 | metadata={ 56 | "help": "Where to store the pretrained models downloaded from huggingface.co" 57 | }, 58 | ) 59 | use_fast_tokenizer: bool = field( 60 | default=True, 61 | metadata={ 62 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 63 | }, 64 | ) 65 | model_revision: str = field( 66 | default="main", 67 | metadata={ 68 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 69 | }, 70 | ) 71 | token: Optional[str] = field( 72 | default=None, 73 | metadata={ 74 | "help": ( 75 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " 76 | "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." 77 | ) 78 | }, 79 | ) 80 | trust_remote_code: bool = field( 81 | default=False, 82 | metadata={ 83 | "help": ( 84 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" 85 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 86 | "execute code present on the Hub on your local machine." 87 | ) 88 | }, 89 | ) 90 | 91 | 92 | @dataclass 93 | class DataTrainingArguments: 94 | """ 95 | Arguments pertaining to what data we are going to input our model for training and eval. 96 | """ 97 | 98 | dataset_name: Optional[str] = field( 99 | default="roborovski/upsampled-prompts-parti", 100 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 101 | ) 102 | dataset_config_name: Optional[str] = field( 103 | default=None, 104 | metadata={ 105 | "help": "The configuration name of the dataset to use (via the datasets library)." 106 | }, 107 | ) 108 | train_file: Optional[str] = field( 109 | default=None, metadata={"help": "The input training data file (a jsonlines)."} 110 | ) 111 | validation_file: Optional[str] = field( 112 | default=None, 113 | metadata={ 114 | "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file." 115 | }, 116 | ) 117 | test_file: Optional[str] = field( 118 | default=None, 119 | metadata={ 120 | "help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file." 121 | }, 122 | ) 123 | overwrite_cache: bool = field( 124 | default=False, 125 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 126 | ) 127 | preprocessing_num_workers: Optional[int] = field( 128 | default=None, 129 | metadata={"help": "The number of processes to use for the preprocessing."}, 130 | ) 131 | max_source_length: Optional[int] = field( 132 | default=1024, 133 | metadata={ 134 | "help": ( 135 | "The maximum total input sequence length after tokenization. Sequences longer " 136 | "than this will be truncated, sequences shorter will be padded." 137 | ) 138 | }, 139 | ) 140 | max_target_length: Optional[int] = field( 141 | default=128, 142 | metadata={ 143 | "help": ( 144 | "The maximum total sequence length for target text after tokenization. Sequences longer " 145 | "than this will be truncated, sequences shorter will be padded." 146 | ) 147 | }, 148 | ) 149 | val_max_target_length: Optional[int] = field( 150 | default=None, 151 | metadata={ 152 | "help": ( 153 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 154 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. " 155 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 156 | "during ``evaluate`` and ``predict``." 157 | ) 158 | }, 159 | ) 160 | pad_to_max_length: bool = field( 161 | default=False, 162 | metadata={ 163 | "help": ( 164 | "Whether to pad all samples to model maximum sentence length. " 165 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 166 | "efficient on GPU but very bad for TPU." 167 | ) 168 | }, 169 | ) 170 | max_train_samples: Optional[int] = field( 171 | default=None, 172 | metadata={ 173 | "help": ( 174 | "For debugging purposes or quicker training, truncate the number of training examples to this " 175 | "value if set." 176 | ) 177 | }, 178 | ) 179 | max_eval_samples: Optional[int] = field( 180 | default=None, 181 | metadata={ 182 | "help": ( 183 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 184 | "value if set." 185 | ) 186 | }, 187 | ) 188 | max_predict_samples: Optional[int] = field( 189 | default=None, 190 | metadata={ 191 | "help": ( 192 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 193 | "value if set." 194 | ) 195 | }, 196 | ) 197 | num_beams: Optional[int] = field( 198 | default=1, 199 | metadata={ 200 | "help": ( 201 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 202 | "which is used during ``evaluate`` and ``predict``." 203 | ) 204 | }, 205 | ) 206 | ignore_pad_token_for_loss: bool = field( 207 | default=True, 208 | metadata={ 209 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 210 | }, 211 | ) 212 | source_prefix: Optional[str] = field( 213 | default="Convert the following prompt:", 214 | metadata={ 215 | "help": "A prefix to add before every source text (useful for T5 models)." 216 | }, 217 | ) 218 | forced_bos_token: Optional[str] = field( 219 | default=None, 220 | metadata={ 221 | "help": ( 222 | "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for" 223 | " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to" 224 | " be the target language token.(Usually it is the target language token)" 225 | ) 226 | }, 227 | ) 228 | 229 | 230 | def main(): 231 | # See all possible arguments in src/transformers/training_args.py 232 | # or by passing the --help flag to this script. 233 | # We now keep distinct sets of args, for a cleaner separation of concerns. 234 | 235 | model_args = ModelArguments(model_name_or_path="t5-base") 236 | data_args = DataTrainingArguments() 237 | training_args = Seq2SeqTrainingArguments(output_dir="./output") 238 | 239 | # Setup logging 240 | logging.basicConfig( 241 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 242 | datefmt="%m/%d/%Y %H:%M:%S", 243 | handlers=[logging.StreamHandler(sys.stdout)], 244 | ) 245 | 246 | transformers.utils.logging.set_verbosity_info() 247 | 248 | log_level = training_args.get_process_log_level() 249 | logger.setLevel(log_level) 250 | datasets.utils.logging.set_verbosity(log_level) 251 | transformers.utils.logging.set_verbosity(log_level) 252 | transformers.utils.logging.enable_default_handler() 253 | transformers.utils.logging.enable_explicit_format() 254 | 255 | # Log on each process the small summary: 256 | logger.warning( 257 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 258 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 259 | ) 260 | logger.info(f"Training/evaluation parameters {training_args}") 261 | 262 | # Detecting last checkpoint. 263 | last_checkpoint = None 264 | if ( 265 | os.path.isdir(training_args.output_dir) 266 | and training_args.do_train 267 | and not training_args.overwrite_output_dir 268 | ): 269 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 270 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 271 | raise ValueError( 272 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 273 | "Use --overwrite_output_dir to overcome." 274 | ) 275 | elif ( 276 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 277 | ): 278 | logger.info( 279 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 280 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 281 | ) 282 | 283 | # Set seed before initializing model. 284 | set_seed(training_args.seed) 285 | 286 | assert data_args.dataset_name 287 | raw_datasets: DatasetDict = load_dataset( 288 | data_args.dataset_name, 289 | data_args.dataset_config_name, 290 | cache_dir=model_args.cache_dir, 291 | token=model_args.token, 292 | ) # type: ignore 293 | # Load pretrained model and tokenizer 294 | # 295 | # Distributed training: 296 | # The .from_pretrained methods guarantee that only one local process can concurrently 297 | # download model & vocab. 298 | config = AutoConfig.from_pretrained( 299 | model_args.config_name 300 | if model_args.config_name 301 | else model_args.model_name_or_path, 302 | cache_dir=model_args.cache_dir, 303 | revision=model_args.model_revision, 304 | token=model_args.token, 305 | trust_remote_code=model_args.trust_remote_code, 306 | ) 307 | tokenizer = AutoTokenizer.from_pretrained( 308 | model_args.tokenizer_name 309 | if model_args.tokenizer_name 310 | else model_args.model_name_or_path, 311 | cache_dir=model_args.cache_dir, 312 | use_fast=model_args.use_fast_tokenizer, 313 | revision=model_args.model_revision, 314 | token=model_args.token, 315 | trust_remote_code=model_args.trust_remote_code, 316 | ) 317 | model = AutoModelForSeq2SeqLM.from_pretrained( 318 | model_args.model_name_or_path, 319 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 320 | config=config, 321 | cache_dir=model_args.cache_dir, 322 | revision=model_args.model_revision, 323 | token=model_args.token, 324 | trust_remote_code=model_args.trust_remote_code, 325 | ) 326 | 327 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 328 | # on a small vocab and want a smaller embedding size, remove this test. 329 | embedding_size = model.get_input_embeddings().weight.shape[0] 330 | if len(tokenizer) > embedding_size: 331 | model.resize_token_embeddings(len(tokenizer)) 332 | 333 | start_phrase = "upsample:" 334 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(start_phrase) 335 | 336 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 337 | 338 | # Preprocessing the datasets. 339 | # We need to tokenize inputs and targets. 340 | if training_args.do_train: 341 | column_names = raw_datasets["train"].column_names 342 | elif training_args.do_eval: 343 | column_names = raw_datasets["validation"].column_names 344 | elif training_args.do_predict: 345 | column_names = raw_datasets["test"].column_names 346 | else: 347 | logger.info( 348 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." 349 | ) 350 | return 351 | 352 | source_column = "Prompt" 353 | target_column = "Upsampled" 354 | 355 | # Temporarily set max_target_length for training. 356 | max_target_length = data_args.max_target_length 357 | padding = "max_length" if data_args.pad_to_max_length else False 358 | 359 | if training_args.label_smoothing_factor > 0 and not hasattr( 360 | model, "prepare_decoder_input_ids_from_labels" 361 | ): 362 | logger.warning( 363 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for " 364 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 365 | ) 366 | 367 | def preprocess_function(examples): 368 | inputs = [ex[source_column] for ex in examples] 369 | targets = [ex[target_column] for ex in examples] 370 | inputs = [prefix + inp for inp in inputs] 371 | 372 | model_inputs = tokenizer( 373 | inputs, 374 | max_length=data_args.max_source_length, 375 | padding=padding, 376 | truncation=True, 377 | ) 378 | 379 | # Tokenize targets with the `text_target` keyword argument 380 | labels = tokenizer( 381 | text_target=targets, 382 | max_length=max_target_length, 383 | padding=padding, 384 | truncation=True, 385 | ) 386 | 387 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 388 | # padding in the loss. 389 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 390 | labels["input_ids"] = [ 391 | [(l if l != tokenizer.pad_token_id else -100) for l in label] 392 | for label in labels["input_ids"] # type: ignore 393 | ] 394 | 395 | model_inputs["labels"] = labels["input_ids"] 396 | return model_inputs 397 | 398 | train_dataset = raw_datasets["train"] 399 | if training_args.do_train: 400 | if data_args.max_train_samples is not None: 401 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 402 | train_dataset = train_dataset.select(range(max_train_samples)) 403 | with training_args.main_process_first(desc="train dataset map pre-processing"): 404 | train_dataset = train_dataset.map( 405 | preprocess_function, 406 | batched=True, 407 | num_proc=data_args.preprocessing_num_workers, 408 | remove_columns=column_names, 409 | load_from_cache_file=not data_args.overwrite_cache, 410 | desc="Running tokenizer on train dataset", 411 | ) 412 | 413 | eval_dataset = raw_datasets["validation"] 414 | 415 | if training_args.do_eval: 416 | max_target_length = data_args.val_max_target_length 417 | if data_args.max_eval_samples is not None: 418 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 419 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 420 | with training_args.main_process_first( 421 | desc="validation dataset map pre-processing" 422 | ): 423 | eval_dataset = eval_dataset.map( 424 | preprocess_function, 425 | batched=True, 426 | num_proc=data_args.preprocessing_num_workers, 427 | remove_columns=column_names, 428 | load_from_cache_file=not data_args.overwrite_cache, 429 | desc="Running tokenizer on validation dataset", 430 | ) 431 | 432 | predict_dataset = raw_datasets["test"] 433 | if training_args.do_predict: 434 | max_target_length = data_args.val_max_target_length 435 | if "test" not in raw_datasets: 436 | raise ValueError("--do_predict requires a test dataset") 437 | if data_args.max_predict_samples is not None: 438 | max_predict_samples = min( 439 | len(predict_dataset), data_args.max_predict_samples 440 | ) 441 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 442 | with training_args.main_process_first( 443 | desc="prediction dataset map pre-processing" 444 | ): 445 | predict_dataset = predict_dataset.map( 446 | preprocess_function, 447 | batched=True, 448 | num_proc=data_args.preprocessing_num_workers, 449 | remove_columns=column_names, 450 | load_from_cache_file=not data_args.overwrite_cache, 451 | desc="Running tokenizer on prediction dataset", 452 | ) 453 | 454 | # Data collator 455 | label_pad_token_id = ( 456 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 457 | ) 458 | assert label_pad_token_id 459 | assert tokenizer.pad_token_id 460 | if data_args.pad_to_max_length: 461 | data_collator = default_data_collator 462 | else: 463 | data_collator = DataCollatorForSeq2Seq( 464 | tokenizer, 465 | model=model, 466 | label_pad_token_id=label_pad_token_id, 467 | pad_to_multiple_of=8 if training_args.fp16 else None, 468 | ) 469 | 470 | # Metric 471 | metric = evaluate.load("sacrebleu") 472 | 473 | def postprocess_text(preds, labels): 474 | preds = [pred.strip() for pred in preds] 475 | labels = [[label.strip()] for label in labels] 476 | 477 | return preds, labels 478 | 479 | def compute_metrics(eval_preds): 480 | preds, labels = eval_preds 481 | if isinstance(preds, tuple): 482 | preds = preds[0] 483 | assert tokenizer.pad_token_id 484 | # Replace -100s used for padding as we can't decode them 485 | preds = np.where(preds != -100, preds, tokenizer.pad_token_id) 486 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 487 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 488 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 489 | 490 | # Some simple post-processing 491 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 492 | 493 | result = metric.compute(predictions=decoded_preds, references=decoded_labels) 494 | assert result 495 | result = {"bleu": result["score"]} 496 | 497 | prediction_lens = [ 498 | np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds 499 | ] 500 | result["gen_len"] = np.mean(prediction_lens) 501 | result = {k: round(v, 4) for k, v in result.items()} 502 | return result 503 | 504 | # Initialize our Trainer 505 | trainer = Seq2SeqTrainer( 506 | model=model, 507 | args=training_args, 508 | train_dataset=train_dataset, # type: ignore 509 | eval_dataset=eval_dataset, # type: ignore 510 | tokenizer=tokenizer, 511 | data_collator=data_collator, 512 | compute_metrics=compute_metrics 513 | if training_args.predict_with_generate 514 | else None, 515 | ) 516 | 517 | # Training 518 | if training_args.do_train: 519 | checkpoint = None 520 | if training_args.resume_from_checkpoint is not None: 521 | checkpoint = training_args.resume_from_checkpoint 522 | elif last_checkpoint is not None: 523 | checkpoint = last_checkpoint 524 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 525 | trainer.save_model() # Saves the tokenizer too for easy upload 526 | 527 | metrics = train_result.metrics 528 | max_train_samples = ( 529 | data_args.max_train_samples 530 | if data_args.max_train_samples is not None 531 | else len(train_dataset) 532 | ) 533 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 534 | 535 | trainer.log_metrics("train", metrics) 536 | trainer.save_metrics("train", metrics) 537 | trainer.save_state() 538 | 539 | # Evaluation 540 | results = {} 541 | max_length = ( 542 | training_args.generation_max_length 543 | if training_args.generation_max_length is not None 544 | else data_args.val_max_target_length 545 | ) 546 | num_beams = ( 547 | data_args.num_beams 548 | if data_args.num_beams is not None 549 | else training_args.generation_num_beams 550 | ) 551 | if training_args.do_eval: 552 | logger.info("*** Evaluate ***") 553 | 554 | metrics = trainer.evaluate( 555 | max_length=max_length, num_beams=num_beams, metric_key_prefix="eval" 556 | ) 557 | max_eval_samples = ( 558 | data_args.max_eval_samples 559 | if data_args.max_eval_samples is not None 560 | else len(eval_dataset) 561 | ) 562 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 563 | 564 | trainer.log_metrics("eval", metrics) 565 | trainer.save_metrics("eval", metrics) 566 | 567 | if training_args.do_predict: 568 | logger.info("*** Predict ***") 569 | 570 | predict_results = trainer.predict( 571 | predict_dataset, # type: ignore 572 | metric_key_prefix="predict", 573 | max_length=max_length, 574 | num_beams=num_beams, 575 | ) 576 | metrics = predict_results.metrics 577 | assert metrics 578 | max_predict_samples = ( 579 | data_args.max_predict_samples 580 | if data_args.max_predict_samples is not None 581 | else len(predict_dataset) 582 | ) 583 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 584 | 585 | trainer.log_metrics("predict", metrics) 586 | trainer.save_metrics("predict", metrics) 587 | 588 | if trainer.is_world_process_zero(): 589 | if training_args.predict_with_generate: 590 | predictions = predict_results.predictions 591 | predictions = np.where( 592 | predictions != -100, predictions, tokenizer.pad_token_id 593 | ) 594 | predictions = tokenizer.batch_decode( 595 | predictions, 596 | skip_special_tokens=True, 597 | clean_up_tokenization_spaces=True, 598 | ) 599 | predictions = [pred.strip() for pred in predictions] 600 | output_prediction_file = os.path.join( 601 | training_args.output_dir, "generated_predictions.txt" 602 | ) 603 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 604 | writer.write("\n".join(predictions)) 605 | 606 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"} 607 | if data_args.dataset_name is not None: 608 | kwargs["dataset_tags"] = data_args.dataset_name 609 | if data_args.dataset_config_name is not None: 610 | kwargs["dataset_args"] = data_args.dataset_config_name 611 | kwargs[ 612 | "dataset" 613 | ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 614 | else: 615 | kwargs["dataset"] = data_args.dataset_name 616 | 617 | return results 618 | 619 | 620 | def _mp_fn(index): 621 | # For xla_spawn (TPUs) 622 | main() 623 | 624 | 625 | if __name__ == "__main__": 626 | main() 627 | -------------------------------------------------------------------------------- /notebooks/finetune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "lND0ci7hx0_E" 7 | }, 8 | "source": [ 9 | "# Fine-tuning a masked language model (PyTorch)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "oS2ZEetax0_G" 16 | }, 17 | "source": [ 18 | "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "id": "u47KiPV1x0_H" 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Defaulting to user installation because normal site-packages is not writeable\n", 33 | "Requirement already satisfied: datasets in ./.local/lib/python3.8/site-packages (2.11.0)\n", 34 | "Collecting evaluate\n", 35 | " Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n", 36 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 37 | "\u001b[?25hRequirement already satisfied: transformers[sentencepiece] in ./.local/lib/python3.8/site-packages (4.28.1)\n", 38 | "Requirement already satisfied: requests>=2.19.0 in ./.local/lib/python3.8/site-packages (from datasets) (2.28.1)\n", 39 | "Requirement already satisfied: numpy>=1.17 in ./.local/lib/python3.8/site-packages (from datasets) (1.23.4)\n", 40 | "Requirement already satisfied: multiprocess in ./.local/lib/python3.8/site-packages (from datasets) (0.70.14)\n", 41 | "Requirement already satisfied: aiohttp in ./.local/lib/python3.8/site-packages (from datasets) (3.8.4)\n", 42 | "Requirement already satisfied: dill<0.3.7,>=0.3.0 in ./.local/lib/python3.8/site-packages (from datasets) (0.3.6)\n", 43 | "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets) (5.3.1)\n", 44 | "Requirement already satisfied: responses<0.19 in ./.local/lib/python3.8/site-packages (from datasets) (0.18.0)\n", 45 | "Requirement already satisfied: tqdm>=4.62.1 in ./.local/lib/python3.8/site-packages (from datasets) (4.64.1)\n", 46 | "Requirement already satisfied: packaging in ./.local/lib/python3.8/site-packages (from datasets) (21.3)\n", 47 | "Requirement already satisfied: pyarrow>=8.0.0 in ./.local/lib/python3.8/site-packages (from datasets) (11.0.0)\n", 48 | "Requirement already satisfied: xxhash in ./.local/lib/python3.8/site-packages (from datasets) (3.2.0)\n", 49 | "Requirement already satisfied: pandas in ./.local/lib/python3.8/site-packages (from datasets) (1.5.1)\n", 50 | "Requirement already satisfied: fsspec[http]>=2021.11.1 in ./.local/lib/python3.8/site-packages (from datasets) (2023.4.0)\n", 51 | "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in ./.local/lib/python3.8/site-packages (from datasets) (0.13.4)\n", 52 | "Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from transformers[sentencepiece]) (3.0.12)\n", 53 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in ./.local/lib/python3.8/site-packages (from transformers[sentencepiece]) (0.13.3)\n", 54 | "Requirement already satisfied: regex!=2019.12.17 in ./.local/lib/python3.8/site-packages (from transformers[sentencepiece]) (2023.3.23)\n", 55 | "Requirement already satisfied: sentencepiece!=0.1.92,>=0.1.91 in ./.local/lib/python3.8/site-packages (from transformers[sentencepiece]) (0.1.98)\n", 56 | "Requirement already satisfied: protobuf<=3.20.2 in /usr/lib/python3/dist-packages (from transformers[sentencepiece]) (3.11.4)\n", 57 | "Requirement already satisfied: attrs>=17.3.0 in /usr/lib/python3/dist-packages (from aiohttp->datasets) (19.3.0)\n", 58 | "Requirement already satisfied: yarl<2.0,>=1.0 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.1)\n", 59 | "Requirement already satisfied: multidict<7.0,>=4.5 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n", 60 | "Requirement already satisfied: aiosignal>=1.1.2 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n", 61 | "Requirement already satisfied: frozenlist>=1.1.1 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.3)\n", 62 | "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.2)\n", 63 | "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in ./.local/lib/python3.8/site-packages (from aiohttp->datasets) (2.1.1)\n", 64 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./.local/lib/python3.8/site-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.4.0)\n", 65 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/lib/python3/dist-packages (from packaging->datasets) (2.4.6)\n", 66 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2019.11.28)\n", 67 | "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2.8)\n", 68 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./.local/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (1.26.15)\n", 69 | "Requirement already satisfied: pytz>=2020.1 in ./.local/lib/python3.8/site-packages (from pandas->datasets) (2022.5)\n", 70 | "Requirement already satisfied: python-dateutil>=2.8.1 in ./.local/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n", 71 | "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.14.0)\n", 72 | "Installing collected packages: evaluate\n", 73 | "Successfully installed evaluate-0.4.0\n", 74 | "\n", 75 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", 76 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n", 77 | "Defaulting to user installation because normal site-packages is not writeable\n", 78 | "Collecting accelerate\n", 79 | " Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)\n", 80 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.3/215.3 kB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 81 | "\u001b[?25hRequirement already satisfied: torch>=1.4.0 in /usr/lib/python3/dist-packages (from accelerate) (1.12.1)\n", 82 | "Requirement already satisfied: packaging>=20.0 in ./.local/lib/python3.8/site-packages (from accelerate) (21.3)\n", 83 | "Requirement already satisfied: psutil in /usr/lib/python3/dist-packages (from accelerate) (5.5.1)\n", 84 | "Requirement already satisfied: numpy>=1.17 in ./.local/lib/python3.8/site-packages (from accelerate) (1.23.4)\n", 85 | "Requirement already satisfied: pyyaml in /usr/lib/python3/dist-packages (from accelerate) (5.3.1)\n", 86 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/lib/python3/dist-packages (from packaging>=20.0->accelerate) (2.4.6)\n", 87 | "Installing collected packages: accelerate\n", 88 | "Successfully installed accelerate-0.18.0\n", 89 | "\n", 90 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", 91 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n", 92 | "\u001b[1;31mE: \u001b[0mCould not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)\u001b[0m\n", 93 | "\u001b[1;31mE: \u001b[0mUnable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?\u001b[0m\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "!pip install datasets evaluate transformers[sentencepiece]\n", 99 | "!pip install accelerate\n", 100 | "# To run the training on TPU, you will need to uncomment the following line:\n", 101 | "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n", 102 | "!apt install git-lfs" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": { 108 | "id": "xMScaAZHx0_H" 109 | }, 110 | "source": [ 111 | "You will need to setup git, adapt your email and name in the following cell." 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 2, 117 | "metadata": { 118 | "id": "VYNhDrEOx0_I" 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "!git config --global user.email \"you@example.com\"\n", 123 | "!git config --global user.name \"Your Name\"" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "nbeQSBhrx0_I" 130 | }, 131 | "source": [ 132 | "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": { 139 | "id": "eawZEJ3sx0_I" 140 | }, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "Token is valid.\n", 147 | "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n", 148 | "You might have to re-authenticate when pushing to the Hugging Face Hub.\n", 149 | "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n", 150 | "\n", 151 | "git config --global credential.helper store\n", 152 | "\n", 153 | "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n", 154 | "Token has not been saved to git credential helper.\n", 155 | "Your token has been saved to /home/ubuntu/.cache/huggingface/token\n", 156 | "Login successful\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "from huggingface_hub import notebook_login, login\n", 162 | "\n", 163 | "login(\"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\", add_to_git_credential=True)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "metadata": { 170 | "id": "p_tCZtEYx0_I" 171 | }, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "application/vnd.jupyter.widget-view+json": { 176 | "model_id": "7eb1e1f4880d45c2bff5dbfd4bfbd6ff", 177 | "version_major": 2, 178 | "version_minor": 0 179 | }, 180 | "text/plain": [ 181 | "Downloading (…)lve/main/config.json: 0%| | 0.00/483 [00:00>> DistilBERT number of parameters: 67M'\n", 222 | "'>>> BERT number of parameters: 110M'\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "distilbert_num_parameters = model.num_parameters() / 1_000_000\n", 228 | "print(f\"'>>> DistilBERT number of parameters: {round(distilbert_num_parameters)}M'\")\n", 229 | "print(f\"'>>> BERT number of parameters: 110M'\")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 8, 235 | "metadata": { 236 | "id": "akBuTf5gx0_J" 237 | }, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "application/vnd.jupyter.widget-view+json": { 242 | "model_id": "a96ff83bea304a6796a62b75ea2250d3", 243 | "version_major": 2, 244 | "version_minor": 0 245 | }, 246 | "text/plain": [ 247 | "Downloading (…)okenizer_config.json: 0%| | 0.00/28.0 [00:00>> a stylized potrait of a necromancer with stylized, fine details. stylized setting. very stylized style. trending on art station'\n", 301 | "'>>> a miniature potrait of a necromancer with miniature, fine details. miniature setting. very miniature style. trending on art station'\n", 302 | "'>>> a simple potrait of a necromancer with simple, fine details. simple setting. very simple style. trending on art station'\n", 303 | "'>>> a typical potrait of a necromancer with typical, fine details. typical setting. very typical style. trending on art station'\n", 304 | "'>>> a fine potrait of a necromancer with fine, fine details. fine setting. very fine style. trending on art station'\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "import torch\n", 310 | "\n", 311 | "text = \"a [MASK] potrait of a necromancer with [MASK], fine details. [MASK] setting. very [MASK] style. trending on art station\"\n", 312 | "\n", 313 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 314 | "token_logits = model(**inputs).logits\n", 315 | "# Find the location of [MASK] and extract its logits\n", 316 | "mask_token_index = torch.where(inputs[\"input_ids\"] == tokenizer.mask_token_id)[1]\n", 317 | "mask_token_logits = token_logits[0, mask_token_index, :]\n", 318 | "# Pick the [MASK] candidates with the highest logits\n", 319 | "top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()\n", 320 | "\n", 321 | "for token in top_5_tokens:\n", 322 | " print(f\"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'\")" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "id": "jVJB3flWx0_K", 330 | "outputId": "cee77dc3-114c-4d1c-86e7-8552a534c8ef" 331 | }, 332 | "outputs": [ 333 | { 334 | "data": { 335 | "text/plain": [ 336 | "DatasetDict({\n", 337 | " train: Dataset({\n", 338 | " features: ['text', 'label'],\n", 339 | " num_rows: 25000\n", 340 | " })\n", 341 | " test: Dataset({\n", 342 | " features: ['text', 'label'],\n", 343 | " num_rows: 25000\n", 344 | " })\n", 345 | " unsupervised: Dataset({\n", 346 | " features: ['text', 'label'],\n", 347 | " num_rows: 50000\n", 348 | " })\n", 349 | "})" 350 | ] 351 | }, 352 | "execution_count": null, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "from datasets import load_dataset\n", 359 | "\n", 360 | "imdb_dataset = load_dataset(\"imdb\")\n", 361 | "imdb_dataset" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "id": "J9qH0g6Ux0_K", 369 | "outputId": "f8a50d55-e4b0-44b5-c1db-3e37cb95a4f3" 370 | }, 371 | "outputs": [ 372 | { 373 | "data": { 374 | "text/plain": [ 375 | "\n", 376 | "'>>> Review: This is your typical Priyadarshan movie--a bunch of loony characters out on some silly mission. His signature climax has the entire cast of the film coming together and fighting each other in some crazy moshpit over hidden money. Whether it is a winning lottery ticket in Malamaal Weekly, black money in Hera Pheri, \"kodokoo\" in Phir Hera Pheri, etc., etc., the director is becoming ridiculously predictable. Don\\'t get me wrong; as clichéd and preposterous his movies may be, I usually end up enjoying the comedy. However, in most his previous movies there has actually been some good humor, (Hungama and Hera Pheri being noteworthy ones). Now, the hilarity of his films is fading as he is using the same formula over and over again.

Songs are good. Tanushree Datta looks awesome. Rajpal Yadav is irritating, and Tusshar is not a whole lot better. Kunal Khemu is OK, and Sharman Joshi is the best.'\n", 377 | "'>>> Label: 0'\n", 378 | "\n", 379 | "'>>> Review: Okay, the story makes no sense, the characters lack any dimensionally, the best dialogue is ad-libs about the low quality of movie, the cinematography is dismal, and only editing saves a bit of the muddle, but Sam\" Peckinpah directed the film. Somehow, his direction is not enough. For those who appreciate Peckinpah and his great work, this movie is a disappointment. Even a great cast cannot redeem the time the viewer wastes with this minimal effort.

The proper response to the movie is the contempt that the director San Peckinpah, James Caan, Robert Duvall, Burt Young, Bo Hopkins, Arthur Hill, and even Gig Young bring to their work. Watch the great Peckinpah films. Skip this mess.'\n", 380 | "'>>> Label: 0'\n", 381 | "\n", 382 | "'>>> Review: I saw this movie at the theaters when I was about 6 or 7 years old. I loved it then, and have recently come to own a VHS version.

My 4 and 6 year old children love this movie and have been asking again and again to watch it.

I have enjoyed watching it again too. Though I have to admit it is not as good on a little TV.

I do not have older children so I do not know what they would think of it.

The songs are very cute. My daughter keeps singing them over and over.

Hope this helps.'\n", 383 | "'>>> Label: 1'" 384 | ] 385 | }, 386 | "execution_count": null, 387 | "metadata": {}, 388 | "output_type": "execute_result" 389 | } 390 | ], 391 | "source": [ 392 | "sample = imdb_dataset[\"train\"].shuffle(seed=42).select(range(3))\n", 393 | "\n", 394 | "for row in sample:\n", 395 | " print(f\"\\n'>>> Review: {row['text']}'\")\n", 396 | " print(f\"'>>> Label: {row['label']}'\")" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": { 403 | "id": "2lwhXCKnx0_K", 404 | "outputId": "c7c11b17-38a6-43e2-8c30-48b6a4569e7c" 405 | }, 406 | "outputs": [ 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "DatasetDict({\n", 411 | " train: Dataset({\n", 412 | " features: ['attention_mask', 'input_ids', 'word_ids'],\n", 413 | " num_rows: 25000\n", 414 | " })\n", 415 | " test: Dataset({\n", 416 | " features: ['attention_mask', 'input_ids', 'word_ids'],\n", 417 | " num_rows: 25000\n", 418 | " })\n", 419 | " unsupervised: Dataset({\n", 420 | " features: ['attention_mask', 'input_ids', 'word_ids'],\n", 421 | " num_rows: 50000\n", 422 | " })\n", 423 | "})" 424 | ] 425 | }, 426 | "execution_count": null, 427 | "metadata": {}, 428 | "output_type": "execute_result" 429 | } 430 | ], 431 | "source": [ 432 | "def tokenize_function(examples):\n", 433 | " result = tokenizer(examples[\"text\"])\n", 434 | " if tokenizer.is_fast:\n", 435 | " result[\"word_ids\"] = [result.word_ids(i) for i in range(len(result[\"input_ids\"]))]\n", 436 | " return result\n", 437 | "\n", 438 | "\n", 439 | "# Use batched=True to activate fast multithreading!\n", 440 | "tokenized_datasets = imdb_dataset.map(\n", 441 | " tokenize_function, batched=True, remove_columns=[\"text\", \"label\"]\n", 442 | ")\n", 443 | "tokenized_datasets" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "id": "8MKAszP7x0_L", 451 | "outputId": "45a3d62a-157e-42e4-d81a-6ecdc389454e" 452 | }, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "512" 458 | ] 459 | }, 460 | "execution_count": null, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "tokenizer.model_max_length" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": { 473 | "id": "CIxBK6qzx0_L" 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "chunk_size = 128" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "id": "egmt144Qx0_L", 485 | "outputId": "5207152c-665d-46ac-860f-6ccf6d1e7ecf" 486 | }, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/plain": [ 491 | "'>>> Review 0 length: 200'\n", 492 | "'>>> Review 1 length: 559'\n", 493 | "'>>> Review 2 length: 192'" 494 | ] 495 | }, 496 | "execution_count": null, 497 | "metadata": {}, 498 | "output_type": "execute_result" 499 | } 500 | ], 501 | "source": [ 502 | "# Slicing produces a list of lists for each feature\n", 503 | "tokenized_samples = tokenized_datasets[\"train\"][:3]\n", 504 | "\n", 505 | "for idx, sample in enumerate(tokenized_samples[\"input_ids\"]):\n", 506 | " print(f\"'>>> Review {idx} length: {len(sample)}'\")" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": { 513 | "id": "l3C04NyYx0_L", 514 | "outputId": "503ade65-a34f-43e0-c459-d344fa54bdcb" 515 | }, 516 | "outputs": [ 517 | { 518 | "data": { 519 | "text/plain": [ 520 | "'>>> Concatenated reviews length: 951'" 521 | ] 522 | }, 523 | "execution_count": null, 524 | "metadata": {}, 525 | "output_type": "execute_result" 526 | } 527 | ], 528 | "source": [ 529 | "concatenated_examples = {\n", 530 | " k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()\n", 531 | "}\n", 532 | "total_length = len(concatenated_examples[\"input_ids\"])\n", 533 | "print(f\"'>>> Concatenated reviews length: {total_length}'\")" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "metadata": { 540 | "id": "hcm9_riAx0_L", 541 | "outputId": "2a1c247f-eb65-447c-9196-755bb148139f" 542 | }, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "text/plain": [ 547 | "'>>> Chunk length: 128'\n", 548 | "'>>> Chunk length: 128'\n", 549 | "'>>> Chunk length: 128'\n", 550 | "'>>> Chunk length: 128'\n", 551 | "'>>> Chunk length: 128'\n", 552 | "'>>> Chunk length: 128'\n", 553 | "'>>> Chunk length: 128'\n", 554 | "'>>> Chunk length: 55'" 555 | ] 556 | }, 557 | "execution_count": null, 558 | "metadata": {}, 559 | "output_type": "execute_result" 560 | } 561 | ], 562 | "source": [ 563 | "chunks = {\n", 564 | " k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n", 565 | " for k, t in concatenated_examples.items()\n", 566 | "}\n", 567 | "\n", 568 | "for chunk in chunks[\"input_ids\"]:\n", 569 | " print(f\"'>>> Chunk length: {len(chunk)}'\")" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": { 576 | "id": "xH-wD_Kjx0_L" 577 | }, 578 | "outputs": [], 579 | "source": [ 580 | "def group_texts(examples):\n", 581 | " # Concatenate all texts\n", 582 | " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", 583 | " # Compute length of concatenated texts\n", 584 | " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", 585 | " # We drop the last chunk if it's smaller than chunk_size\n", 586 | " total_length = (total_length // chunk_size) * chunk_size\n", 587 | " # Split by chunks of max_len\n", 588 | " result = {\n", 589 | " k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n", 590 | " for k, t in concatenated_examples.items()\n", 591 | " }\n", 592 | " # Create a new labels column\n", 593 | " result[\"labels\"] = result[\"input_ids\"].copy()\n", 594 | " return result" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": { 601 | "id": "QbMqtknMx0_M", 602 | "outputId": "9fe23fb5-1480-48bb-c724-5494367a0b52" 603 | }, 604 | "outputs": [ 605 | { 606 | "data": { 607 | "text/plain": [ 608 | "DatasetDict({\n", 609 | " train: Dataset({\n", 610 | " features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n", 611 | " num_rows: 61289\n", 612 | " })\n", 613 | " test: Dataset({\n", 614 | " features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n", 615 | " num_rows: 59905\n", 616 | " })\n", 617 | " unsupervised: Dataset({\n", 618 | " features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n", 619 | " num_rows: 122963\n", 620 | " })\n", 621 | "})" 622 | ] 623 | }, 624 | "execution_count": null, 625 | "metadata": {}, 626 | "output_type": "execute_result" 627 | } 628 | ], 629 | "source": [ 630 | "lm_datasets = tokenized_datasets.map(group_texts, batched=True)\n", 631 | "lm_datasets" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": null, 637 | "metadata": { 638 | "id": "8OTAjcACx0_M", 639 | "outputId": "d150b146-89f1-46d7-da39-4bc2cd67ffa9" 640 | }, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/plain": [ 645 | "\".... at.......... high. a classic line : inspector : i'm here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn't! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless\"" 646 | ] 647 | }, 648 | "execution_count": null, 649 | "metadata": {}, 650 | "output_type": "execute_result" 651 | } 652 | ], 653 | "source": [ 654 | "tokenizer.decode(lm_datasets[\"train\"][1][\"input_ids\"])" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": { 661 | "id": "UxQ_p9j8x0_M" 662 | }, 663 | "outputs": [], 664 | "source": [ 665 | "from transformers import DataCollatorForLanguageModeling\n", 666 | "\n", 667 | "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": null, 673 | "metadata": { 674 | "id": "qGgO9KZ1x0_M" 675 | }, 676 | "outputs": [], 677 | "source": [ 678 | "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n", 679 | "for sample in samples:\n", 680 | " _ = sample.pop(\"word_ids\")\n", 681 | "\n", 682 | "for chunk in data_collator(samples)[\"input_ids\"]:\n", 683 | " print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": null, 689 | "metadata": { 690 | "id": "e2FIRtMgx0_M" 691 | }, 692 | "outputs": [], 693 | "source": [ 694 | "import collections\n", 695 | "import numpy as np\n", 696 | "\n", 697 | "from transformers import default_data_collator\n", 698 | "\n", 699 | "wwm_probability = 0.2\n", 700 | "\n", 701 | "\n", 702 | "def whole_word_masking_data_collator(features):\n", 703 | " for feature in features:\n", 704 | " word_ids = feature.pop(\"word_ids\")\n", 705 | "\n", 706 | " # Create a map between words and corresponding token indices\n", 707 | " mapping = collections.defaultdict(list)\n", 708 | " current_word_index = -1\n", 709 | " current_word = None\n", 710 | " for idx, word_id in enumerate(word_ids):\n", 711 | " if word_id is not None:\n", 712 | " if word_id != current_word:\n", 713 | " current_word = word_id\n", 714 | " current_word_index += 1\n", 715 | " mapping[current_word_index].append(idx)\n", 716 | "\n", 717 | " # Randomly mask words\n", 718 | " mask = np.random.binomial(1, wwm_probability, (len(mapping),))\n", 719 | " input_ids = feature[\"input_ids\"]\n", 720 | " labels = feature[\"labels\"]\n", 721 | " new_labels = [-100] * len(labels)\n", 722 | " for word_id in np.where(mask)[0]:\n", 723 | " word_id = word_id.item()\n", 724 | " for idx in mapping[word_id]:\n", 725 | " new_labels[idx] = labels[idx]\n", 726 | " input_ids[idx] = tokenizer.mask_token_id\n", 727 | " feature[\"labels\"] = new_labels\n", 728 | "\n", 729 | " return default_data_collator(features)" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": null, 735 | "metadata": { 736 | "id": "f_7hHzEHx0_M", 737 | "outputId": "fbdf18b4-d964-48cb-f5ac-65ece874e3e9" 738 | }, 739 | "outputs": [ 740 | { 741 | "data": { 742 | "text/plain": [ 743 | "'>>> [CLS] bromwell high is a cartoon comedy [MASK] it ran at the same time as some other programs about school life, such as \" teachers \". my 35 years in the teaching profession lead me to believe that bromwell high\\'s satire is much closer to reality than is \" teachers \". the scramble to survive financially, the insightful students who can see right through their pathetic teachers\\'pomp, the pettiness of the whole situation, all remind me of the schools i knew and their students. when i saw the episode in which a student repeatedly tried to burn down the school, i immediately recalled.....'\n", 744 | "\n", 745 | "'>>> .... [MASK] [MASK] [MASK] [MASK]....... high. a classic line : inspector : i\\'m here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn\\'t! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless'" 746 | ] 747 | }, 748 | "execution_count": null, 749 | "metadata": {}, 750 | "output_type": "execute_result" 751 | } 752 | ], 753 | "source": [ 754 | "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n", 755 | "batch = whole_word_masking_data_collator(samples)\n", 756 | "\n", 757 | "for chunk in batch[\"input_ids\"]:\n", 758 | " print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": null, 764 | "metadata": { 765 | "id": "KLH73RGmx0_N", 766 | "outputId": "2bcfe332-8621-4467-ad44-389d650a5cfd" 767 | }, 768 | "outputs": [ 769 | { 770 | "data": { 771 | "text/plain": [ 772 | "DatasetDict({\n", 773 | " train: Dataset({\n", 774 | " features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n", 775 | " num_rows: 10000\n", 776 | " })\n", 777 | " test: Dataset({\n", 778 | " features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n", 779 | " num_rows: 1000\n", 780 | " })\n", 781 | "})" 782 | ] 783 | }, 784 | "execution_count": null, 785 | "metadata": {}, 786 | "output_type": "execute_result" 787 | } 788 | ], 789 | "source": [ 790 | "train_size = 10_000\n", 791 | "test_size = int(0.1 * train_size)\n", 792 | "\n", 793 | "downsampled_dataset = lm_datasets[\"train\"].train_test_split(\n", 794 | " train_size=train_size, test_size=test_size, seed=42\n", 795 | ")\n", 796 | "downsampled_dataset" 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": { 803 | "id": "dbZlpDWAx0_N" 804 | }, 805 | "outputs": [], 806 | "source": [ 807 | "from huggingface_hub import notebook_login\n", 808 | "\n", 809 | "notebook_login()" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": null, 815 | "metadata": { 816 | "id": "1ndl-0eZx0_N" 817 | }, 818 | "outputs": [], 819 | "source": [ 820 | "from transformers import TrainingArguments\n", 821 | "\n", 822 | "batch_size = 64\n", 823 | "# Show the training loss with every epoch\n", 824 | "logging_steps = len(downsampled_dataset[\"train\"]) // batch_size\n", 825 | "model_name = model_checkpoint.split(\"/\")[-1]\n", 826 | "\n", 827 | "training_args = TrainingArguments(\n", 828 | " output_dir=f\"{model_name}-finetuned-imdb\",\n", 829 | " overwrite_output_dir=True,\n", 830 | " evaluation_strategy=\"epoch\",\n", 831 | " learning_rate=2e-5,\n", 832 | " weight_decay=0.01,\n", 833 | " per_device_train_batch_size=batch_size,\n", 834 | " per_device_eval_batch_size=batch_size,\n", 835 | " push_to_hub=True,\n", 836 | " fp16=True,\n", 837 | " logging_steps=logging_steps,\n", 838 | ")" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": null, 844 | "metadata": { 845 | "id": "sFDhmmOex0_N" 846 | }, 847 | "outputs": [], 848 | "source": [ 849 | "from transformers import Trainer\n", 850 | "\n", 851 | "trainer = Trainer(\n", 852 | " model=model,\n", 853 | " args=training_args,\n", 854 | " train_dataset=downsampled_dataset[\"train\"],\n", 855 | " eval_dataset=downsampled_dataset[\"test\"],\n", 856 | " data_collator=data_collator,\n", 857 | " tokenizer=tokenizer,\n", 858 | ")" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": { 865 | "id": "BQO1C3xrx0_N", 866 | "outputId": "f9ea43de-5aa2-498d-b6fd-8861cc1ea921" 867 | }, 868 | "outputs": [ 869 | { 870 | "data": { 871 | "text/plain": [ 872 | ">>> Perplexity: 21.75" 873 | ] 874 | }, 875 | "execution_count": null, 876 | "metadata": {}, 877 | "output_type": "execute_result" 878 | } 879 | ], 880 | "source": [ 881 | "import math\n", 882 | "\n", 883 | "eval_results = trainer.evaluate()\n", 884 | "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")" 885 | ] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": null, 890 | "metadata": { 891 | "id": "2Na-5AjWx0_N" 892 | }, 893 | "outputs": [], 894 | "source": [ 895 | "trainer.train()" 896 | ] 897 | }, 898 | { 899 | "cell_type": "code", 900 | "execution_count": null, 901 | "metadata": { 902 | "id": "j0waxKaax0_O", 903 | "outputId": "fcd3fa5c-7162-4f14-b7af-bf5107b44af8" 904 | }, 905 | "outputs": [ 906 | { 907 | "data": { 908 | "text/plain": [ 909 | ">>> Perplexity: 11.32" 910 | ] 911 | }, 912 | "execution_count": null, 913 | "metadata": {}, 914 | "output_type": "execute_result" 915 | } 916 | ], 917 | "source": [ 918 | "eval_results = trainer.evaluate()\n", 919 | "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")" 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "execution_count": null, 925 | "metadata": { 926 | "id": "dSHg6monx0_O" 927 | }, 928 | "outputs": [], 929 | "source": [ 930 | "trainer.push_to_hub()" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": null, 936 | "metadata": { 937 | "id": "glIQJkz9x0_O" 938 | }, 939 | "outputs": [], 940 | "source": [ 941 | "def insert_random_mask(batch):\n", 942 | " features = [dict(zip(batch, t)) for t in zip(*batch.values())]\n", 943 | " masked_inputs = data_collator(features)\n", 944 | " # Create a new \"masked\" column for each column in the dataset\n", 945 | " return {\"masked_\" + k: v.numpy() for k, v in masked_inputs.items()}" 946 | ] 947 | }, 948 | { 949 | "cell_type": "code", 950 | "execution_count": null, 951 | "metadata": { 952 | "id": "AcaEIqx-x0_O" 953 | }, 954 | "outputs": [], 955 | "source": [ 956 | "downsampled_dataset = downsampled_dataset.remove_columns([\"word_ids\"])\n", 957 | "eval_dataset = downsampled_dataset[\"test\"].map(\n", 958 | " insert_random_mask,\n", 959 | " batched=True,\n", 960 | " remove_columns=downsampled_dataset[\"test\"].column_names,\n", 961 | ")\n", 962 | "eval_dataset = eval_dataset.rename_columns(\n", 963 | " {\n", 964 | " \"masked_input_ids\": \"input_ids\",\n", 965 | " \"masked_attention_mask\": \"attention_mask\",\n", 966 | " \"masked_labels\": \"labels\",\n", 967 | " }\n", 968 | ")" 969 | ] 970 | }, 971 | { 972 | "cell_type": "code", 973 | "execution_count": null, 974 | "metadata": { 975 | "id": "kJLdyN80x0_O" 976 | }, 977 | "outputs": [], 978 | "source": [ 979 | "from torch.utils.data import DataLoader\n", 980 | "from transformers import default_data_collator\n", 981 | "\n", 982 | "batch_size = 64\n", 983 | "train_dataloader = DataLoader(\n", 984 | " downsampled_dataset[\"train\"],\n", 985 | " shuffle=True,\n", 986 | " batch_size=batch_size,\n", 987 | " collate_fn=data_collator,\n", 988 | ")\n", 989 | "eval_dataloader = DataLoader(\n", 990 | " eval_dataset, batch_size=batch_size, collate_fn=default_data_collator\n", 991 | ")" 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "execution_count": null, 997 | "metadata": { 998 | "id": "CYxQmfp8x0_O" 999 | }, 1000 | "outputs": [], 1001 | "source": [ 1002 | "from torch.optim import AdamW\n", 1003 | "\n", 1004 | "optimizer = AdamW(model.parameters(), lr=5e-5)" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "execution_count": null, 1010 | "metadata": { 1011 | "id": "YqF_KamSx0_O" 1012 | }, 1013 | "outputs": [], 1014 | "source": [ 1015 | "from accelerate import Accelerator\n", 1016 | "\n", 1017 | "accelerator = Accelerator()\n", 1018 | "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n", 1019 | " model, optimizer, train_dataloader, eval_dataloader\n", 1020 | ")" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": null, 1026 | "metadata": { 1027 | "id": "CZ4kWWi0x0_O" 1028 | }, 1029 | "outputs": [], 1030 | "source": [ 1031 | "from transformers import get_scheduler\n", 1032 | "\n", 1033 | "num_train_epochs = 3\n", 1034 | "num_update_steps_per_epoch = len(train_dataloader)\n", 1035 | "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n", 1036 | "\n", 1037 | "lr_scheduler = get_scheduler(\n", 1038 | " \"linear\",\n", 1039 | " optimizer=optimizer,\n", 1040 | " num_warmup_steps=0,\n", 1041 | " num_training_steps=num_training_steps,\n", 1042 | ")" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": null, 1048 | "metadata": { 1049 | "id": "LA2v6r1Wx0_O", 1050 | "outputId": "e632d18a-8b1b-4e62-abb2-b2139959ac0d" 1051 | }, 1052 | "outputs": [ 1053 | { 1054 | "data": { 1055 | "text/plain": [ 1056 | "'lewtun/distilbert-base-uncased-finetuned-imdb-accelerate'" 1057 | ] 1058 | }, 1059 | "execution_count": null, 1060 | "metadata": {}, 1061 | "output_type": "execute_result" 1062 | } 1063 | ], 1064 | "source": [ 1065 | "from huggingface_hub import get_full_repo_name\n", 1066 | "\n", 1067 | "model_name = \"distilbert-base-uncased-finetuned-imdb-accelerate\"\n", 1068 | "repo_name = get_full_repo_name(model_name)\n", 1069 | "repo_name" 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "code", 1074 | "execution_count": null, 1075 | "metadata": { 1076 | "id": "nWbzglLex0_P" 1077 | }, 1078 | "outputs": [], 1079 | "source": [ 1080 | "from huggingface_hub import Repository\n", 1081 | "\n", 1082 | "output_dir = model_name\n", 1083 | "repo = Repository(output_dir, clone_from=repo_name)" 1084 | ] 1085 | }, 1086 | { 1087 | "cell_type": "code", 1088 | "execution_count": null, 1089 | "metadata": { 1090 | "id": "kfDNqAq2x0_P", 1091 | "outputId": "232887c0-b311-4c65-d3c9-91c1ace24dbb" 1092 | }, 1093 | "outputs": [ 1094 | { 1095 | "data": { 1096 | "text/plain": [ 1097 | ">>> Epoch 0: Perplexity: 11.397545307900472\n", 1098 | ">>> Epoch 1: Perplexity: 10.904909330983092\n", 1099 | ">>> Epoch 2: Perplexity: 10.729503505340409" 1100 | ] 1101 | }, 1102 | "execution_count": null, 1103 | "metadata": {}, 1104 | "output_type": "execute_result" 1105 | } 1106 | ], 1107 | "source": [ 1108 | "from tqdm.auto import tqdm\n", 1109 | "import torch\n", 1110 | "import math\n", 1111 | "\n", 1112 | "progress_bar = tqdm(range(num_training_steps))\n", 1113 | "\n", 1114 | "for epoch in range(num_train_epochs):\n", 1115 | " # Training\n", 1116 | " model.train()\n", 1117 | " for batch in train_dataloader:\n", 1118 | " outputs = model(**batch)\n", 1119 | " loss = outputs.loss\n", 1120 | " accelerator.backward(loss)\n", 1121 | "\n", 1122 | " optimizer.step()\n", 1123 | " lr_scheduler.step()\n", 1124 | " optimizer.zero_grad()\n", 1125 | " progress_bar.update(1)\n", 1126 | "\n", 1127 | " # Evaluation\n", 1128 | " model.eval()\n", 1129 | " losses = []\n", 1130 | " for step, batch in enumerate(eval_dataloader):\n", 1131 | " with torch.no_grad():\n", 1132 | " outputs = model(**batch)\n", 1133 | "\n", 1134 | " loss = outputs.loss\n", 1135 | " losses.append(accelerator.gather(loss.repeat(batch_size)))\n", 1136 | "\n", 1137 | " losses = torch.cat(losses)\n", 1138 | " losses = losses[: len(eval_dataset)]\n", 1139 | " try:\n", 1140 | " perplexity = math.exp(torch.mean(losses))\n", 1141 | " except OverflowError:\n", 1142 | " perplexity = float(\"inf\")\n", 1143 | "\n", 1144 | " print(f\">>> Epoch {epoch}: Perplexity: {perplexity}\")\n", 1145 | "\n", 1146 | " # Save and upload\n", 1147 | " accelerator.wait_for_everyone()\n", 1148 | " unwrapped_model = accelerator.unwrap_model(model)\n", 1149 | " unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n", 1150 | " if accelerator.is_main_process:\n", 1151 | " tokenizer.save_pretrained(output_dir)\n", 1152 | " repo.push_to_hub(\n", 1153 | " commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n", 1154 | " )" 1155 | ] 1156 | }, 1157 | { 1158 | "cell_type": "code", 1159 | "execution_count": null, 1160 | "metadata": { 1161 | "id": "NI5lIeCBx0_P" 1162 | }, 1163 | "outputs": [], 1164 | "source": [ 1165 | "from transformers import pipeline\n", 1166 | "\n", 1167 | "mask_filler = pipeline(\n", 1168 | " \"fill-mask\", model=\"huggingface-course/distilbert-base-uncased-finetuned-imdb\"\n", 1169 | ")" 1170 | ] 1171 | }, 1172 | { 1173 | "cell_type": "code", 1174 | "execution_count": null, 1175 | "metadata": { 1176 | "id": "6iCYlwlux0_P", 1177 | "outputId": "34a8c467-19ea-4834-8e31-6fd888308287" 1178 | }, 1179 | "outputs": [ 1180 | { 1181 | "data": { 1182 | "text/plain": [ 1183 | "'>>> this is a great movie.'\n", 1184 | "'>>> this is a great film.'\n", 1185 | "'>>> this is a great story.'\n", 1186 | "'>>> this is a great movies.'\n", 1187 | "'>>> this is a great character.'" 1188 | ] 1189 | }, 1190 | "execution_count": null, 1191 | "metadata": {}, 1192 | "output_type": "execute_result" 1193 | } 1194 | ], 1195 | "source": [ 1196 | "preds = mask_filler(text)\n", 1197 | "\n", 1198 | "for pred in preds:\n", 1199 | " print(f\">>> {pred['sequence']}\")" 1200 | ] 1201 | } 1202 | ], 1203 | "metadata": { 1204 | "colab": { 1205 | "name": "Fine-tuning a masked language model (PyTorch)", 1206 | "provenance": [] 1207 | }, 1208 | "kernelspec": { 1209 | "display_name": "Python 3", 1210 | "language": "python", 1211 | "name": "python3" 1212 | }, 1213 | "language_info": { 1214 | "codemirror_mode": { 1215 | "name": "ipython", 1216 | "version": 3 1217 | }, 1218 | "file_extension": ".py", 1219 | "mimetype": "text/x-python", 1220 | "name": "python", 1221 | "nbconvert_exporter": "python", 1222 | "pygments_lexer": "ipython3", 1223 | "version": "3.8.10" 1224 | } 1225 | }, 1226 | "nbformat": 4, 1227 | "nbformat_minor": 4 1228 | } 1229 | --------------------------------------------------------------------------------