├── imgs ├── sample.txt ├── mmllm.png └── flamingo_ss.png ├── notebooks └── README.md ├── vlm ├── __pycache__ │ ├── vlm.cpython-311.pyc │ └── dataset.cpython-311.pyc ├── modeling │ ├── __pycache__ │ │ ├── llm.cpython-311.pyc │ │ ├── projection.cpython-311.pyc │ │ ├── image_encoder.cpython-311.pyc │ │ └── vision_encoder.cpython-311.pyc │ ├── vision_encoder.py │ ├── llm.py │ └── projection.py ├── lightning_train.py ├── train.py ├── dataset.py └── vlm.py └── README.md /imgs/sample.txt: -------------------------------------------------------------------------------- 1 | sample_text 2 | -------------------------------------------------------------------------------- /imgs/mmllm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/imgs/mmllm.png -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | Ignore this notebook - go back to vlm/train and vlm/lightning_train 2 | -------------------------------------------------------------------------------- /imgs/flamingo_ss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/imgs/flamingo_ss.png -------------------------------------------------------------------------------- /vlm/__pycache__/vlm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/__pycache__/vlm.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/modeling/__pycache__/llm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/modeling/__pycache__/llm.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/modeling/__pycache__/projection.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/modeling/__pycache__/projection.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/modeling/__pycache__/image_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/modeling/__pycache__/image_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/modeling/__pycache__/vision_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexander-moore/vlm/HEAD/vlm/modeling/__pycache__/vision_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /vlm/lightning_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use lightning for DDP training of VLM 3 | 4 | """ 5 | 6 | import lightning as L 7 | from lit_vlm import VLM_LitModel 8 | from vlm import build_vlm 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='VLM Training Settings') 12 | 13 | parser.add_argument('--gpu_ids', help='Comma-separated list of GPU Numbers to use', 14 | default='0', type=str) 15 | 16 | parser.add_argument('--version_name', help = 'version name in save_dir', 17 | type = str, default = 'test') 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | model = build_vlm() 23 | lit_model = VLM_LitModel(model = model) 24 | 25 | from lightning.pytorch.strategies import DDPStrategy 26 | 27 | 28 | # Added gradient clipping 29 | from dataset import get_coco_dataset 30 | from torch.utils.data import DataLoader 31 | 32 | gpu_ids = [int(i) for i in args.gpu_ids.split(',')] 33 | 34 | train_dataset = get_coco_dataset(mode='train') 35 | train_dataloader = DataLoader(train_dataset, batch_size = 1) 36 | 37 | print('im validating on testing dataset!!') 38 | val_dataset = get_coco_dataset(mode = 'train') # mode = val BTW! 39 | val_dataloader = DataLoader(val_dataset, batch_size = 1) 40 | 41 | from lightning.pytorch.loggers import TensorBoardLogger 42 | save_dir = 'logs' 43 | version_name = args.version_name 44 | 45 | logger = TensorBoardLogger(save_dir=save_dir, version=version_name, name="trackers") 46 | 47 | trainer = L.Trainer(accelerator="gpu", devices=gpu_ids, 48 | gradient_clip_val = 1, 49 | max_steps = 60000) 50 | 51 | 52 | trainer.fit(model=lit_model, train_dataloaders=train_dataloader, 53 | val_dataloaders=val_dataloader) 54 | -------------------------------------------------------------------------------- /vlm/modeling/vision_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vision Encoder 3 | -------------- 4 | 5 | This vision model will encode images into representations which can be sequenced/patched into tokens under our custom tokenizer. 6 | This will encode visual information for the language model to decode. 7 | """ 8 | 9 | from transformers import ViTImageProcessor, ViTForImageClassification 10 | from PIL import Image 11 | import requests 12 | 13 | def get_image_encoder(image_encoder_str, peft = False): 14 | processor = ViTImageProcessor.from_pretrained(image_encoder_str) 15 | model = ViTForImageClassification.from_pretrained(image_encoder_str) 16 | 17 | if peft: 18 | 19 | from peft import get_peft_config, get_peft_model, LoraConfig, TaskType 20 | 21 | peft_config = LoraConfig( 22 | task_type=None, inference_mode=False, r=8, 23 | lora_alpha=32, lora_dropout=0.1, target_modules=['dense'] 24 | ) 25 | 26 | model = get_peft_model(model, peft_config) 27 | model.print_trainable_parameters() 28 | 29 | else: 30 | model.requires_grad = False 31 | 32 | return processor, model 33 | 34 | if __name__ == '__main__': 35 | """ 36 | Test image encoder, verify encoded size 37 | """ 38 | #url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 39 | #image = Image.open(requests.get(url, stream=True).raw) 40 | image = Image.open('/data/coco2017/val2017/000000187990.jpg') 41 | 42 | processor, model = get_image_encoder('google/vit-base-patch16-224', peft = True) 43 | 44 | inputs = processor(images=image, return_tensors="pt") 45 | outputs = model(**inputs, output_hidden_states = True) 46 | 47 | embeddings = outputs.hidden_states[-1] 48 | print('made embedings', embeddings.shape) 49 | 50 | -------------------------------------------------------------------------------- /vlm/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a vlm 3 | """ 4 | 5 | device = 7 6 | # def batch_to_device(batch): 7 | # for key, value in batch.items(): 8 | # try: 9 | # batch[key] = value.to(device) 10 | # print(key, value) 11 | # except: 12 | # pass 13 | 14 | def train_model(model, n_epochs): 15 | model = model.to(device) 16 | model.train() 17 | optimizer.train() 18 | 19 | for _ in range(n_epochs): 20 | losses = [] 21 | for bi, batch in enumerate(train_dataloader): 22 | optimizer.zero_grad() 23 | 24 | #batch = batch_to_device(batch) 25 | batch['image'] = batch['image'].to(device) 26 | 27 | logits, loss = model.forward(batch) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | losses.append(loss.data.item()) 32 | if bi % 100 == 0: 33 | print(sum(losses) / len(losses)) 34 | 35 | val_metrics = val_step() 36 | print(val_metrics) 37 | 38 | def val_step(): 39 | model.val() 40 | optimizer.val() 41 | 42 | val_metrics = 0 43 | 44 | model.train() 45 | optimizer.train() 46 | 47 | return val_metrics 48 | 49 | if __name__ == '__main__': 50 | """ 51 | Train a model 52 | """ 53 | 54 | # Model 55 | from vlm import build_vlm 56 | 57 | model = build_vlm().to(device) 58 | print(model) 59 | 60 | # Optimizer 61 | from schedulefree import AdamWScheduleFree 62 | optimizer = AdamWScheduleFree(model.parameters(), lr = 3e-4) 63 | 64 | # Data 65 | from dataset import get_coco_dataset#, get_pokemon_dataset 66 | from torch.utils.data import DataLoader 67 | 68 | #val_dataset = get_pokemon_dataset() 69 | #val_dataset = get_coco_dataset() 70 | train_dataset = get_coco_dataset(mode='train') 71 | train_dataloader = DataLoader(train_dataset, batch_size = 1) 72 | 73 | val_dataset = get_coco_dataset(mode = 'val') 74 | val_dataloader = DataLoader(val_dataset, batch_size = 1) 75 | 76 | print(train_dataset, val_dataset) 77 | n_epochs = 2 78 | 79 | train_model(model, n_epochs) 80 | -------------------------------------------------------------------------------- /vlm/modeling/llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Language Model Base 3 | ------------------- 4 | 5 | We need a language model to serve as the initial point. This language model will be frozen during training and used to decode custom language-vision inputs 6 | """ 7 | import huggingface_hub 8 | import transformers 9 | #from transformers import PhiForCausalLM, PhiForCausalLM 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | 12 | # Optional peft component for lora 13 | from peft import LoraConfig, PeftModel, get_peft_model 14 | 15 | 16 | def get_llm(llm_name): 17 | """ 18 | 19 | """ 20 | tokenizer = AutoTokenizer.from_pretrained(llm_name) 21 | model = AutoModelForCausalLM.from_pretrained(llm_name) 22 | 23 | # Freeze LLM 24 | for name, param in model.named_parameters(): 25 | param.requires_grad = False 26 | 27 | # Add PEFT? 28 | print('peft requires target modules, gate_proj, up_proj, down_proj') 29 | if False: 30 | peft_config = LoraConfig( 31 | task_type=None, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=['lin1', 'lin2'] 32 | ) 33 | 34 | model = get_peft_model(model, peft_config) 35 | 36 | return tokenizer, model 37 | 38 | if __name__ == '__main__': 39 | """ 40 | Be able to call llm.py to test things 41 | """ 42 | 43 | # Load model directly 44 | 45 | # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") 46 | # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") 47 | 48 | # Load model directly 49 | # Rather than instruct model, likely want to fine-tune our own base model? 50 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 51 | model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 52 | 53 | # 54 | from transformers import AutoTokenizer, MistralForCausalLM 55 | 56 | #model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 57 | #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 58 | 59 | prompt = "Hey, are you conscious? Can you talk to me?" 60 | inputs = tokenizer(prompt, return_tensors="pt") 61 | 62 | # Generate 63 | generate_ids = model.generate(inputs.input_ids, max_length=30) 64 | print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) 65 | -------------------------------------------------------------------------------- /vlm/modeling/projection.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | """ 5 | Currently just using simple linear projection. 6 | Literature often uses tiny attention models like xformer (?) mini-gpt etc to get output tokens and then project to new LLM dim 7 | """ 8 | 9 | def get_image_tokenizer(projector_str, insize, outsize): 10 | if 'qformer' in projector_str: 11 | #config = BlipConfig(in_dim = insize, 12 | # out_dim = outsize) 13 | #return QFormer(config) 14 | pass 15 | else: 16 | return ImageTokenizer(insize, outsize) 17 | 18 | def get_in_size(image_encoder): 19 | return image_encoder.config.hidden_size 20 | 21 | def get_out_size(lang_tokenizer, llm): 22 | input_ids = lang_tokenizer.encode('hi', add_special_tokens=True, return_tensors="pt") 23 | vec = llm.get_input_embeddings()(input_ids) 24 | embedded_tokens_size = vec.size()[-1] 25 | return embedded_tokens_size 26 | 27 | class ImageTokenizer(nn.Module): 28 | def __init__(self, in_dim, out_dim): 29 | super(ImageTokenizer, self).__init__() 30 | self.fc1 = nn.Linear(in_dim, out_dim) 31 | self.fc2 = nn.Linear(out_dim, out_dim) 32 | self.activ = nn.GELU() 33 | 34 | def forward(self, x): 35 | """ 36 | Forward maps vision_encoder outputs to llm_token input size 37 | (bs, seq_length, in_size) -> (bs, seq_length, out_size) 38 | """ 39 | return self.fc2(self.activ(self.fc1(x))) 40 | 41 | # class BlipConfig(): 42 | # """ 43 | # Build the config to pass to Blip2Qformer 44 | # """ 45 | # def __init__(self, in_dim, out_dim): 46 | # self.hidden_size = in_dim 47 | # self.num_attention_heads = 8 48 | 49 | 50 | # def QFormer(config): 51 | # from transformers import Blip2QFormerModel 52 | # return Blip2QFormerModel() 53 | 54 | def qformer(): 55 | import torch 56 | from qformer import QFormer 57 | 58 | # Create a random tensor of shape (1, 32, 512) 59 | x = torch.randn(1, 32, 512) 60 | 61 | # Create a random image tensor of shape (1, 3, 224, 224) 62 | img = torch.randn(1, 3, 224, 224) 63 | 64 | # Create an instance of the QFormer model with the following parameters: 65 | # - input_size: 512 66 | # - num_heads: 8 67 | # - num_layers: 8 68 | # - dropout: 0.1 69 | # - num_classes: 2 70 | # - num_patches: 2 71 | qformer = QFormer(512, 8, 8, 0.1, 2, 2) 72 | 73 | # Apply the QFormer model to the input tensors x and img 74 | y = qformer(x, img) 75 | 76 | 77 | # Print the shape of the output tensor y 78 | print(y.shape) 79 | 80 | # Then I think we literally jam the QFORMER output into the LLM? 81 | # I think? lol. that ounds crazty 82 | 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal LLMs from Scratch 2 | 3 | I will blog about updates and considerations to the design of this repo here: 4 | https://medium.com/@ammpersonal77 5 | 6 | Multimodal LLMs are an exciting new direction for multimodal AI research. Multimodality for LLMs come from utilizing a pretrained LLM in conjunction with a pretrained domain encoder (for example, an image encoder). A new untrained adapter module translates from the image encoder embedding space to the LLM token language embedding space. This system is trained end-to-end for the "domain tokenizer" / "adapter" component to succesfully translate between the data domain and language space to give the LLM multimodal understanding with no need to fine tune the extremely heavyweight LLM backbone. We also add PEFT components and experiment with adapter design, training schema, and eventually end-to-end multimodality - read the blog for more. 7 | 8 | Multimodal VLM for now: 9 | - Use a ViT image encoder 10 | -- I want to experiment with Segment Anything Model, other work from LLNL ongoing research on optimal image encodings for different Vision-language models. Depending on downstream task 11 | - Use a image-caption pair dataset. Experiment with 'augmentation' / bootstrapping, for example image-text, text-image, text-image-text (all bootstrapped from same image-text sample) 12 | -- In the LLaMA literature this is cited as prefix-content-suffix frmo another publication 13 | 14 | - Experiment with new position encodings, combinations of sequence encoding and image encoding. Should an image inside a text block get a position encoding as well as an image (xy?) position embedding summed? 15 | 16 | - notes: 17 | - - two way attention image and text? this assumes structured input, though? 18 | - i think losing interwoven text and image would be fine, in order to use two-way attention as an adapter 19 | - rather than as an llm replacement - though this would work im sure. 20 | - maybe an alternate multimodal approach is just two-way attention over image and text. consider! 21 | - Qformer (blip) https://huggingface.co/docs/transformers/main/en/model_doc/blip-2#transformers.Blip2QFormerModel 22 | - just look into this as an adapter model - again not sure about requirement for structured input 23 | 24 | ![Flamingo_screenshot - multimodal motivation](imgs/flamingo_ss.png) 25 | 26 | *The [Flamingo](https://arxiv.org/abs/2204.14198) Authors motivate the usage of multimodal (interwoven vision and text) language models as few-shot inference models. Imagine training a seperate model for each of these tasks!* :fearful: 27 | 28 | 29 | ![General Architecture](imgs/mmllm.png) 30 | 31 | *Here we share a general approach to multimodal langauge models with text output. Special tokens are added to the language model tokenizer, and end-to-end finetuning adapts the "image tokenizer" module to translate visual information for the language model.* 32 | -------------------------------------------------------------------------------- /vlm/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset and dataloading for vision-language models. 3 | 4 | We need to do a few things here: 5 | load image-text pairs 6 | format them 7 | 8 | Do we do this here or in the model? 9 | the model needs a way to take a set [dict1, dict2, ...] where dict_i = {text, imgae} 10 | Then format these into token sequences, pad them, and collate them into batches 11 | 12 | So I guess the dataset can't really do that. Needs to be in the model to access the tokenizers, padding embedding etc? 13 | """ 14 | 15 | """ 16 | Sample Data 17 | ----------- 18 | 19 | Sample data is borrowed from the `cppe-5` dataset. I use this data since it has images, string labels, and some interesting annotations such as bboxes we may enjoy using. 20 | It is also a small, reasonable size for testing 21 | 22 | Quote: https://huggingface.co/docs/datasets/en/object_detection 23 | The dataset has the following fields: 24 | 25 | image: PIL.Image.Image object containing the image. 26 | image_id: The image ID. 27 | height: The image height. 28 | width: The image width. 29 | objects: A dictionary containing bounding box metadata for the objects in the image: 30 | id: The annotation id. 31 | area: The area of the bounding box. 32 | bbox: The object’s bounding box (in the coco format). 33 | category: The object’s category, with possible values including Coverall (0), Face_Shield (1), Gloves (2), Goggles (3) and Mask (4). 34 | """ 35 | 36 | from datasets import load_dataset 37 | import matplotlib.pyplot as plt 38 | import warnings 39 | import random 40 | from torch.utils.data import Dataset 41 | import torchvision.datasets as dset 42 | import torch.nn as nn 43 | from torchvision.transforms import ToTensor, Compose, Resize 44 | from random import choice 45 | 46 | def get_pokemon_dataset(): 47 | """ 48 | Get image-caption pokemon dataset. use this as val 49 | """ 50 | return load_dataset("lambdalabs/pokemon-blip-captions") 51 | 52 | def get_coco_dataset(mode = 'train'): 53 | """ 54 | Abcd 55 | """ 56 | coco_dataset = dset.CocoDetection(root = f'/data/coco2017/{mode}2017', 57 | annFile = f'/data/coco2017/annotations/captions_{mode}2017.json' 58 | ) 59 | return Coco_Wrapper(coco_dataset) 60 | 61 | class Coco_Wrapper(Dataset): 62 | def __init__(self, coco_dataset): 63 | self.dataset = coco_dataset 64 | #self.transforms = Compose(ToTensor(), Resize((256,256), antialias=True)) 65 | self.totensor = ToTensor() 66 | self.resize = Resize((256, 256), antialias=True) 67 | self.len = len(coco_dataset) 68 | 69 | def __len__(self): 70 | return self.len 71 | 72 | def __getitem__(self, idx): 73 | """ 74 | Get and transform items 75 | """ 76 | img, target = self.dataset[idx] 77 | image = self.totensor(img) 78 | #image = self.resize(image) 79 | 80 | caption = choice(target)['caption'] 81 | 82 | sample = {'image': image, 83 | 'caption': caption} 84 | return sample 85 | 86 | 87 | 88 | from torch.utils.data import DataLoader 89 | if __name__ == '__main__': 90 | """ 91 | This function is used to test datasets returns the correct 92 | """ 93 | # Load a test dataset 94 | dataset = get_coco_dataset() 95 | 96 | batch = dataset[0] 97 | print(batch['image'].shape, batch['caption']) 98 | 99 | dataloader = DataLoader(dataset, batch_size = 2) 100 | batch = next(iter(dataloader)) 101 | print(batch['image'].shape, batch['caption']) 102 | 103 | -------------------------------------------------------------------------------- /vlm/vlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Seems like we want our own module so we can train multiple elements at once 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import random 8 | 9 | from modeling import vision_encoder, llm, projection 10 | 11 | def build_vlm(image_encoder_str = 'google/vit-base-patch16-224', 12 | llm_str = "microsoft/Phi-3-mini-4k-instruct", 13 | projector_str = 'twoLayerLinear'): 14 | """ 15 | Assemble a VLM from an image encoder, projection, and llm 16 | """ 17 | # Image processor and image encoder model - loaded from a Huggingface ViT 18 | image_processor, image_encoder = vision_encoder.get_image_encoder(image_encoder_str, peft = True) 19 | 20 | # Language model tokenizer and llm 21 | #language_tokenizer, language_model = llm.get_llm("mistralai/Mistral-7B-v0.1") 22 | language_tokenizer, language_model = llm.get_llm(llm_str, peft = True) 23 | 24 | # "Image tokenizer" projects from image encoder transformer activations to LLM input dimension 25 | #image_tokenizer = projection.ImageTokenizer(in_dim = 768, out_dim = 3072) 26 | insize = projection.get_in_size(image_encoder) 27 | outsize = projection.get_out_size(language_tokenizer, language_model) 28 | image_tokenizer = projection.get_image_tokenizer(projector_str, insize, outsize) 29 | 30 | vlm = VisionLanguageModel(image_processor, 31 | image_encoder, 32 | image_tokenizer, 33 | language_tokenizer, 34 | language_model 35 | ) 36 | 37 | for name, param in vlm.vision_encoder.named_parameters(): 38 | param.requires_grad = False 39 | 40 | for name, param in vlm.language_model.named_parameters(): 41 | param.requires_grad = False 42 | 43 | return vlm 44 | 45 | 46 | class VisionLanguageModel(nn.Module): 47 | """ 48 | SAM for images 49 | """ 50 | def __init__(self, 51 | vision_processor, 52 | vision_encoder, 53 | vision_tokenizer, 54 | language_tokenizer, 55 | language_model, 56 | ): 57 | super(VisionLanguageModel, self).__init__() # initialize self._modules as OrderedDict - enables nested nn modules 58 | #self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 59 | 60 | # Model structural components 61 | self.vision_processor = vision_processor 62 | self.vision_encoder = vision_encoder 63 | self.vision_tokenizer = vision_tokenizer 64 | 65 | self.language_model = language_model 66 | self.language_tokenizer = language_tokenizer 67 | 68 | # Language components - custom tokens we use to format the prompt 69 | # bos_token = tokenizer.bos_token # begin sentence 70 | imstart_str = '<|imstart|>' 71 | imend_str = '<|imend|>' 72 | #textend_str = '<|endoftext|>' 73 | language_prompt = "This image contains: " 74 | 75 | language_tokenizer.add_tokens(imstart_str, special_tokens = True) 76 | language_tokenizer.add_tokens(imend_str, special_tokens = True) 77 | 78 | # Update langauge model's token embedding matrix 79 | self.language_model.resize_token_embeddings(len(self.language_tokenizer)) 80 | 81 | # Hold vector representing elements of our custom prompt 82 | self.start_vec = self.embed_ints(torch.tensor(self.language_tokenizer(imstart_str)['input_ids'])).unsqueeze(0) 83 | self.end_vec = self.embed_ints(torch.tensor(self.language_tokenizer(imend_str)['input_ids'])).unsqueeze(0) 84 | self.query_vec = self.embed_ints(torch.tensor(self.language_tokenizer(language_prompt)['input_ids'])).unsqueeze(0) 85 | #self.textend_vec = self.embed_ints(torch.tensor(self.language_tokenizer(textend_str)['input_ids'], device = self.device)).unsqueeze(0) 86 | 87 | # Loss 88 | self.bceloss = nn.BCEWithLogitsLoss() 89 | 90 | def forward(self, batch): 91 | """ 92 | Given a batch, format the input, do the forward, get the logits, return logits, loss 93 | Predicts the next token given an image, random substring of caption 94 | """ 95 | device = batch['image'].device 96 | tokenized_image = self.image_forward(batch['image']) 97 | 98 | # Tokenize string 99 | int_captions = torch.LongTensor(self.language_tokenizer(batch['caption'])['input_ids']).to(device) 100 | 101 | predict_at_index = random.randint(1, int_captions.shape[1] - 2) 102 | caption_prefix = self.embed_ints(int_captions[:, :predict_at_index]) 103 | caption_target = int_captions[:, predict_at_index] 104 | 105 | self.start_vec = self.start_vec.to(device) 106 | self.end_vec = self.end_vec.to(device) 107 | self.query_vec = self.query_vec.to(device) 108 | 109 | # Structure token sequence 110 | 111 | # Here, rather than simple concat, need to use the LLM text formatting, then break up the ids and embed them 112 | # then cat the embeddings which use the conversation formatting to the image embeddings to the suffix embeddings 113 | # same thing below for generate - need to cat in correct positions and make sure we are using the provided chat templtae a la tici 114 | llm_input = torch.cat((self.start_vec, tokenized_image, self.end_vec, self.query_vec, caption_prefix), dim = 1)#.permute(0,2,1) 115 | 116 | # Forward with frozen llm 117 | output = self.language_model.forward(inputs_embeds = llm_input) 118 | logits = output.logits 119 | 120 | #print(logits.shape) 121 | last_logit = logits[:, -1, :] 122 | 123 | loss = self.loss_function(last_logit, caption_target) 124 | 125 | return logits, loss 126 | 127 | def generate(self, batch, max_new_tokens): 128 | device = batch['image'].device 129 | #self.language_model.assisted_decoding 130 | 131 | # idx is (B, T) array of indices in the current context 132 | tokenized_image = self.image_forward(batch['image']) 133 | 134 | # get initial prompt for llm 135 | # Tokenize string 136 | int_captions = torch.LongTensor(self.language_tokenizer(batch['caption'])['input_ids']).to(device) 137 | 138 | predict_at_index = 0 139 | caption_prefix = self.embed_ints(int_captions[:, :predict_at_index]) 140 | 141 | self.start_vec = self.start_vec.to(device) 142 | self.end_vec = self.end_vec.to(device) 143 | self.query_vec = self.query_vec.to(device) 144 | 145 | # Structure token sequence 146 | llm_input = torch.cat((self.start_vec, tokenized_image, self.end_vec, self.query_vec, caption_prefix), dim = 1)#.permute(0,2,1) 147 | 148 | logit_outputs = [] 149 | for _ in range(max_new_tokens): 150 | # Forward on prompt 151 | outputs = self.language_model.forward(inputs_embeds = llm_input) 152 | logit_output = outputs.logits[:, -1, :] 153 | logit_outputs.append(logit_output) 154 | 155 | #print(llm_vector_prompt.shape, logit_output.shape) 156 | # Add EMBEDDED output to current sequence 157 | int_output = logit_output.argmax(dim = 1) 158 | #print('int output', int_output) 159 | new_vec = self.embed_ints(int_output).unsqueeze(0) 160 | llm_input = torch.cat((llm_input, new_vec), dim = 1) 161 | 162 | logit_outputs = torch.stack(logit_outputs, dim = 1) 163 | #print('generate constructed outputs', logit_outputs.shape) 164 | 165 | # Logits to ints 166 | int_outputs = logit_outputs.argmax(dim = 2) 167 | str_outputs = self.language_tokenizer.decode(int_outputs[0]) 168 | print('Caption: [', batch['caption'], ']') 169 | print('Str out: [', str_outputs, ']') 170 | return str_outputs 171 | 172 | def image_forward(self, image): 173 | """ 174 | Set of PIL images? 175 | Or single pil image? 176 | """ 177 | # Vision processor should have all necessary transforms and normalize 178 | #print('in img forward', image.shape) 179 | inputs = self.vision_processor(image, return_tensors='pt').to(self.vision_encoder.device) 180 | 181 | #print('what is this', type(image), image.shape) 182 | # Encode from image pixels to token sequence 183 | 184 | encoded_image = self.vision_encoder(**inputs, output_hidden_states = True) 185 | encoded_image = encoded_image.hidden_states[-1] 186 | 187 | #print('encoded', encoded_image.shape) 188 | # Project representation to language tokens with trainable 'image tokenizer' 189 | tokenized_image = self.vision_tokenizer(encoded_image) 190 | 191 | return tokenized_image 192 | 193 | def embed_ints(self, tokens): 194 | """ 195 | Use the model's existing integer tokens to return vector embeddings: 196 | """ 197 | return self.language_model.get_input_embeddings()(tokens) 198 | 199 | def loss_function(self, logits, int_labels): 200 | """ 201 | logits FloatTensor shape: [B*T, vocab_size] (sequence of probabilities over vocab) 202 | labels intTensor shape: [B*T] (sequence of int vocab positions) 203 | - what is b*t? Shouldn't loss be (b, vocab) (b, 1) -> ints 204 | """ 205 | 206 | return torch.nn.functional.cross_entropy(logits, int_labels) 207 | 208 | 209 | if __name__ == '__main__': 210 | """ 211 | Do a forward to check vlm works 212 | """ 213 | model = build_vlm() 214 | #model.print_trainable_parameters() 215 | 216 | --------------------------------------------------------------------------------