├── paligemma.py ├── .gitignore ├── train_idefics2.py ├── smolvlm.py ├── README.md ├── knowledge_distillation.md ├── LICENSE ├── Faster_foundation_models_with_torch_compile.ipynb ├── gemma3n_fine_tuning_on_all_modalities.py ├── Smol_VLM_FT.ipynb ├── Gemma_3n_Video_Vibe_Tests.ipynb └── Gemma3n_Fine_tuning_on_All_Modalities.ipynb /paligemma.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import torch 3 | from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, Trainer, TrainingArguments, BitsAndBytesConfig 4 | from peft import get_peft_model, LoraConfig 5 | import os 6 | 7 | USE_LORA = False 8 | USE_QLORA = False 9 | FREEZE_VISION = False 10 | 11 | ds = load_dataset('merve/vqav2-small', split="validation") 12 | ds = ds.train_test_split(test_size=0.5)["train"] 13 | 14 | model_id = "google/paligemma2-3b-pt-448" 15 | processor = PaliGemmaProcessor.from_pretrained(model_id) 16 | 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | image_token = processor.tokenizer.convert_tokens_to_ids("") 20 | 21 | def collate_fn(examples): 22 | texts = ["answer en " + example["question"] for example in examples] 23 | labels= [example['multiple_choice_answer'] for example in examples] 24 | images = [example["image"].convert("RGB") for example in examples] 25 | tokens = processor(text=texts, images=images, suffix=labels, 26 | return_tensors="pt", padding="longest") 27 | 28 | tokens = tokens.to(torch.bfloat16).to(device) 29 | return tokens 30 | 31 | 32 | if USE_LORA or USE_QLORA: 33 | lora_config = LoraConfig( 34 | r=8, 35 | target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], 36 | task_type="CAUSAL_LM", 37 | ) 38 | if USE_QLORA: 39 | bnb_config = BitsAndBytesConfig( 40 | load_in_4bit=True, 41 | bnb_4bit_quant_type="nf4", 42 | bnb_4bit_compute_type=torch.bfloat16 43 | ) 44 | model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto", 45 | quantization_config=bnb_config if USE_QLORA else None, 46 | torch_dtype=torch.bfloat16) 47 | model = get_peft_model(model, lora_config) 48 | model = model.to(device) 49 | model.print_trainable_parameters() 50 | else: 51 | model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto").to(device) 52 | model = model.to(device) 53 | 54 | if FREEZE_VISION: 55 | for param in model.vision_tower.parameters(): 56 | param.requires_grad = False 57 | 58 | for param in model.multi_modal_projector.parameters(): 59 | param.requires_grad = False 60 | 61 | 62 | args=TrainingArguments( 63 | num_train_epochs=3, 64 | remove_unused_columns=False, 65 | per_device_train_batch_size=4, 66 | gradient_accumulation_steps=4, 67 | warmup_steps=2, 68 | learning_rate=2e-5, 69 | weight_decay=1e-6, 70 | adam_beta2=0.999, 71 | logging_steps=100, 72 | optim="adamw_hf", 73 | save_strategy="steps", 74 | save_steps=1000, 75 | save_total_limit=1, 76 | push_to_hub=True 77 | output_dir="paligemma_vqav2", 78 | bf16=True, 79 | report_to=["tensorboard"], 80 | dataloader_pin_memory=False 81 | ) 82 | 83 | 84 | trainer = Trainer( 85 | model=model, 86 | train_dataset=ds , 87 | data_collator=collate_fn, 88 | args=args 89 | ) 90 | 91 | trainer.train() 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /train_idefics2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model 3 | from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration 4 | from datasets import load_dataset 5 | 6 | 7 | DEVICE = "cuda:4" 8 | PCI_BUS_ID=4 9 | CUDA_VISIBLE_DEVICES=4 10 | USE_LORA = False 11 | USE_QLORA = True 12 | model_id = "HuggingFaceM4/Idefics3-8B-Llama3" 13 | 14 | processor = AutoProcessor.from_pretrained( 15 | model_id 16 | ) 17 | 18 | 19 | if USE_QLORA or USE_LORA: 20 | lora_config = LoraConfig( 21 | r=8, 22 | lora_alpha=8, 23 | lora_dropout=0.1, 24 | target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'], 25 | use_dora=False if USE_QLORA else True, 26 | init_lora_weights="gaussian" 27 | ) 28 | lora_config.inference_mode = False 29 | if USE_QLORA: 30 | bnb_config = BitsAndBytesConfig( 31 | load_in_4bit=True, 32 | bnb_4bit_use_double_quant=True, 33 | bnb_4bit_quant_type="nf4", 34 | bnb_4bit_compute_dtype=torch.bfloat16 35 | ) 36 | 37 | model = Idefics3ForConditionalGeneration.from_pretrained( 38 | model_id, 39 | quantization_config=bnb_config if USE_QLORA else None, 40 | _attn_implementation="flash_attention_2", 41 | device_map="auto" 42 | ) 43 | model.add_adapter(lora_config) 44 | model.enable_adapters() 45 | model = prepare_model_for_kbit_training(model) 46 | model = get_peft_model(model, lora_config) 47 | print(model.get_nb_trainable_parameters()) 48 | 49 | 50 | else: 51 | model = Idefics3ForConditionalGeneration.from_pretrained( 52 | model_id, 53 | torch_dtype=torch.bfloat16, 54 | _attn_implementation="flash_attention_2", 55 | ).to(DEVICE) 56 | 57 | # if you'd like to only fine-tune LLM 58 | for param in model.model.vision_model.parameters(): 59 | param.requires_grad = False 60 | 61 | ds = load_dataset('merve/vqav2-small', trust_remote_code=True) 62 | split_ds = ds["validation"].train_test_split(test_size=0.8) 63 | train_ds = split_ds["train"] 64 | 65 | image_token_id = processor.tokenizer.additional_special_tokens_ids[ 66 | processor.tokenizer.additional_special_tokens.index("")] 67 | 68 | def collate_fn(examples): 69 | texts = [] 70 | images = [] 71 | for example in examples: 72 | image = example["image"] 73 | question = example["question"] 74 | answer = example["multiple_choice_answer"] 75 | messages = [ 76 | { 77 | "role": "user", 78 | "content": [ 79 | {"type": "text", "text": "Answer briefly."}, 80 | {"type": "image"}, 81 | {"type": "text", "text": question} 82 | ] 83 | }, 84 | { 85 | "role": "assistant", 86 | "content": [ 87 | {"type": "text", "text": answer} 88 | ] 89 | } 90 | ] 91 | text = processor.apply_chat_template(messages, add_generation_prompt=False) 92 | texts.append(text.strip()) 93 | images.append([image]) 94 | 95 | batch = processor(text=texts, images=images, return_tensors="pt", padding=True) 96 | labels = batch["input_ids"].clone() 97 | labels[labels == processor.tokenizer.pad_token_id] = -100 98 | labels[labels == image_token_id] = -100 99 | batch["labels"] = labels 100 | 101 | return batch 102 | 103 | from transformers import TrainingArguments, Trainer 104 | 105 | training_args = TrainingArguments( 106 | num_train_epochs=1, 107 | per_device_train_batch_size=1, # increase for QLoRA 108 | gradient_accumulation_steps=8, 109 | warmup_steps=50, 110 | learning_rate=1e-4, 111 | weight_decay=0.01, 112 | logging_steps=25, 113 | save_strategy="steps", 114 | save_steps=250, 115 | save_total_limit=1, 116 | optim="adamw_hf", # for 8-bit, pick paged_adamw_hf 117 | #evaluation_strategy="epoch", 118 | bf16=True, 119 | output_dir="./idefics3-llama-vqav2", 120 | hub_model_id="idefics3-llama-vqav2", 121 | remove_unused_columns=False, 122 | ) 123 | 124 | trainer = Trainer( 125 | model=model, 126 | args=training_args, 127 | data_collator=collate_fn, 128 | train_dataset=train_ds, 129 | ) 130 | 131 | trainer.train() 132 | trainer.push_to_hub() -------------------------------------------------------------------------------- /smolvlm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model 3 | from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration 4 | from transformers import TrainingArguments, Trainer 5 | from datasets import load_dataset 6 | import os 7 | from PIL import Image 8 | from transformers.image_utils import load_image 9 | 10 | USE_LORA = False 11 | USE_QLORA = True 12 | SMOL = True 13 | 14 | model_id = "HuggingFaceTB/SmolVLM-Base" if SMOL else "HuggingFaceM4/Idefics3-8B-Llama3" 15 | 16 | processor = AutoProcessor.from_pretrained( 17 | model_id 18 | ) 19 | 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "1, 4" 22 | if USE_QLORA or USE_LORA: 23 | lora_config = LoraConfig( 24 | r=8, 25 | lora_alpha=8, 26 | lora_dropout=0.1, 27 | target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'], 28 | use_dora=False if USE_QLORA else True, 29 | init_lora_weights="gaussian" 30 | ) 31 | lora_config.inference_mode = False 32 | if USE_QLORA: 33 | bnb_config = BitsAndBytesConfig( 34 | load_in_4bit=True, 35 | bnb_4bit_use_double_quant=True, 36 | bnb_4bit_quant_type="nf4", 37 | bnb_4bit_compute_dtype=torch.bfloat16 38 | ) 39 | 40 | model = Idefics3ForConditionalGeneration.from_pretrained( 41 | model_id, 42 | quantization_config=bnb_config if USE_QLORA else None, 43 | _attn_implementation="flash_attention_2", 44 | device_map="auto" 45 | ) 46 | model.add_adapter(lora_config) 47 | model.enable_adapters() 48 | model = prepare_model_for_kbit_training(model) 49 | model = get_peft_model(model, lora_config) 50 | print(model.get_nb_trainable_parameters()) 51 | else: 52 | model = Idefics3ForConditionalGeneration.from_pretrained( 53 | model_id, 54 | torch_dtype=torch.bfloat16, 55 | _attn_implementation="flash_attention_2", 56 | ).to(DEVICE) 57 | 58 | # if you'd like to only fine-tune LLM 59 | for param in model.model.vision_model.parameters(): 60 | param.requires_grad = False 61 | 62 | ds = load_dataset('merve/vqav2-small', trust_remote_code=True) 63 | 64 | split_ds = ds["validation"].train_test_split(test_size=0.8) 65 | train_ds = split_ds["train"] 66 | 67 | 68 | image_token_id = processor.tokenizer.additional_special_tokens_ids[ 69 | processor.tokenizer.additional_special_tokens.index("")] 70 | def collate_fn(examples): 71 | texts = [] 72 | images = [] 73 | for example in examples: 74 | image = example["image"] 75 | if image.mode != 'RGB': 76 | image = image.convert('RGB') 77 | question = example["question"] 78 | answer = example["multiple_choice_answer"] 79 | messages = [ 80 | { 81 | "role": "user", 82 | "content": [ 83 | {"type": "text", "text": "Answer briefly."}, 84 | {"type": "image"}, 85 | {"type": "text", "text": question} 86 | ] 87 | }, 88 | { 89 | "role": "assistant", 90 | "content": [ 91 | {"type": "text", "text": answer} 92 | ] 93 | } 94 | ] 95 | text = processor.apply_chat_template(messages, add_generation_prompt=False) 96 | texts.append(text.strip()) 97 | images.append([image]) 98 | 99 | batch = processor(text=texts, images=images, return_tensors="pt", padding=True) 100 | labels = batch["input_ids"].clone() 101 | labels[labels == processor.tokenizer.pad_token_id] = -100 102 | labels[labels == image_token_id] = -100 103 | batch["labels"] = labels 104 | 105 | return batch 106 | 107 | 108 | model_name = model_id.split("/")[-1] 109 | 110 | training_args = TrainingArguments( 111 | num_train_epochs=1, 112 | per_device_train_batch_size=8, 113 | gradient_accumulation_steps=4, 114 | warmup_steps=50, 115 | learning_rate=1e-4, 116 | weight_decay=0.01, 117 | logging_steps=25, 118 | save_strategy="steps", 119 | save_steps=250, 120 | save_total_limit=1, 121 | optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf 122 | bf16=True, # underlying precision for 8bit 123 | output_dir=f"./{model_name}-vqav2", 124 | hub_model_id=f"{model_name}-vqav2", 125 | report_to="tensorboard", 126 | remove_unused_columns=False, 127 | gradient_checkpointing=True 128 | ) 129 | trainer = Trainer( 130 | model=model, 131 | args=training_args, 132 | data_collator=collate_fn, 133 | train_dataset=train_ds, 134 | ) 135 | 136 | trainer.train() 137 | trainer.push_to_hub() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Smol](https://github.com/merveenoyan/smol-vision/assets/53175384/930d5b36-bb9d-4ab6-8b5a-4fec28c48f80) 2 | 3 | # Smol Vision 🐣 4 | 5 | Recipes for shrinking, optimizing, customizing cutting edge vision and multimodal AI models. 6 | 7 | Latest examples 👇🏻 8 | - [Fine-tune Kosmos2.5 on OCR with bounding boxes](https://github.com/merveenoyan/smol-vision/blob/main/Grounded_Fine_tuning%20GH.ipynb) 9 | - [Fine-tune Florence-2 on document question answering](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_Florence_2.ipynb) 10 | - [Fine-tune DINOv3 on image classification](https://github.com/merveenoyan/smol-vision/blob/main/DINOv3_FT.ipynb) 11 | 12 | **Note:** GitHub refuses to render notebooks for a long time now, so the notebooks of smol-vision with rich outputs now lives [here](https://huggingface.co/merve/smol-vision). I still update this repository but it's inconvenient to read here. 13 | 14 | | | Notebook | Description | 15 | |------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------| 16 | | Quantization/ONNX | [Faster and Smaller Zero-shot Object Detection with Optimum](https://github.com/merveenoyan/smol-vision/blob/main/Faster_Zero_shot_Object_Detection_with_Optimum.ipynb) | Quantize the state-of-the-art zero-shot object detection model OWLv2 using Optimum ONNXRuntime tools. | 17 | | VLM Fine-tuning | [Fine-tune PaliGemma](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb) | Fine-tune state-of-the-art vision language backbone PaliGemma using transformers. | 18 | | Intro to Optimum/ORT | [Optimizing DETR with 🤗 Optimum](https://github.com/merveenoyan/smol-vision/blob/main/Reduce_any_model_to_fp16_using_%F0%9F%A4%97_Optimum_DETR.ipynb) | A soft introduction to exporting vision models to ONNX and quantizing them. | 19 | | Model Shrinking | [Knowledge Distillation for Computer Vision](https://huggingface.co/docs/transformers/en/tasks/knowledge_distillation_for_image_classification) | Knowledge distillation for image classification. | 20 | | Quantization | [Fit in vision models using Quanto](https://github.com/merveenoyan/smol-vision/blob/main/Fit_in_vision_models_using_quanto.ipynb) | Fit in vision models to smaller hardware using quanto | 21 | | Speed-up | [Faster foundation models with torch.compile](https://github.com/merveenoyan/smol-vision/blob/main/Faster_foundation_models_with_torch_compile.ipynb) | Improving latency for foundation models using `torch.compile` | 22 | | [NEW] VLM Fine-tuning | [Fine-tune Florence-2](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_Florence_2.ipynb) | Fine-tune Florence-2 on DocVQA dataset | 23 | | VLM Fine-tuning | [QLoRA/Fine-tune IDEFICS3 or SmolVLM on VQAv2](https://github.com/merveenoyan/smol-vision/blob/main/Smol_VLM_FT.ipynb) | QLoRA/Full Fine-tune IDEFICS3 or SmolVLM on VQAv2 dataset | 24 | | VLM Fine-tuning (Script) | [QLoRA Fine-tune IDEFICS3 on VQAv2](https://github.com/merveenoyan/smol-vision/blob/main/smolvlm.py) | QLoRA/Full Fine-tune IDEFICS3 or SmolVLM on VQAv2 dataset | 25 | | [NEW] VLM Fine-tuning | [Grounded Fine-tuning](https://github.com/merveenoyan/smol-vision/blob/main/Grounded_Fine_tuning%20GH.ipynb) | Grounded fine-tuning for vision-language models | 26 | | [NEW] Vision Model Fine-tuning | [Fine-tune DINOv3](https://github.com/merveenoyan/smol-vision/blob/main/DINOv3_FT.ipynb) | Fine-tune DINOv3 for vision tasks | 27 | | Multimodal RAG | [Multimodal RAG using ColPali and Qwen2-VL](https://github.com/merveenoyan/smol-vision/blob/main/ColPali_%2B_Qwen2_VL.ipynb) | Learn to retrieve documents and pipeline to RAG without hefty document processing using ColPali through Byaldi and do the generation with Qwen2-VL | 28 | | Multimodal Retriever Fine-tuning | [Fine-tune ColPali for Multimodal RAG](https://github.com/merveenoyan/smol-vision/blob/main/Finetune_ColPali.ipynb) | Learn to apply contrastive fine-tuning on ColPali to customize it for your own multimodal document RAG use case | 29 | | Any-to-Any Fine-tuning | [Fine-tune Gemma-3n for all modalities (audio-text-image)](https://github.com/merveenoyan/smol-vision/blob/main/Gemma3n_Fine_tuning_on_All_Modalities.ipynb) | Fine-tune Gemma-3n model to handle any modality: audio, text, and image. | 30 | | Any-to-Any RAG | [Any-to-Any (Video) RAG with OmniEmbed and Qwen](https://github.com/merveenoyan/smol-vision/blob/main/Any_to_Any_RAG.ipynb) | Do retrieval and generation across modalities (including video) using OmniEmbed and Qwen. | 31 | | Speed-up/Memory Optimization | Vision language model serving using TGI (SOON) | Explore speed-ups and memory improvements for vision-language model serving with text-generation inference | 32 | | Quantization/Optimum/ORT | All levels of quantization and graph optimizations for Image Segmentation using Optimum (SOON) | End-to-end model optimization using Optimum | 33 | -------------------------------------------------------------------------------- /knowledge_distillation.md: -------------------------------------------------------------------------------- 1 | 16 | # Knowledge Distillation for Computer Vision 17 | 18 | [[open-in-colab]] 19 | 20 | Knowledge distillation is a technique used to transfer knowledge from a larger, more complex model (teacher) to a smaller, simpler model (student). To distill knowledge from one model to another, we take a pre-trained teacher model trained on a certain task (image classification for this case) and randomly initialize a student model to be trained on image classification. Next, we train the student model to minimize the difference between it's outputs and the teacher's outputs, thus making it mimic the behavior. It was first introduced in [Distilling the Knowledge in a Neural Network by Hinton et al](https://arxiv.org/abs/1503.02531). In this guide, we will do task-specific knowledge distillation. We will use the [beans dataset](https://huggingface.co/datasets/beans) for this. 21 | 22 | This guide demonstrates how you can distill a [fine-tuned ViT model](https://huggingface.co/merve/vit-mobilenet-beans-224) (teacher model) to a [MobileNet](https://huggingface.co/google/mobilenet_v2_1.4_224) (student model) using the [Trainer API](https://huggingface.co/docs/transformers/en/main_classes/trainer#trainer) of 🤗 Transformers. 23 | 24 | Let's install the libraries needed for distillation and evaluating the process. 25 | 26 | ```bash 27 | pip install transformers datasets accelerate tensorboard evaluate --upgrade 28 | ``` 29 | 30 | In this example, we are using the `merve/beans-vit-224` model as teacher model. It's an image classification model, based on `google/vit-base-patch16-224-in21k` fine-tuned on beans dataset. We will distill this model to a randomly initialized MobileNetV2. 31 | 32 | We will now load the dataset. 33 | 34 | ```python 35 | from datasets import load_dataset 36 | 37 | dataset = load_dataset("beans") 38 | ``` 39 | 40 | We can use an image processor from either of the models, as in this case they return the same output with same resolution. We will use the `map()` method of `dataset` to apply the preprocessing to every split of the dataset. 41 | 42 | ```python 43 | from transformers import AutoImageProcessor 44 | teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224") 45 | 46 | def process(examples): 47 | processed_inputs = teacher_processor(examples["image"]) 48 | return processed_inputs 49 | 50 | processed_datasets = dataset.map(process, batched=True) 51 | ``` 52 | 53 | Essentially, we want the student model (a randomly initialized MobileNet) to mimic the teacher model (fine-tuned vision transformer). To achieve this, we first get the logits output from the teacher and the student. Then, we divide each of them by the parameter `temperature` which controls the importance of each soft target. A parameter called `lambda` weighs the importance of the distillation loss. In this example, we will use `temperature=5` and `lambda=0.5`. We will use the Kullback-Leibler Divergence loss to compute the divergence between the student and teacher. Given two data P and Q, KL Divergence explains how much extra information we need to represent P using Q. If two are identical, their KL divergence is zero, as there's no other information needed to explain P from Q. Thus, in the context of knowledge distillation, KL divergence is useful. 54 | 55 | 56 | ```python 57 | from transformers import TrainingArguments, Trainer 58 | import torch 59 | import torch.nn as nn 60 | import torch.nn.functional as F 61 | 62 | 63 | class ImageDistilTrainer(Trainer): 64 | def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs): 65 | super().__init__(model=student_model, *args, **kwargs) 66 | self.teacher = teacher_model 67 | self.student = student_model 68 | self.loss_function = nn.KLDivLoss(reduction="batchmean") 69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 70 | self.teacher.to(device) 71 | self.teacher.eval() 72 | self.temperature = temperature 73 | self.lambda_param = lambda_param 74 | 75 | def compute_loss(self, student, inputs, return_outputs=False): 76 | student_output = self.student(**inputs) 77 | 78 | with torch.no_grad(): 79 | teacher_output = self.teacher(**inputs) 80 | 81 | # Compute soft targets for teacher and student 82 | soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1) 83 | soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1) 84 | 85 | # Compute the loss 86 | distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2) 87 | 88 | # Compute the true label loss 89 | student_target_loss = student_output.loss 90 | 91 | # Calculate final loss 92 | loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss 93 | return (loss, student_output) if return_outputs else loss 94 | ``` 95 | 96 | We will now login to Hugging Face Hub so we can push our model to the Hugging Face Hub through the `Trainer`. 97 | 98 | ```python 99 | from huggingface_hub import notebook_login 100 | 101 | notebook_login() 102 | ``` 103 | 104 | Let's set the `TrainingArguments`, the teacher model and the student model. 105 | 106 | ```python 107 | from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification 108 | 109 | training_args = TrainingArguments( 110 | output_dir="my-awesome-model", 111 | num_train_epochs=30, 112 | fp16=True, 113 | logging_dir=f"{repo_name}/logs", 114 | logging_strategy="epoch", 115 | eval_strategy="epoch", 116 | save_strategy="epoch", 117 | load_best_model_at_end=True, 118 | metric_for_best_model="accuracy", 119 | report_to="tensorboard", 120 | push_to_hub=True, 121 | hub_strategy="every_save", 122 | hub_model_id=repo_name, 123 | ) 124 | 125 | num_labels = len(processed_datasets["train"].features["labels"].names) 126 | 127 | # initialize models 128 | teacher_model = AutoModelForImageClassification.from_pretrained( 129 | "merve/beans-vit-224", 130 | num_labels=num_labels, 131 | ignore_mismatched_sizes=True 132 | ) 133 | 134 | # training MobileNetV2 from scratch 135 | student_config = MobileNetV2Config() 136 | student_config.num_labels = num_labels 137 | student_model = MobileNetV2ForImageClassification(student_config) 138 | ``` 139 | 140 | We can use `compute_metrics` function to evaluate our model on the test set. This function will be used during the training process to compute the `accuracy` & `f1` of our model. 141 | 142 | ```python 143 | import evaluate 144 | import numpy as np 145 | 146 | accuracy = evaluate.load("accuracy") 147 | 148 | def compute_metrics(eval_pred): 149 | predictions, labels = eval_pred 150 | acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1)) 151 | return {"accuracy": acc["accuracy"]} 152 | ``` 153 | 154 | Let's initialize the `Trainer` with the training arguments we defined. We will also initialize our data collator. 155 | 156 | ```python 157 | from transformers import DefaultDataCollator 158 | 159 | data_collator = DefaultDataCollator() 160 | trainer = ImageDistilTrainer( 161 | student_model=student_model, 162 | teacher_model=teacher_model, 163 | training_args=training_args, 164 | train_dataset=processed_datasets["train"], 165 | eval_dataset=processed_datasets["validation"], 166 | data_collator=data_collator, 167 | tokenizer=teacher_processor, 168 | compute_metrics=compute_metrics, 169 | temperature=5, 170 | lambda_param=0.5 171 | ) 172 | ``` 173 | 174 | We can now train our model. 175 | 176 | ```python 177 | trainer.train() 178 | ``` 179 | 180 | We can evaluate the model on the test set. 181 | 182 | ```python 183 | trainer.evaluate(processed_datasets["test"]) 184 | ``` 185 | 186 | On test set, our model reaches 72 percent accuracy. To have a sanity check over efficiency of distillation, we also trained MobileNet on the beans dataset from scratch with the same hyperparameters and observed 63 percent accuracy on the test set. We invite the readers to try different pre-trained teacher models, student architectures, distillation parameters and report their findings. The training logs and checkpoints for distilled model can be found in [this repository](https://huggingface.co/merve/vit-mobilenet-beans-224), and MobileNetV2 trained from scratch can be found in this [repository](https://huggingface.co/merve/resnet-mobilenet-beans-5). 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Faster_foundation_models_with_torch_compile.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "machine_shape": "hm", 8 | "gpuType": "L4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "# Faster Foundation Models with `torch.compile`" 24 | ], 25 | "metadata": { 26 | "id": "axYlcDTznci4" 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "## Introduction to `torch.compile()`" 33 | ], 34 | "metadata": { 35 | "id": "B-yw8KMWsjfY" 36 | } 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "This guide aims to provide a benchmark on the inference speed-ups introduced with `torch.compile()` with no reduction in model performance for foundation models in 🤗 Transformers.\n", 42 | "\n", 43 | "Most used `torch.compile` modes are following:\n", 44 | "\n", 45 | "- \"default\" is the default mode, which is a good balance between performance and overhead\n", 46 | "\n", 47 | "- \"reduce-overhead\" reduces the overhead of python with CUDA graphs, useful for small batches, consumes a lot of memory. As of now only works for CUDA only graphs which do not mutate inputs.\n", 48 | "\n", 49 | "If you have a lot of memory to use, the best speed-up is through `reduce-overhead`. How much speed-up one can get depends on the model, so in this tutorial we will check the most used foundation models." 50 | ], 51 | "metadata": { 52 | "id": "AmmT4aDnqgOB" 53 | } 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "source": [ 58 | "## OWLv2\n", 59 | "\n", 60 | "OWLv2 is a zero-shot object detection model released by Google Brain. We will load base version." 61 | ], 62 | "metadata": { 63 | "id": "5sCfbPTn7wBE" 64 | } 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "source": [ 69 | "Let's load the model and processor for OWLv2." 70 | ], 71 | "metadata": { 72 | "id": "joeX3J315K0G" 73 | } 74 | }, 75 | { 76 | "cell_type": "code", 77 | "source": [ 78 | "from PIL import Image\n", 79 | "import requests\n", 80 | "\n", 81 | "url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg'\n", 82 | "image = Image.open(requests.get(url, stream=True).raw)" 83 | ], 84 | "metadata": { 85 | "id": "Ztfcdqkul62z" 86 | }, 87 | "execution_count": 1, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "source": [ 93 | "from transformers import AutoProcessor, Owlv2ForObjectDetection\n", 94 | "import torch\n", 95 | "import numpy as np\n", 96 | "\n", 97 | "processor = AutoProcessor.from_pretrained(\"google/owlv2-base-patch16-ensemble\")\n", 98 | "model = Owlv2ForObjectDetection.from_pretrained(\"google/owlv2-base-patch16-ensemble\").to(\"cuda\")\n", 99 | "\n", 100 | "texts = [[\"a photo of a bee\", \"a photo of a bird\"]]\n", 101 | "inputs = processor(text=texts, images=image, return_tensors=\"pt\").to(\"cuda\")" 102 | ], 103 | "metadata": { 104 | "id": "84npPHCQpHZ6", 105 | "colab": { 106 | "base_uri": "https://localhost:8080/" 107 | }, 108 | "outputId": "f30c41c7-b897-460d-d2a4-a1276bf2263e" 109 | }, 110 | "execution_count": 2, 111 | "outputs": [ 112 | { 113 | "output_type": "stream", 114 | "name": "stderr", 115 | "text": [ 116 | "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", 117 | "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", 118 | "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", 119 | "You will be able to reuse this secret in all of your notebooks.\n", 120 | "Please note that authentication is recommended but still optional to access public models or datasets.\n", 121 | " warnings.warn(\n" 122 | ] 123 | } 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "source": [ 129 | "We can now get to benchmarking. We will benchmark the model itself and the compiled model." 130 | ], 131 | "metadata": { 132 | "id": "3AedkjLu5PRo" 133 | } 134 | }, 135 | { 136 | "cell_type": "code", 137 | "source": [ 138 | "starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", 139 | "repetitions = 30\n", 140 | "timings=np.zeros((repetitions,1))\n", 141 | "\n", 142 | "for _ in range(10):\n", 143 | " _ = model(**inputs)\n", 144 | "\n", 145 | "with torch.no_grad():\n", 146 | " for rep in range(repetitions):\n", 147 | " torch.cuda.synchronize()\n", 148 | " starter.record()\n", 149 | " output = model(**inputs)\n", 150 | " ender.record()\n", 151 | " torch.cuda.synchronize()\n", 152 | " curr_time = starter.elapsed_time(ender)\n", 153 | " timings[rep] = curr_time\n", 154 | "\n", 155 | "mean_syn = np.sum(timings) / repetitions\n", 156 | "print(mean_syn)\n" 157 | ], 158 | "metadata": { 159 | "id": "RQQSEgkQtXEV", 160 | "colab": { 161 | "base_uri": "https://localhost:8080/" 162 | }, 163 | "outputId": "8003590b-c4bc-4b3d-9b1b-dade853b8dd8" 164 | }, 165 | "execution_count": 3, 166 | "outputs": [ 167 | { 168 | "output_type": "stream", 169 | "name": "stdout", 170 | "text": [ 171 | "255.7331792195638\n" 172 | ] 173 | } 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "source": [ 179 | "starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", 180 | "timings=np.zeros((repetitions,1))\n", 181 | "\n", 182 | "compiled_model = torch.compile(model, mode=\"reduce-overhead\").to(\"cuda\")\n", 183 | "\n", 184 | "for _ in range(30):\n", 185 | " with torch.no_grad():\n", 186 | " _ = compiled_model(**inputs)\n", 187 | "\n", 188 | "\n", 189 | "with torch.no_grad():\n", 190 | " for rep in range(repetitions):\n", 191 | " torch.cuda.synchronize()\n", 192 | " starter.record()\n", 193 | " output = compiled_model(**inputs)\n", 194 | " ender.record()\n", 195 | " torch.cuda.synchronize()\n", 196 | " curr_time = starter.elapsed_time(ender)\n", 197 | " timings[rep] = curr_time\n", 198 | "\n", 199 | "mean_syn = np.sum(timings) / repetitions\n", 200 | "print(mean_syn)" 201 | ], 202 | "metadata": { 203 | "id": "bEZiNgaupOx6", 204 | "colab": { 205 | "base_uri": "https://localhost:8080/" 206 | }, 207 | "outputId": "e5d47875-1e40-4997-e533-94bf0ff34d14" 208 | }, 209 | "execution_count": 4, 210 | "outputs": [ 211 | { 212 | "output_type": "stream", 213 | "name": "stderr", 214 | "text": [ 215 | "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", 216 | " self.pid = os.fork()\n", 217 | "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n", 218 | " warnings.warn(\n", 219 | "skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from : \n", 220 | " File \"/usr/local/lib/python3.10/dist-packages/transformers/models/owlv2/modeling_owlv2.py\", line 1711, in forward\n", 221 | " pred_boxes = self.box_predictor(image_feats, feature_map)\n", 222 | " File \"/usr/local/lib/python3.10/dist-packages/transformers/models/owlv2/modeling_owlv2.py\", line 1374, in box_predictor\n", 223 | " box_bias = self.box_bias.to(feature_map.device)\n", 224 | "\n" 225 | ] 226 | }, 227 | { 228 | "output_type": "stream", 229 | "name": "stdout", 230 | "text": [ 231 | "154.6884775797526\n" 232 | ] 233 | } 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "source": [ 239 | "We got nearly 40 percent speed-up! You can also increase the batch size and see how much further speed-up you can get." 240 | ], 241 | "metadata": { 242 | "id": "d_0d7DwN6gBt" 243 | } 244 | }, 245 | { 246 | "cell_type": "code", 247 | "source": [ 248 | "texts = [[\"a photo of a bee\", \"a photo of a bird\"] for _ in range(8)]\n", 249 | "images = [image for _ in range(8)]\n", 250 | "inputs = processor(text=texts, images=image, return_tensors=\"pt\").to(\"cuda\")" 251 | ], 252 | "metadata": { 253 | "id": "exKoOptB61UL" 254 | }, 255 | "execution_count": 11, 256 | "outputs": [] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "source": [ 261 | "starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", 262 | "repetitions = 30\n", 263 | "timings=np.zeros((repetitions,1))\n", 264 | "\n", 265 | "for _ in range(10):\n", 266 | " _ = model(**inputs)\n", 267 | "\n", 268 | "with torch.no_grad():\n", 269 | " for rep in range(repetitions):\n", 270 | " torch.cuda.synchronize()\n", 271 | " starter.record()\n", 272 | " output = model(**inputs)\n", 273 | " ender.record()\n", 274 | " torch.cuda.synchronize()\n", 275 | " curr_time = starter.elapsed_time(ender)\n", 276 | " timings[rep] = curr_time\n", 277 | "\n", 278 | "mean_syn = np.sum(timings) / repetitions\n", 279 | "print(mean_syn)" 280 | ], 281 | "metadata": { 282 | "colab": { 283 | "base_uri": "https://localhost:8080/" 284 | }, 285 | "id": "EFj9Pgra7Km8", 286 | "outputId": "5fefb8c0-9e86-478c-e9e2-0dbc0fa8a37b" 287 | }, 288 | "execution_count": 12, 289 | "outputs": [ 290 | { 291 | "output_type": "stream", 292 | "name": "stdout", 293 | "text": [ 294 | "269.3023401896159\n" 295 | ] 296 | } 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "source": [ 302 | "starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", 303 | "timings=np.zeros((repetitions,1))\n", 304 | "\n", 305 | "compiled_model = torch.compile(model, mode=\"reduce-overhead\").to(\"cuda\")\n", 306 | "\n", 307 | "for _ in range(30):\n", 308 | " with torch.no_grad():\n", 309 | " _ = compiled_model(**inputs)\n", 310 | "\n", 311 | "\n", 312 | "with torch.no_grad():\n", 313 | " for rep in range(repetitions):\n", 314 | " torch.cuda.synchronize()\n", 315 | " starter.record()\n", 316 | " output = compiled_model(**inputs)\n", 317 | " ender.record()\n", 318 | " torch.cuda.synchronize()\n", 319 | " curr_time = starter.elapsed_time(ender)\n", 320 | " timings[rep] = curr_time\n", 321 | "\n", 322 | "mean_syn = np.sum(timings) / repetitions\n", 323 | "print(mean_syn)" 324 | ], 325 | "metadata": { 326 | "colab": { 327 | "base_uri": "https://localhost:8080/" 328 | }, 329 | "id": "OuQZmgTK7UCo", 330 | "outputId": "7184eb1d-b545-4bb6-b544-3effd5c2545a" 331 | }, 332 | "execution_count": 13, 333 | "outputs": [ 334 | { 335 | "output_type": "stream", 336 | "name": "stdout", 337 | "text": [ 338 | "159.77137603759766\n" 339 | ] 340 | } 341 | ] 342 | } 343 | ] 344 | } -------------------------------------------------------------------------------- /gemma3n_fine_tuning_on_all_modalities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Gemma3n Fine-tuning on All Modalities.ipynb 3 | 4 | Automatically generated by Colab. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4 8 | 9 | # Fine-tune Gemma3n on FineVideo 10 | 11 | In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside. 12 | Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM. 13 | """ 14 | 15 | !pip install -U -q timm transformers trl peft datasets 16 | 17 | import io 18 | import os 19 | import zipfile 20 | 21 | import torch 22 | from datasets import load_dataset 23 | from PIL import Image 24 | from transformers import AutoProcessor, Gemma3nForConditionalGeneration 25 | 26 | from trl import ( 27 | SFTConfig, 28 | SFTTrainer, 29 | ) 30 | 31 | """## Download videos and preprocessing 32 | 33 | FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs. 34 | """ 35 | 36 | from datasets import load_dataset 37 | import json 38 | import os 39 | 40 | dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True) 41 | 42 | 43 | os.makedirs("videos", exist_ok=True) 44 | os.makedirs("metadata", exist_ok=True) 45 | 46 | for idx, sample in enumerate(dataset): 47 | data = sample["json"] 48 | duration = data.get("duration_seconds", 0) 49 | if duration < 30: 50 | video_filename = f"videos/sample_{idx}.mp4" 51 | with open(video_filename, 'wb') as video_file: 52 | video_file.write(sample['mp4']) 53 | 54 | json_filename = f"metadata/sample_{idx}.json" 55 | with open(json_filename, 'w') as json_file: 56 | json.dump(sample['json'], json_file) 57 | 58 | print(f"Number of items in content/videos: {len(os.listdir('videos'))}") 59 | 60 | """In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them.""" 61 | 62 | import cv2 63 | from PIL import Image 64 | import numpy as np 65 | 66 | def is_dark(frame, threshold=10): 67 | return np.max(frame) < threshold # all pixels are very close to 0 68 | 69 | def downsample_video(video_path): 70 | vidcap = cv2.VideoCapture(video_path) 71 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 72 | fps = vidcap.get(cv2.CAP_PROP_FPS) 73 | 74 | frames = [] 75 | 76 | # Generate 8 evenly spaced indices, skip first and last 77 | full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1] 78 | 79 | for i in full_indices: 80 | found_valid = False 81 | for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark 82 | candidate_idx = i + offset 83 | if 0 <= candidate_idx < total_frames: 84 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx) 85 | success, image = vidcap.read() 86 | if success: 87 | if not is_dark(image): 88 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 89 | pil_image = Image.fromarray(image) 90 | timestamp = round(candidate_idx / fps, 2) 91 | frames.append((pil_image, timestamp)) 92 | found_valid = True 93 | break 94 | if not found_valid: 95 | print(f"Warning: Could not find non-dark frame near index {i}") 96 | 97 | vidcap.release() 98 | 99 | # If still fewer than 8, try to top off by scanning more frames 100 | if len(frames) < 6: 101 | print("Trying to top off with additional non-dark frames...") 102 | idx = 0 103 | while len(frames) < 8 and idx < total_frames: 104 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx) 105 | success, image = vidcap.read() 106 | if success and not is_dark(image): 107 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 108 | pil_image = Image.fromarray(image) 109 | timestamp = round(idx / fps, 2) 110 | # Avoid adding duplicate timestamps 111 | if not any(ts == timestamp for _, ts in frames): 112 | frames.append((pil_image, timestamp)) 113 | idx += 1 114 | 115 | return frames[:8] # Ensure exactly 8 frames 116 | 117 | import os 118 | import glob 119 | 120 | def remove_dark_videos(video_dir, metadata_dir, audio_dir): 121 | """ 122 | Remove videos (and their metadata/audio files) if all frames are dark. 123 | """ 124 | video_paths = glob.glob(os.path.join(video_dir, "*.mp4")) 125 | 126 | for video_path in video_paths: 127 | filename = os.path.basename(video_path) 128 | base_name = os.path.splitext(filename)[0] 129 | 130 | frames = downsample_video(video_path) 131 | if len(frames) < 6: 132 | try: 133 | os.remove(video_path) 134 | print(f"Deleted: {video_path}") 135 | except Exception as e: 136 | print(f"Failed to delete {video_path}: {e}") 137 | 138 | metadata_path = os.path.join(metadata_dir, f"{base_name}.json") 139 | if os.path.exists(metadata_path): 140 | os.remove(metadata_path) 141 | 142 | # Remove audio 143 | audio_path = os.path.join(audio_dir, f"{base_name}.wav") 144 | if os.path.exists(audio_path): 145 | os.remove(audio_path) 146 | 147 | remove_dark_videos( 148 | video_dir="videos", 149 | metadata_dir="metadata", 150 | audio_dir="audios" 151 | ) 152 | 153 | """Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video.""" 154 | 155 | import os 156 | import subprocess 157 | 158 | video_dir = "videos" 159 | audio_dir = "audios" 160 | os.makedirs(audio_dir, exist_ok=True) 161 | 162 | for filename in os.listdir(video_dir): 163 | if not filename.endswith(".mp4"): 164 | continue 165 | 166 | idx = filename.split("_")[1].split(".")[0] 167 | video_path = os.path.join(video_dir, filename) 168 | audio_path = os.path.join(audio_dir, f"sample_{idx}.wav") 169 | 170 | subprocess.run([ 171 | "ffmpeg", "-i", video_path, 172 | "-q:a", "0", "-map", "a", 173 | audio_path, 174 | "-y" 175 | ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 176 | 177 | """Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes.""" 178 | 179 | from datasets import Dataset 180 | import json 181 | 182 | def gen(): 183 | meta_dir = "metadata" 184 | for filename in os.listdir(meta_dir): 185 | if not filename.endswith(".json"): 186 | continue 187 | 188 | idx = filename.split("_")[1].split(".")[0] 189 | if os.path.exists(f"videos/sample_{idx}.mp4"): 190 | video_filename = f"sample_{idx}.mp4" 191 | audio_filename = f"sample_{idx}.wav" 192 | json_path = os.path.join(meta_dir, filename) 193 | 194 | with open(json_path, "r") as f: 195 | metadata = json.load(f) 196 | 197 | 198 | yield { 199 | "video": video_filename, 200 | "audio": audio_filename, 201 | "content_parent_category": metadata["content_parent_category"], 202 | "sample_index": int(idx) 203 | } 204 | else: 205 | pass 206 | 207 | dataset = Dataset.from_generator(gen) 208 | 209 | """We will speed-up and downsample the audios to save space during training.""" 210 | 211 | import torchaudio 212 | from torchaudio.transforms import Resample 213 | import os 214 | import torch 215 | 216 | def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25): 217 | waveform, sample_rate = torchaudio.load(audio_path) 218 | 219 | if waveform.shape[0] > 1: 220 | waveform = waveform.mean(dim=0, keepdim=True) 221 | 222 | if sample_rate != target_sample_rate: 223 | resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate) 224 | waveform = resampler(waveform) 225 | sample_rate = target_sample_rate 226 | 227 | if speedup_factor > 1.0: 228 | indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long() 229 | if indices[-1] >= waveform.shape[1]: 230 | indices = indices[:-1] 231 | waveform = waveform[:, indices] 232 | 233 | max_length = int(target_sample_rate * max_duration_sec) 234 | if waveform.shape[1] > max_length: 235 | waveform = waveform[:, :max_length] 236 | 237 | torchaudio.save(audio_path, waveform, sample_rate) 238 | 239 | for file_name in os.listdir("audios"): 240 | if file_name.lower().endswith(".wav"): 241 | audio_path = os.path.join("audios", file_name) 242 | preprocess_audio(audio_path) 243 | 244 | dataset = dataset.train_test_split(test_size=0.10, seed=42) 245 | 246 | """### Load the model 247 | 248 | Make sure you have your Hugging Face token in your Colab secrets. 249 | """ 250 | 251 | model = Gemma3nForConditionalGeneration.from_pretrained( 252 | "google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16, 253 | ) 254 | processor = AutoProcessor.from_pretrained( 255 | "google/gemma-3n-E2B-it", 256 | ) 257 | processor.tokenizer.padding_side = "right" 258 | 259 | processor.tokenizer.all_special_ids 260 | 261 | """Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities. 262 | 263 | In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames. 264 | """ 265 | 266 | def collate_fn(examples): 267 | video_path = examples[0]["video"] 268 | audio_path = examples[0]["audio"] 269 | sample_idx = filename.split("_")[1].split(".")[0] 270 | frames = downsample_video(f"videos/{video_path}") 271 | 272 | text = "Based on the video, predict the category of it." 273 | message = [ 274 | { 275 | "role": "user", 276 | "content": [ 277 | {"type": "text", "text": text} 278 | ], 279 | }, 280 | ] 281 | # this is how video inference should be formatted in Gemma3n 282 | for frame in frames: 283 | image, timestamp = frame 284 | message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) 285 | timestamp = str(timestamp).replace(".", "_") 286 | image.save(f"image_idx_{sample_idx}_{timestamp}.png") 287 | message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"}) 288 | 289 | message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"}) 290 | message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]}) 291 | inputs = processor.apply_chat_template( 292 | message, 293 | add_generation_prompt=False, 294 | tokenize=True, 295 | return_dict=True, 296 | return_tensors="pt", 297 | padding=True, 298 | ).to(model.device) 299 | 300 | labels = inputs["input_ids"].clone() 301 | special_token_ids = processor.tokenizer.all_special_ids 302 | 303 | special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device) 304 | mask = torch.isin(labels, special_token_ids_tensor) 305 | labels[mask] = -100 306 | 307 | inputs["labels"] = labels 308 | if torch.all(inputs["pixel_values"] == 0): 309 | print("Frames are dark") 310 | 311 | return inputs 312 | 313 | """## Training 314 | 315 | We do LoRA fine-tuning again to save up on space. 316 | """ 317 | 318 | from peft import LoraConfig 319 | peft_config = LoraConfig( 320 | task_type="CAUSAL_LM", 321 | r=16, 322 | target_modules="all-linear", 323 | lora_alpha=32, 324 | lora_dropout=0.05, 325 | bias="none", 326 | use_rslora=False, 327 | use_dora=False, 328 | modules_to_save=None 329 | ) 330 | 331 | model.gradient_checkpointing_disable() 332 | 333 | model.config.use_cache = False 334 | 335 | training_args = SFTConfig( 336 | output_dir="/content/gemma-3n-finevideo", 337 | eval_strategy='epoch', 338 | per_device_train_batch_size=1, 339 | per_device_eval_batch_size=1, 340 | gradient_accumulation_steps=4, 341 | gradient_checkpointing=False, 342 | learning_rate=1e-05, 343 | num_train_epochs=3.0, 344 | logging_steps=10, 345 | save_steps=100, 346 | bf16=True, 347 | report_to=["tensorboard"], 348 | dataset_kwargs={'skip_prepare_dataset': True}, 349 | remove_unused_columns=False, 350 | max_seq_length=None, 351 | push_to_hub=True, 352 | dataloader_pin_memory=False, 353 | ) 354 | 355 | trainer = SFTTrainer( 356 | model=model, 357 | args=training_args, 358 | data_collator=collate_fn, 359 | train_dataset=dataset["train"], 360 | eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None, 361 | processing_class=processor.tokenizer, 362 | peft_config=peft_config, 363 | ) 364 | 365 | trainer.train() 366 | 367 | """Test the model with a video of snowboarding.""" 368 | 369 | !wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4 370 | 371 | model = trainer.model # trainer has the adapter 372 | 373 | """Strip audio and downsample video.""" 374 | 375 | audio_path = "/content/test_audio.wav" 376 | subprocess.run([ 377 | "ffmpeg", "-i", "/content/IMG_8137.mp4", 378 | "-q:a", "0", "-map", "a", 379 | f"{audio_path}", 380 | "-y" 381 | ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 382 | 383 | frames = downsample_video("/content/IMG_8137.mp4") 384 | 385 | # repeat the chat template 386 | text = "Based on the video, predict the category of it." 387 | message = [ 388 | { 389 | "role": "user", 390 | "content": [ 391 | {"type": "text", "text": text} 392 | ], 393 | }, 394 | ] 395 | for frame in frames: 396 | image, timestamp = frame 397 | message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) 398 | timestamp = str(timestamp).replace(".", "_") 399 | image.save(f"test_frame_{timestamp}.png") 400 | message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"}) 401 | 402 | message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"}) 403 | 404 | message 405 | 406 | inputs = processor.apply_chat_template( 407 | message, 408 | add_generation_prompt=True, 409 | tokenize=True, 410 | return_dict=True, 411 | return_tensors="pt", 412 | padding=True, 413 | ).to(model.device).to(model.dtype) 414 | 415 | input_len = inputs["input_ids"].shape[-1] 416 | 417 | with torch.inference_mode(): 418 | generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) 419 | generation = generation[0][input_len:] 420 | 421 | decoded = processor.decode(generation, skip_special_tokens=True) 422 | print(decoded) 423 | 424 | """Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗""" 425 | 426 | -------------------------------------------------------------------------------- /Smol_VLM_FT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "nc0g2NLpUSGr" 7 | }, 8 | "source": [ 9 | "# Fine-tune SmolVLM on Visual Question Answering using Consumer GPU with QLoRA\n", 10 | "\n", 11 | "In this notebook we will fine-tune SmolVLM VQAv2 dataset. With this notebook you can also fine-tune Idefics3, since both models have the same model class/architecture.\n", 12 | "\n", 13 | "We will use some techniques in this notebook that will let you fine-tune the model on L4 with batch size of 4 only using around 16.4 GB of VRAM. We ran this notebook in that setup to test, but because we were able to afford A100 this notebook was last ran on an A100." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "WIhA1lQ7j0kw", 24 | "outputId": "d152531d-8a63-459f-d0b5-f61a47b268d2" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install -q accelerate datasets peft bitsandbytes tensorboard" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "colab": { 36 | "base_uri": "https://localhost:8080/" 37 | }, 38 | "id": "XyJaqZZ3uYYl", 39 | "outputId": "eff31ad7-7a77-4391-a1ed-6a871e667be5" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "!pip install -q flash-attn --no-build-isolation" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "wAeMA0heVBjT" 50 | }, 51 | "source": [ 52 | "We will push out model to Hub so we need to authenticate ourselves." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "colab": { 60 | "base_uri": "https://localhost:8080/", 61 | "height": 17, 62 | "referenced_widgets": [ 63 | "261a3abc28d74e4ca5af6f9df8cea3e5", 64 | "b6284cfacfd642278a7809a154463d69", 65 | "62c12672f59349b9ade248bee799fa5a", 66 | "9af532f878ab491096358d3bc83250d8", 67 | "599303d9f1204c85bca500c859dd0d87", 68 | "00617a46b15d45648c4796a91c96ec57", 69 | "5492da586f594365afc30ee6da1bf67c", 70 | "86aa1abb905346bf8956754a9704f250", 71 | "eeb2fbfd6cd54c4aa3983dc334a5377d", 72 | "ed34441fca164b389dfea1eabdba6e4a", 73 | "99f5b0432c1849128fa181b88925c77b", 74 | "5e529d6d6c4e40b4863961ea63bf259a", 75 | "ebfcd83e42ec46afb772d53ad7f35d43", 76 | "94958be916d6439d87dcd45c59178bec", 77 | "31a0c4a7fcff4744be56adf4125ef4e6", 78 | "2c975a8158bf49b389d47a5c4e40c97b", 79 | "b474bf8f464d40d8865665e4c7f0a411", 80 | "f8a75ac273fc408f923bf9d7f7263db8", 81 | "dd08ce6386184df38f47348e547738d8", 82 | "3aef5e8d5d9e4bd29bd3790ad139c02c" 83 | ] 84 | }, 85 | "id": "yKd5xtSGj7cm", 86 | "outputId": "63b352c0-3f7d-4945-add2-52102246d7b2" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "from huggingface_hub import notebook_login\n", 91 | "\n", 92 | "notebook_login()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "id": "WRq8ve-LVAzU" 99 | }, 100 | "source": [ 101 | "In this notebook we will not do full fine-tuning but use QLoRA method, which loads an adapter to the quantized version of the model, saving space. If you want to do full fine-tuning, set `USE_LORA` and `USE_QLORA` to False. If you want to do LoRA, set `USE_QLORA` to False and `USE_LORA` to True." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 1, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "import os\n", 111 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 112 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1, 2\"" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "WIVhpp0EyZO2" 119 | }, 120 | "source": [ 121 | "The model as is is holding 2.7 GB of GPU RAM 💗" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": { 127 | "id": "LMTtg3dl3NX2" 128 | }, 129 | "source": [ 130 | "## Loading the dataset and Preprocessing" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "id": "pWHMWTSZ3Pyr" 137 | }, 138 | "source": [ 139 | "We will load a small portion of the VQAv2 dataset. We are loading a small portion of the model for education purposes." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 3, 145 | "metadata": { 146 | "id": "POOqKqYRka5O" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "from datasets import load_dataset\n", 151 | "ds = load_dataset('merve/vqav2-small', trust_remote_code=True)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "import torch\n", 161 | "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n", 162 | "from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration\n", 163 | "\n", 164 | "USE_LORA = False\n", 165 | "USE_QLORA = True\n", 166 | "SMOL = True\n", 167 | "\n", 168 | "model_id = \"HuggingFaceTB/SmolVLM-Base\" if SMOL else \"HuggingFaceM4/Idefics3-8B-Llama3\"\n", 169 | "\n", 170 | "processor = AutoProcessor.from_pretrained(\n", 171 | " model_id\n", 172 | ")\n", 173 | "\n", 174 | "if USE_QLORA or USE_LORA:\n", 175 | " lora_config = LoraConfig(\n", 176 | " r=8,\n", 177 | " lora_alpha=8,\n", 178 | " lora_dropout=0.1,\n", 179 | " target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],\n", 180 | " use_dora=False if USE_QLORA else True,\n", 181 | " init_lora_weights=\"gaussian\"\n", 182 | " )\n", 183 | " lora_config.inference_mode = False\n", 184 | " if USE_QLORA:\n", 185 | " bnb_config = BitsAndBytesConfig(\n", 186 | " load_in_4bit=True,\n", 187 | " bnb_4bit_use_double_quant=True,\n", 188 | " bnb_4bit_quant_type=\"nf4\",\n", 189 | " bnb_4bit_compute_dtype=torch.bfloat16\n", 190 | " )\n", 191 | "\n", 192 | " model = Idefics3ForConditionalGeneration.from_pretrained(\n", 193 | " model_id,\n", 194 | " quantization_config=bnb_config if USE_QLORA else None,\n", 195 | " _attn_implementation=\"flash_attention_2\",\n", 196 | " device_map=\"auto\"\n", 197 | " )\n", 198 | " model.add_adapter(lora_config)\n", 199 | " model.enable_adapters()\n", 200 | " model = prepare_model_for_kbit_training(model)\n", 201 | " model = get_peft_model(model, lora_config)\n", 202 | " print(model.get_nb_trainable_parameters())\n", 203 | "else:\n", 204 | " model = Idefics3ForConditionalGeneration.from_pretrained(\n", 205 | " model_id,\n", 206 | " torch_dtype=torch.bfloat16,\n", 207 | " _attn_implementation=\"flash_attention_2\",\n", 208 | " ).to(DEVICE)\n", 209 | "\n", 210 | " # if you'd like to only fine-tune LLM\n", 211 | " for param in model.model.vision_model.parameters():\n", 212 | " param.requires_grad = False" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 6, 218 | "metadata": { 219 | "id": "Znf9vMo5rnSd" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "split_ds = ds[\"validation\"].train_test_split(test_size=0.5)\n", 224 | "train_ds = split_ds[\"train\"]" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 7, 230 | "metadata": { 231 | "colab": { 232 | "base_uri": "https://localhost:8080/" 233 | }, 234 | "id": "FIDioFlRuYYn", 235 | "outputId": "79b697a7-d245-4fdc-b0e8-d9ffa8627953" 236 | }, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "Dataset({\n", 242 | " features: ['multiple_choice_answer', 'question', 'image'],\n", 243 | " num_rows: 10717\n", 244 | "})" 245 | ] 246 | }, 247 | "execution_count": 7, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "train_ds" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": { 259 | "id": "5nwMO3n0X7Hv" 260 | }, 261 | "source": [ 262 | "Let's write our data collating function. We will apply prompt template to have questions and answers together so model can learn to answer. Then we pass the formatted prompts and images to the processor which processes both." 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 8, 268 | "metadata": { 269 | "id": "e0krVLZ-wNMl" 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "image_token_id = processor.tokenizer.additional_special_tokens_ids[\n", 274 | " processor.tokenizer.additional_special_tokens.index(\"\")]\n", 275 | "\n", 276 | "def collate_fn(examples):\n", 277 | " texts = []\n", 278 | " images = []\n", 279 | " for example in examples:\n", 280 | " image = example[\"image\"]\n", 281 | " if image.mode != 'RGB':\n", 282 | " image = image.convert('RGB')\n", 283 | " question = example[\"question\"]\n", 284 | " answer = example[\"multiple_choice_answer\"]\n", 285 | " messages = [\n", 286 | " {\n", 287 | " \"role\": \"user\",\n", 288 | " \"content\": [\n", 289 | " {\"type\": \"text\", \"text\": \"Answer briefly.\"},\n", 290 | " {\"type\": \"image\"},\n", 291 | " {\"type\": \"text\", \"text\": question}\n", 292 | " ]\n", 293 | " },\n", 294 | " {\n", 295 | " \"role\": \"assistant\",\n", 296 | " \"content\": [\n", 297 | " {\"type\": \"text\", \"text\": answer}\n", 298 | " ]\n", 299 | " }\n", 300 | " ]\n", 301 | " text = processor.apply_chat_template(messages, add_generation_prompt=False)\n", 302 | " texts.append(text.strip())\n", 303 | " images.append([image])\n", 304 | "\n", 305 | " batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n", 306 | " labels = batch[\"input_ids\"].clone()\n", 307 | " labels[labels == processor.tokenizer.pad_token_id] = -100\n", 308 | " labels[labels == image_token_id] = -100\n", 309 | " batch[\"labels\"] = labels\n", 310 | "\n", 311 | " return batch" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": { 317 | "id": "kEYDjWpE3LD5" 318 | }, 319 | "source": [ 320 | "## Training" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": { 326 | "id": "QvAs896cdwg8" 327 | }, 328 | "source": [ 329 | "We can now initialize `Trainer` and initialize `TrainingArguments` to pass to `Trainer`.\n", 330 | "\n", 331 | "Some notes:\n", 332 | "- If you use 8-bit QLoRA with the below setup it uses around 16.4 GB VRAM (beautiful, fits comfortably inside L4, Colab free tier)\n", 333 | "- We use gradient accumulation to simulate a larger batch size.\n", 334 | "- We also save up on memory from intermediate activations by using gradient checkpointing.\n", 335 | "\n", 336 | "**Disclaimer:** \n", 337 | "The techniques here aren't free lunch. The latter two will add additional compute to the training, thus slow down a bit (for reference on two A100s with bsz of 16, we were able to train for 2 hrs 43 mins with the gradient accumulation steps of 4, disabling it reduced it with 2 hr 35 mins). \n", 338 | "If you want to speed-up, you might play around, reduce to 4-bit precision and have a higher batch size. Note that 4-bit might result in model learning less." 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 15, 344 | "metadata": { 345 | "id": "QNE2yWAYrAhD" 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "from transformers import TrainingArguments, Trainer\n", 350 | "\n", 351 | "model_name = model_id.split(\"/\")[-1]\n", 352 | "\n", 353 | "training_args = TrainingArguments(\n", 354 | " num_train_epochs=1,\n", 355 | " per_device_train_batch_size=16,\n", 356 | " gradient_accumulation_steps=4,\n", 357 | " warmup_steps=50,\n", 358 | " learning_rate=1e-4,\n", 359 | " weight_decay=0.01,\n", 360 | " logging_steps=25,\n", 361 | " save_strategy=\"steps\",\n", 362 | " save_steps=250,\n", 363 | " save_total_limit=1,\n", 364 | " optim=\"paged_adamw_8bit\", # for 8-bit, keep this, else adamw_hf\n", 365 | " bf16=True, # underlying precision for 8bit\n", 366 | " output_dir=f\"./{model_name}-vqav2\",\n", 367 | " hub_model_id=f\"{model_name}-vqav2\",\n", 368 | " report_to=\"tensorboard\",\n", 369 | " remove_unused_columns=False,\n", 370 | " gradient_checkpointing=True\n", 371 | ")\n" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 16, 377 | "metadata": { 378 | "id": "oBBSDpBhreJd" 379 | }, 380 | "outputs": [ 381 | { 382 | "name": "stderr", 383 | "output_type": "stream", 384 | "text": [ 385 | "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "trainer = Trainer(\n", 391 | " model=model,\n", 392 | " args=training_args,\n", 393 | " data_collator=collate_fn,\n", 394 | " train_dataset=train_ds,\n", 395 | ")" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "trainer.train()" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "metadata": { 411 | "id": "0hN0QD9_uYYo" 412 | }, 413 | "outputs": [], 414 | "source": [ 415 | "trainer.push_to_hub()" 416 | ] 417 | } 418 | ], 419 | "metadata": { 420 | "accelerator": "GPU", 421 | "colab": { 422 | "gpuType": "A100", 423 | "provenance": [] 424 | }, 425 | "kernelspec": { 426 | "display_name": "Python 3 (ipykernel)", 427 | "language": "python", 428 | "name": "python3" 429 | }, 430 | "language_info": { 431 | "codemirror_mode": { 432 | "name": "ipython", 433 | "version": 3 434 | }, 435 | "file_extension": ".py", 436 | "mimetype": "text/x-python", 437 | "name": "python", 438 | "nbconvert_exporter": "python", 439 | "pygments_lexer": "ipython3", 440 | "version": "3.12.4" 441 | }, 442 | "widgets": { 443 | "application/vnd.jupyter.widget-state+json": { 444 | "00617a46b15d45648c4796a91c96ec57": { 445 | "model_module": "@jupyter-widgets/controls", 446 | "model_module_version": "1.5.0", 447 | "model_name": "HTMLModel", 448 | "state": { 449 | "_dom_classes": [], 450 | "_model_module": "@jupyter-widgets/controls", 451 | "_model_module_version": "1.5.0", 452 | "_model_name": "HTMLModel", 453 | "_view_count": null, 454 | "_view_module": "@jupyter-widgets/controls", 455 | "_view_module_version": "1.5.0", 456 | "_view_name": "HTMLView", 457 | "description": "", 458 | "description_tooltip": null, 459 | "layout": "IPY_MODEL_2c975a8158bf49b389d47a5c4e40c97b", 460 | "placeholder": "​", 461 | "style": "IPY_MODEL_b474bf8f464d40d8865665e4c7f0a411", 462 | "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. " 463 | } 464 | }, 465 | "261a3abc28d74e4ca5af6f9df8cea3e5": { 466 | "model_module": "@jupyter-widgets/controls", 467 | "model_module_version": "1.5.0", 468 | "model_name": "VBoxModel", 469 | "state": { 470 | "_dom_classes": [], 471 | "_model_module": "@jupyter-widgets/controls", 472 | "_model_module_version": "1.5.0", 473 | "_model_name": "VBoxModel", 474 | "_view_count": null, 475 | "_view_module": "@jupyter-widgets/controls", 476 | "_view_module_version": "1.5.0", 477 | "_view_name": "VBoxView", 478 | "box_style": "", 479 | "children": [], 480 | "layout": "IPY_MODEL_5492da586f594365afc30ee6da1bf67c" 481 | } 482 | }, 483 | "2c975a8158bf49b389d47a5c4e40c97b": { 484 | "model_module": "@jupyter-widgets/base", 485 | "model_module_version": "1.2.0", 486 | "model_name": "LayoutModel", 487 | "state": { 488 | "_model_module": "@jupyter-widgets/base", 489 | "_model_module_version": "1.2.0", 490 | "_model_name": "LayoutModel", 491 | "_view_count": null, 492 | "_view_module": "@jupyter-widgets/base", 493 | "_view_module_version": "1.2.0", 494 | "_view_name": "LayoutView", 495 | "align_content": null, 496 | "align_items": null, 497 | "align_self": null, 498 | "border": null, 499 | "bottom": null, 500 | "display": null, 501 | "flex": null, 502 | "flex_flow": null, 503 | "grid_area": null, 504 | "grid_auto_columns": null, 505 | "grid_auto_flow": null, 506 | "grid_auto_rows": null, 507 | "grid_column": null, 508 | "grid_gap": null, 509 | "grid_row": null, 510 | "grid_template_areas": null, 511 | "grid_template_columns": null, 512 | "grid_template_rows": null, 513 | "height": null, 514 | "justify_content": null, 515 | "justify_items": null, 516 | "left": null, 517 | "margin": null, 518 | "max_height": null, 519 | "max_width": null, 520 | "min_height": null, 521 | "min_width": null, 522 | "object_fit": null, 523 | "object_position": null, 524 | "order": null, 525 | "overflow": null, 526 | "overflow_x": null, 527 | "overflow_y": null, 528 | "padding": null, 529 | "right": null, 530 | "top": null, 531 | "visibility": null, 532 | "width": null 533 | } 534 | }, 535 | "31a0c4a7fcff4744be56adf4125ef4e6": { 536 | "model_module": "@jupyter-widgets/controls", 537 | "model_module_version": "1.5.0", 538 | "model_name": "ButtonStyleModel", 539 | "state": { 540 | "_model_module": "@jupyter-widgets/controls", 541 | "_model_module_version": "1.5.0", 542 | "_model_name": "ButtonStyleModel", 543 | "_view_count": null, 544 | "_view_module": "@jupyter-widgets/base", 545 | "_view_module_version": "1.2.0", 546 | "_view_name": "StyleView", 547 | "button_color": null, 548 | "font_weight": "" 549 | } 550 | }, 551 | "3aef5e8d5d9e4bd29bd3790ad139c02c": { 552 | "model_module": "@jupyter-widgets/controls", 553 | "model_module_version": "1.5.0", 554 | "model_name": "DescriptionStyleModel", 555 | "state": { 556 | "_model_module": "@jupyter-widgets/controls", 557 | "_model_module_version": "1.5.0", 558 | "_model_name": "DescriptionStyleModel", 559 | "_view_count": null, 560 | "_view_module": "@jupyter-widgets/base", 561 | "_view_module_version": "1.2.0", 562 | "_view_name": "StyleView", 563 | "description_width": "" 564 | } 565 | }, 566 | "5492da586f594365afc30ee6da1bf67c": { 567 | "model_module": "@jupyter-widgets/base", 568 | "model_module_version": "1.2.0", 569 | "model_name": "LayoutModel", 570 | "state": { 571 | "_model_module": "@jupyter-widgets/base", 572 | "_model_module_version": "1.2.0", 573 | "_model_name": "LayoutModel", 574 | "_view_count": null, 575 | "_view_module": "@jupyter-widgets/base", 576 | "_view_module_version": "1.2.0", 577 | "_view_name": "LayoutView", 578 | "align_content": null, 579 | "align_items": "center", 580 | "align_self": null, 581 | "border": null, 582 | "bottom": null, 583 | "display": "flex", 584 | "flex": null, 585 | "flex_flow": "column", 586 | "grid_area": null, 587 | "grid_auto_columns": null, 588 | "grid_auto_flow": null, 589 | "grid_auto_rows": null, 590 | "grid_column": null, 591 | "grid_gap": null, 592 | "grid_row": null, 593 | "grid_template_areas": null, 594 | "grid_template_columns": null, 595 | "grid_template_rows": null, 596 | "height": null, 597 | "justify_content": null, 598 | "justify_items": null, 599 | "left": null, 600 | "margin": null, 601 | "max_height": null, 602 | "max_width": null, 603 | "min_height": null, 604 | "min_width": null, 605 | "object_fit": null, 606 | "object_position": null, 607 | "order": null, 608 | "overflow": null, 609 | "overflow_x": null, 610 | "overflow_y": null, 611 | "padding": null, 612 | "right": null, 613 | "top": null, 614 | "visibility": null, 615 | "width": "50%" 616 | } 617 | }, 618 | "599303d9f1204c85bca500c859dd0d87": { 619 | "model_module": "@jupyter-widgets/controls", 620 | "model_module_version": "1.5.0", 621 | "model_name": "ButtonModel", 622 | "state": { 623 | "_dom_classes": [], 624 | "_model_module": "@jupyter-widgets/controls", 625 | "_model_module_version": "1.5.0", 626 | "_model_name": "ButtonModel", 627 | "_view_count": null, 628 | "_view_module": "@jupyter-widgets/controls", 629 | "_view_module_version": "1.5.0", 630 | "_view_name": "ButtonView", 631 | "button_style": "", 632 | "description": "Login", 633 | "disabled": false, 634 | "icon": "", 635 | "layout": "IPY_MODEL_94958be916d6439d87dcd45c59178bec", 636 | "style": "IPY_MODEL_31a0c4a7fcff4744be56adf4125ef4e6", 637 | "tooltip": "" 638 | } 639 | }, 640 | "5e529d6d6c4e40b4863961ea63bf259a": { 641 | "model_module": "@jupyter-widgets/base", 642 | "model_module_version": "1.2.0", 643 | "model_name": "LayoutModel", 644 | "state": { 645 | "_model_module": "@jupyter-widgets/base", 646 | "_model_module_version": "1.2.0", 647 | "_model_name": "LayoutModel", 648 | "_view_count": null, 649 | "_view_module": "@jupyter-widgets/base", 650 | "_view_module_version": "1.2.0", 651 | "_view_name": "LayoutView", 652 | "align_content": null, 653 | "align_items": null, 654 | "align_self": null, 655 | "border": null, 656 | "bottom": null, 657 | "display": null, 658 | "flex": null, 659 | "flex_flow": null, 660 | "grid_area": null, 661 | "grid_auto_columns": null, 662 | "grid_auto_flow": null, 663 | "grid_auto_rows": null, 664 | "grid_column": null, 665 | "grid_gap": null, 666 | "grid_row": null, 667 | "grid_template_areas": null, 668 | "grid_template_columns": null, 669 | "grid_template_rows": null, 670 | "height": null, 671 | "justify_content": null, 672 | "justify_items": null, 673 | "left": null, 674 | "margin": null, 675 | "max_height": null, 676 | "max_width": null, 677 | "min_height": null, 678 | "min_width": null, 679 | "object_fit": null, 680 | "object_position": null, 681 | "order": null, 682 | "overflow": null, 683 | "overflow_x": null, 684 | "overflow_y": null, 685 | "padding": null, 686 | "right": null, 687 | "top": null, 688 | "visibility": null, 689 | "width": null 690 | } 691 | }, 692 | "62c12672f59349b9ade248bee799fa5a": { 693 | "model_module": "@jupyter-widgets/controls", 694 | "model_module_version": "1.5.0", 695 | "model_name": "PasswordModel", 696 | "state": { 697 | "_dom_classes": [], 698 | "_model_module": "@jupyter-widgets/controls", 699 | "_model_module_version": "1.5.0", 700 | "_model_name": "PasswordModel", 701 | "_view_count": null, 702 | "_view_module": "@jupyter-widgets/controls", 703 | "_view_module_version": "1.5.0", 704 | "_view_name": "PasswordView", 705 | "continuous_update": true, 706 | "description": "Token:", 707 | "description_tooltip": null, 708 | "disabled": false, 709 | "layout": "IPY_MODEL_ed34441fca164b389dfea1eabdba6e4a", 710 | "placeholder": "​", 711 | "style": "IPY_MODEL_99f5b0432c1849128fa181b88925c77b", 712 | "value": "" 713 | } 714 | }, 715 | "86aa1abb905346bf8956754a9704f250": { 716 | "model_module": "@jupyter-widgets/base", 717 | "model_module_version": "1.2.0", 718 | "model_name": "LayoutModel", 719 | "state": { 720 | "_model_module": "@jupyter-widgets/base", 721 | "_model_module_version": "1.2.0", 722 | "_model_name": "LayoutModel", 723 | "_view_count": null, 724 | "_view_module": "@jupyter-widgets/base", 725 | "_view_module_version": "1.2.0", 726 | "_view_name": "LayoutView", 727 | "align_content": null, 728 | "align_items": null, 729 | "align_self": null, 730 | "border": null, 731 | "bottom": null, 732 | "display": null, 733 | "flex": null, 734 | "flex_flow": null, 735 | "grid_area": null, 736 | "grid_auto_columns": null, 737 | "grid_auto_flow": null, 738 | "grid_auto_rows": null, 739 | "grid_column": null, 740 | "grid_gap": null, 741 | "grid_row": null, 742 | "grid_template_areas": null, 743 | "grid_template_columns": null, 744 | "grid_template_rows": null, 745 | "height": null, 746 | "justify_content": null, 747 | "justify_items": null, 748 | "left": null, 749 | "margin": null, 750 | "max_height": null, 751 | "max_width": null, 752 | "min_height": null, 753 | "min_width": null, 754 | "object_fit": null, 755 | "object_position": null, 756 | "order": null, 757 | "overflow": null, 758 | "overflow_x": null, 759 | "overflow_y": null, 760 | "padding": null, 761 | "right": null, 762 | "top": null, 763 | "visibility": null, 764 | "width": null 765 | } 766 | }, 767 | "94958be916d6439d87dcd45c59178bec": { 768 | "model_module": "@jupyter-widgets/base", 769 | "model_module_version": "1.2.0", 770 | "model_name": "LayoutModel", 771 | "state": { 772 | "_model_module": "@jupyter-widgets/base", 773 | "_model_module_version": "1.2.0", 774 | "_model_name": "LayoutModel", 775 | "_view_count": null, 776 | "_view_module": "@jupyter-widgets/base", 777 | "_view_module_version": "1.2.0", 778 | "_view_name": "LayoutView", 779 | "align_content": null, 780 | "align_items": null, 781 | "align_self": null, 782 | "border": null, 783 | "bottom": null, 784 | "display": null, 785 | "flex": null, 786 | "flex_flow": null, 787 | "grid_area": null, 788 | "grid_auto_columns": null, 789 | "grid_auto_flow": null, 790 | "grid_auto_rows": null, 791 | "grid_column": null, 792 | "grid_gap": null, 793 | "grid_row": null, 794 | "grid_template_areas": null, 795 | "grid_template_columns": null, 796 | "grid_template_rows": null, 797 | "height": null, 798 | "justify_content": null, 799 | "justify_items": null, 800 | "left": null, 801 | "margin": null, 802 | "max_height": null, 803 | "max_width": null, 804 | "min_height": null, 805 | "min_width": null, 806 | "object_fit": null, 807 | "object_position": null, 808 | "order": null, 809 | "overflow": null, 810 | "overflow_x": null, 811 | "overflow_y": null, 812 | "padding": null, 813 | "right": null, 814 | "top": null, 815 | "visibility": null, 816 | "width": null 817 | } 818 | }, 819 | "99f5b0432c1849128fa181b88925c77b": { 820 | "model_module": "@jupyter-widgets/controls", 821 | "model_module_version": "1.5.0", 822 | "model_name": "DescriptionStyleModel", 823 | "state": { 824 | "_model_module": "@jupyter-widgets/controls", 825 | "_model_module_version": "1.5.0", 826 | "_model_name": "DescriptionStyleModel", 827 | "_view_count": null, 828 | "_view_module": "@jupyter-widgets/base", 829 | "_view_module_version": "1.2.0", 830 | "_view_name": "StyleView", 831 | "description_width": "" 832 | } 833 | }, 834 | "9af532f878ab491096358d3bc83250d8": { 835 | "model_module": "@jupyter-widgets/controls", 836 | "model_module_version": "1.5.0", 837 | "model_name": "CheckboxModel", 838 | "state": { 839 | "_dom_classes": [], 840 | "_model_module": "@jupyter-widgets/controls", 841 | "_model_module_version": "1.5.0", 842 | "_model_name": "CheckboxModel", 843 | "_view_count": null, 844 | "_view_module": "@jupyter-widgets/controls", 845 | "_view_module_version": "1.5.0", 846 | "_view_name": "CheckboxView", 847 | "description": "Add token as git credential?", 848 | "description_tooltip": null, 849 | "disabled": false, 850 | "indent": true, 851 | "layout": "IPY_MODEL_5e529d6d6c4e40b4863961ea63bf259a", 852 | "style": "IPY_MODEL_ebfcd83e42ec46afb772d53ad7f35d43", 853 | "value": true 854 | } 855 | }, 856 | "b474bf8f464d40d8865665e4c7f0a411": { 857 | "model_module": "@jupyter-widgets/controls", 858 | "model_module_version": "1.5.0", 859 | "model_name": "DescriptionStyleModel", 860 | "state": { 861 | "_model_module": "@jupyter-widgets/controls", 862 | "_model_module_version": "1.5.0", 863 | "_model_name": "DescriptionStyleModel", 864 | "_view_count": null, 865 | "_view_module": "@jupyter-widgets/base", 866 | "_view_module_version": "1.2.0", 867 | "_view_name": "StyleView", 868 | "description_width": "" 869 | } 870 | }, 871 | "b6284cfacfd642278a7809a154463d69": { 872 | "model_module": "@jupyter-widgets/controls", 873 | "model_module_version": "1.5.0", 874 | "model_name": "HTMLModel", 875 | "state": { 876 | "_dom_classes": [], 877 | "_model_module": "@jupyter-widgets/controls", 878 | "_model_module_version": "1.5.0", 879 | "_model_name": "HTMLModel", 880 | "_view_count": null, 881 | "_view_module": "@jupyter-widgets/controls", 882 | "_view_module_version": "1.5.0", 883 | "_view_name": "HTMLView", 884 | "description": "", 885 | "description_tooltip": null, 886 | "layout": "IPY_MODEL_86aa1abb905346bf8956754a9704f250", 887 | "placeholder": "​", 888 | "style": "IPY_MODEL_eeb2fbfd6cd54c4aa3983dc334a5377d", 889 | "value": "

Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.
" 890 | } 891 | }, 892 | "dd08ce6386184df38f47348e547738d8": { 893 | "model_module": "@jupyter-widgets/base", 894 | "model_module_version": "1.2.0", 895 | "model_name": "LayoutModel", 896 | "state": { 897 | "_model_module": "@jupyter-widgets/base", 898 | "_model_module_version": "1.2.0", 899 | "_model_name": "LayoutModel", 900 | "_view_count": null, 901 | "_view_module": "@jupyter-widgets/base", 902 | "_view_module_version": "1.2.0", 903 | "_view_name": "LayoutView", 904 | "align_content": null, 905 | "align_items": null, 906 | "align_self": null, 907 | "border": null, 908 | "bottom": null, 909 | "display": null, 910 | "flex": null, 911 | "flex_flow": null, 912 | "grid_area": null, 913 | "grid_auto_columns": null, 914 | "grid_auto_flow": null, 915 | "grid_auto_rows": null, 916 | "grid_column": null, 917 | "grid_gap": null, 918 | "grid_row": null, 919 | "grid_template_areas": null, 920 | "grid_template_columns": null, 921 | "grid_template_rows": null, 922 | "height": null, 923 | "justify_content": null, 924 | "justify_items": null, 925 | "left": null, 926 | "margin": null, 927 | "max_height": null, 928 | "max_width": null, 929 | "min_height": null, 930 | "min_width": null, 931 | "object_fit": null, 932 | "object_position": null, 933 | "order": null, 934 | "overflow": null, 935 | "overflow_x": null, 936 | "overflow_y": null, 937 | "padding": null, 938 | "right": null, 939 | "top": null, 940 | "visibility": null, 941 | "width": null 942 | } 943 | }, 944 | "ebfcd83e42ec46afb772d53ad7f35d43": { 945 | "model_module": "@jupyter-widgets/controls", 946 | "model_module_version": "1.5.0", 947 | "model_name": "DescriptionStyleModel", 948 | "state": { 949 | "_model_module": "@jupyter-widgets/controls", 950 | "_model_module_version": "1.5.0", 951 | "_model_name": "DescriptionStyleModel", 952 | "_view_count": null, 953 | "_view_module": "@jupyter-widgets/base", 954 | "_view_module_version": "1.2.0", 955 | "_view_name": "StyleView", 956 | "description_width": "" 957 | } 958 | }, 959 | "ed34441fca164b389dfea1eabdba6e4a": { 960 | "model_module": "@jupyter-widgets/base", 961 | "model_module_version": "1.2.0", 962 | "model_name": "LayoutModel", 963 | "state": { 964 | "_model_module": "@jupyter-widgets/base", 965 | "_model_module_version": "1.2.0", 966 | "_model_name": "LayoutModel", 967 | "_view_count": null, 968 | "_view_module": "@jupyter-widgets/base", 969 | "_view_module_version": "1.2.0", 970 | "_view_name": "LayoutView", 971 | "align_content": null, 972 | "align_items": null, 973 | "align_self": null, 974 | "border": null, 975 | "bottom": null, 976 | "display": null, 977 | "flex": null, 978 | "flex_flow": null, 979 | "grid_area": null, 980 | "grid_auto_columns": null, 981 | "grid_auto_flow": null, 982 | "grid_auto_rows": null, 983 | "grid_column": null, 984 | "grid_gap": null, 985 | "grid_row": null, 986 | "grid_template_areas": null, 987 | "grid_template_columns": null, 988 | "grid_template_rows": null, 989 | "height": null, 990 | "justify_content": null, 991 | "justify_items": null, 992 | "left": null, 993 | "margin": null, 994 | "max_height": null, 995 | "max_width": null, 996 | "min_height": null, 997 | "min_width": null, 998 | "object_fit": null, 999 | "object_position": null, 1000 | "order": null, 1001 | "overflow": null, 1002 | "overflow_x": null, 1003 | "overflow_y": null, 1004 | "padding": null, 1005 | "right": null, 1006 | "top": null, 1007 | "visibility": null, 1008 | "width": null 1009 | } 1010 | }, 1011 | "eeb2fbfd6cd54c4aa3983dc334a5377d": { 1012 | "model_module": "@jupyter-widgets/controls", 1013 | "model_module_version": "1.5.0", 1014 | "model_name": "DescriptionStyleModel", 1015 | "state": { 1016 | "_model_module": "@jupyter-widgets/controls", 1017 | "_model_module_version": "1.5.0", 1018 | "_model_name": "DescriptionStyleModel", 1019 | "_view_count": null, 1020 | "_view_module": "@jupyter-widgets/base", 1021 | "_view_module_version": "1.2.0", 1022 | "_view_name": "StyleView", 1023 | "description_width": "" 1024 | } 1025 | }, 1026 | "f8a75ac273fc408f923bf9d7f7263db8": { 1027 | "model_module": "@jupyter-widgets/controls", 1028 | "model_module_version": "1.5.0", 1029 | "model_name": "LabelModel", 1030 | "state": { 1031 | "_dom_classes": [], 1032 | "_model_module": "@jupyter-widgets/controls", 1033 | "_model_module_version": "1.5.0", 1034 | "_model_name": "LabelModel", 1035 | "_view_count": null, 1036 | "_view_module": "@jupyter-widgets/controls", 1037 | "_view_module_version": "1.5.0", 1038 | "_view_name": "LabelView", 1039 | "description": "", 1040 | "description_tooltip": null, 1041 | "layout": "IPY_MODEL_dd08ce6386184df38f47348e547738d8", 1042 | "placeholder": "​", 1043 | "style": "IPY_MODEL_3aef5e8d5d9e4bd29bd3790ad139c02c", 1044 | "value": "Connecting..." 1045 | } 1046 | } 1047 | } 1048 | } 1049 | }, 1050 | "nbformat": 4, 1051 | "nbformat_minor": 4 1052 | } 1053 | -------------------------------------------------------------------------------- /Gemma_3n_Video_Vibe_Tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "onFz3_7AqnaB" 17 | }, 18 | "source": [ 19 | "## Gemma 3n Video with Audio Inference" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "KKUnhy4JqqAg" 26 | }, 27 | "source": [ 28 | "In this notebook we'll infer Gemma-3n videos with audios inside." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "Vf-VvnrNjuxF" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "!pip install -U -q transformers timm datasets" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "gcJbxIPLqvjH" 46 | }, 47 | "source": [ 48 | "We will load three examples from FineVideo dataset and Gemma-3n model so make sure you have access to both and provide access token." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "from huggingface_hub import login\n", 58 | "login()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "from transformers import AutoProcessor, Gemma3nForConditionalGeneration\n", 68 | "import torch\n", 69 | "model = Gemma3nForConditionalGeneration.from_pretrained(\n", 70 | " \"google/gemma-3n-E4B-it\", torch_dtype=torch.bfloat16,\n", 71 | ").to(\"cuda\")\n", 72 | "processor = AutoProcessor.from_pretrained(\n", 73 | " \"google/gemma-3n-E4B-it\",\n", 74 | ")\n", 75 | "processor.tokenizer.padding_side = \"right\"" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "mQzrURJlNRwW" 82 | }, 83 | "source": [ 84 | "Download video for inference." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "KXlBj7dVtUFZ" 100 | }, 101 | "source": [ 102 | "Strip audios from video." 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "colab": { 110 | "base_uri": "https://localhost:8080/" 111 | }, 112 | "id": "FQhKimtlMOHe", 113 | "outputId": "ef05231a-ce56-4733-b0be-d6b423a143ae" 114 | }, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "CompletedProcess(args=['ffmpeg', '-i', 'IMG_8137.mp4', '-q:a', '0', '-map', 'a', 'audios/audio.wav', '-y'], returncode=0)" 120 | ] 121 | }, 122 | "execution_count": 57, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "import os\n", 129 | "import subprocess\n", 130 | "filename = \"IMG_8137.mp4\"\n", 131 | "audio_path = os.path.join(\"audios\", f\"audio.wav\")\n", 132 | "\n", 133 | "subprocess.run([\n", 134 | " \"ffmpeg\", \"-i\", filename,\n", 135 | " \"-q:a\", \"0\", \"-map\", \"a\",\n", 136 | " audio_path,\n", 137 | " \"-y\"\n", 138 | "], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "id": "6e_cExwMjx7v" 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "import cv2\n", 150 | "from PIL import Image\n", 151 | "import numpy as np\n", 152 | "\n", 153 | "def downsample_video(video_path):\n", 154 | " vidcap = cv2.VideoCapture(video_path)\n", 155 | " total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n", 156 | " fps = vidcap.get(cv2.CAP_PROP_FPS)\n", 157 | "\n", 158 | " frames = []\n", 159 | " frame_indices = np.linspace(0, total_frames - 1, 7, dtype=int)\n", 160 | "\n", 161 | " for i in frame_indices:\n", 162 | " vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)\n", 163 | " success, image = vidcap.read()\n", 164 | " if success:\n", 165 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert from BGR to RGB\n", 166 | " pil_image = Image.fromarray(image)\n", 167 | " timestamp = round(i / fps, 2)\n", 168 | " frames.append((pil_image, timestamp))\n", 169 | "\n", 170 | " vidcap.release()\n", 171 | " return frames\n" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": { 177 | "id": "mRKCPRabuMs6" 178 | }, 179 | "source": [ 180 | "We will generate descriptions to videos and compare them to irl description in the metadata for the vibecheck.\n", 181 | "\n", 182 | "We need to downsample video to frames." 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "id": "UMJESbFulYTi" 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "frames = downsample_video(filename)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "colab": { 201 | "base_uri": "https://localhost:8080/" 202 | }, 203 | "id": "wJKdYXasMfEG", 204 | "outputId": "2cff578c-df4d-41ca-8d9e-f85b4fed3456" 205 | }, 206 | "outputs": [ 207 | { 208 | "data": { 209 | "text/plain": [ 210 | "[(, np.float64(0.0)),\n", 211 | " (, np.float64(1.03)),\n", 212 | " (, np.float64(2.09)),\n", 213 | " (, np.float64(3.12)),\n", 214 | " (, np.float64(4.17)),\n", 215 | " (, np.float64(5.21)),\n", 216 | " (, np.float64(6.26))]" 217 | ] 218 | }, 219 | "execution_count": 52, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "frames" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "id": "u8itVHCflZYQ" 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "messages = [\n", 237 | " {\n", 238 | " \"role\": \"system\",\n", 239 | " \"content\": [{\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}]\n", 240 | " },\n", 241 | " {\n", 242 | " \"role\": \"user\",\n", 243 | " \"content\": [\n", 244 | " {\"type\": \"text\", \"text\": f\"What is happening in this video? Summarize the events.\"}]\n", 245 | " }\n", 246 | "]\n", 247 | "for frame in frames:\n", 248 | " image, timestamp = frame\n", 249 | " messages[1][\"content\"].append({\"type\": \"text\", \"text\": f\"Frame {timestamp}: \"})\n", 250 | " image.save(f\"image_{timestamp}.png\")\n", 251 | " messages[1][\"content\"].append({\"type\": \"image\", \"url\": f\"./image_{timestamp}.png\"})\n", 252 | "messages[1][\"content\"].append({\"type\": \"audio\", \"audio\": f\"audios/audio.wav\"})" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "colab": { 260 | "base_uri": "https://localhost:8080/" 261 | }, 262 | "id": "dBX4mNxXxGoC", 263 | "outputId": "b738e828-bf9b-4f13-bbb2-9f38bea50b6a" 264 | }, 265 | "outputs": [ 266 | { 267 | "data": { 268 | "text/plain": [ 269 | "[{'role': 'system',\n", 270 | " 'content': [{'type': 'text', 'text': 'You are a helpful assistant.'}]},\n", 271 | " {'role': 'user',\n", 272 | " 'content': [{'type': 'text',\n", 273 | " 'text': 'What is happening in this video? Summarize the events.'},\n", 274 | " {'type': 'text', 'text': 'Frame 0.0: '},\n", 275 | " {'type': 'image', 'url': './image_0.0.png'},\n", 276 | " {'type': 'text', 'text': 'Frame 1.03: '},\n", 277 | " {'type': 'image', 'url': './image_1.03.png'},\n", 278 | " {'type': 'text', 'text': 'Frame 2.09: '},\n", 279 | " {'type': 'image', 'url': './image_2.09.png'},\n", 280 | " {'type': 'text', 'text': 'Frame 3.12: '},\n", 281 | " {'type': 'image', 'url': './image_3.12.png'},\n", 282 | " {'type': 'text', 'text': 'Frame 4.17: '},\n", 283 | " {'type': 'image', 'url': './image_4.17.png'},\n", 284 | " {'type': 'text', 'text': 'Frame 5.21: '},\n", 285 | " {'type': 'image', 'url': './image_5.21.png'},\n", 286 | " {'type': 'text', 'text': 'Frame 6.26: '},\n", 287 | " {'type': 'image', 'url': './image_6.26.png'},\n", 288 | " {'type': 'audio', 'audio': 'audios/audio.wav'}]}]" 289 | ] 290 | }, 291 | "execution_count": 59, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "messages" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "id": "e4f0qr67lcjo" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "#processor.tokenizer.padding_side = \"right\"\n", 309 | "inputs = processor.apply_chat_template(\n", 310 | " messages, add_generation_prompt=True, tokenize=True,\n", 311 | " return_dict=True, return_tensors=\"pt\"\n", 312 | ").to(model.device).to(model.dtype)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "colab": { 320 | "base_uri": "https://localhost:8080/" 321 | }, 322 | "id": "EOiBpgkI9kXi", 323 | "outputId": "911a6013-f76f-4fed-c402-8039d67b1e05" 324 | }, 325 | "outputs": [ 326 | { 327 | "data": { 328 | "text/plain": [ 329 | "2087" 330 | ] 331 | }, 332 | "execution_count": 61, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "inputs[\"input_ids\"].shape[-1]" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": { 345 | "colab": { 346 | "base_uri": "https://localhost:8080/" 347 | }, 348 | "id": "yJ95UXBqvXPM", 349 | "outputId": "721839dc-aa78-401b-e802-b858690980da" 350 | }, 351 | "outputs": [ 352 | { 353 | "name": "stderr", 354 | "output_type": "stream", 355 | "text": [ 356 | "The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "with torch.inference_mode():\n", 362 | " generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": { 369 | "colab": { 370 | "base_uri": "https://localhost:8080/" 371 | }, 372 | "id": "3ifVZy9c74St", 373 | "outputId": "f8ab51c6-e5a3-4a16-875b-d07404041396" 374 | }, 375 | "outputs": [ 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "Here's a summary of what's happening in the video:\n", 381 | "\n", 382 | "The video appears to be taken at a ski resort. The main subject is a person snowboarding down a snowy slope. \n", 383 | "\n", 384 | "**Initial Scene (0.0 - 1.03):** The snowboarder is initially positioned on the slope, seemingly having fallen or stopped. Other skiers and snowboarders are visible in the background, waiting at what looks like a lift station.\n", 385 | "\n", 386 | "**Mid-Video (1.03 - 6.26):** The snowboarder gets back up and continues down the slope. They navigate past other people, including skiers and snowboarders, and eventually reach a lift station. The video shows the snowboarder interacting with others at the lift, possibly waiting for the lift to start or having just gotten off. There are also other skiers and snowboarders around the lift station.\n", 387 | "\n", 388 | "**End Scene (6.26):** The snowboarder is still at the lift station,\n" 389 | ] 390 | } 391 | ], 392 | "source": [ 393 | "input_len = inputs[\"input_ids\"].shape[-1]\n", 394 | "\n", 395 | "generation = generation[0][input_len:]\n", 396 | "\n", 397 | "decoded = processor.decode(generation, skip_special_tokens=True)\n", 398 | "print(decoded)" 399 | ] 400 | } 401 | ], 402 | "metadata": { 403 | "accelerator": "GPU", 404 | "colab": { 405 | "gpuType": "A100", 406 | "include_colab_link": true, 407 | "machine_shape": "hm", 408 | "provenance": [] 409 | }, 410 | "kernelspec": { 411 | "display_name": "Python 3", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "name": "python" 416 | }, 417 | "widgets": { 418 | "application/vnd.jupyter.widget-state+json": { 419 | "01dc23faab3d42cda41fdfdd2a7dfed5": { 420 | "model_module": "@jupyter-widgets/controls", 421 | "model_module_version": "1.5.0", 422 | "model_name": "HTMLModel", 423 | "state": { 424 | "_dom_classes": [], 425 | "_model_module": "@jupyter-widgets/controls", 426 | "_model_module_version": "1.5.0", 427 | "_model_name": "HTMLModel", 428 | "_view_count": null, 429 | "_view_module": "@jupyter-widgets/controls", 430 | "_view_module_version": "1.5.0", 431 | "_view_name": "HTMLView", 432 | "description": "", 433 | "description_tooltip": null, 434 | "layout": "IPY_MODEL_ed0fa93199b94fb486c125d4f322d59f", 435 | "placeholder": "​", 436 | "style": "IPY_MODEL_66f82e7ef3694c699e3d4a2bd826392b", 437 | "value": "Loading checkpoint shards: 100%" 438 | } 439 | }, 440 | "29416122cc0b4a5592668ddced7686ba": { 441 | "model_module": "@jupyter-widgets/controls", 442 | "model_module_version": "1.5.0", 443 | "model_name": "DescriptionStyleModel", 444 | "state": { 445 | "_model_module": "@jupyter-widgets/controls", 446 | "_model_module_version": "1.5.0", 447 | "_model_name": "DescriptionStyleModel", 448 | "_view_count": null, 449 | "_view_module": "@jupyter-widgets/base", 450 | "_view_module_version": "1.2.0", 451 | "_view_name": "StyleView", 452 | "description_width": "" 453 | } 454 | }, 455 | "2bfd51e3ae954008ae83704c24dbd6cb": { 456 | "model_module": "@jupyter-widgets/base", 457 | "model_module_version": "1.2.0", 458 | "model_name": "LayoutModel", 459 | "state": { 460 | "_model_module": "@jupyter-widgets/base", 461 | "_model_module_version": "1.2.0", 462 | "_model_name": "LayoutModel", 463 | "_view_count": null, 464 | "_view_module": "@jupyter-widgets/base", 465 | "_view_module_version": "1.2.0", 466 | "_view_name": "LayoutView", 467 | "align_content": null, 468 | "align_items": null, 469 | "align_self": null, 470 | "border": null, 471 | "bottom": null, 472 | "display": null, 473 | "flex": null, 474 | "flex_flow": null, 475 | "grid_area": null, 476 | "grid_auto_columns": null, 477 | "grid_auto_flow": null, 478 | "grid_auto_rows": null, 479 | "grid_column": null, 480 | "grid_gap": null, 481 | "grid_row": null, 482 | "grid_template_areas": null, 483 | "grid_template_columns": null, 484 | "grid_template_rows": null, 485 | "height": null, 486 | "justify_content": null, 487 | "justify_items": null, 488 | "left": null, 489 | "margin": null, 490 | "max_height": null, 491 | "max_width": null, 492 | "min_height": null, 493 | "min_width": null, 494 | "object_fit": null, 495 | "object_position": null, 496 | "order": null, 497 | "overflow": null, 498 | "overflow_x": null, 499 | "overflow_y": null, 500 | "padding": null, 501 | "right": null, 502 | "top": null, 503 | "visibility": null, 504 | "width": null 505 | } 506 | }, 507 | "409f985be1134b468b81136fbdb54408": { 508 | "model_module": "@jupyter-widgets/controls", 509 | "model_module_version": "1.5.0", 510 | "model_name": "HTMLModel", 511 | "state": { 512 | "_dom_classes": [], 513 | "_model_module": "@jupyter-widgets/controls", 514 | "_model_module_version": "1.5.0", 515 | "_model_name": "HTMLModel", 516 | "_view_count": null, 517 | "_view_module": "@jupyter-widgets/controls", 518 | "_view_module_version": "1.5.0", 519 | "_view_name": "HTMLView", 520 | "description": "", 521 | "description_tooltip": null, 522 | "layout": "IPY_MODEL_c72dd3d6a4c246cfa6590c314783c8f0", 523 | "placeholder": "​", 524 | "style": "IPY_MODEL_c0e471e664dd41eab98efe08301ef5e1", 525 | "value": "

Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.
" 526 | } 527 | }, 528 | "40c381fd7bb04b43a879044a4e988cc6": { 529 | "model_module": "@jupyter-widgets/controls", 530 | "model_module_version": "1.5.0", 531 | "model_name": "HTMLModel", 532 | "state": { 533 | "_dom_classes": [], 534 | "_model_module": "@jupyter-widgets/controls", 535 | "_model_module_version": "1.5.0", 536 | "_model_name": "HTMLModel", 537 | "_view_count": null, 538 | "_view_module": "@jupyter-widgets/controls", 539 | "_view_module_version": "1.5.0", 540 | "_view_name": "HTMLView", 541 | "description": "", 542 | "description_tooltip": null, 543 | "layout": "IPY_MODEL_9b5d87960dde401baeaf8b6144fb8bad", 544 | "placeholder": "​", 545 | "style": "IPY_MODEL_76e06881e5e94197a24944e07fdf3189", 546 | "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. " 547 | } 548 | }, 549 | "4488de26dce74cbbb39d99ae09bd21fa": { 550 | "model_module": "@jupyter-widgets/base", 551 | "model_module_version": "1.2.0", 552 | "model_name": "LayoutModel", 553 | "state": { 554 | "_model_module": "@jupyter-widgets/base", 555 | "_model_module_version": "1.2.0", 556 | "_model_name": "LayoutModel", 557 | "_view_count": null, 558 | "_view_module": "@jupyter-widgets/base", 559 | "_view_module_version": "1.2.0", 560 | "_view_name": "LayoutView", 561 | "align_content": null, 562 | "align_items": null, 563 | "align_self": null, 564 | "border": null, 565 | "bottom": null, 566 | "display": null, 567 | "flex": null, 568 | "flex_flow": null, 569 | "grid_area": null, 570 | "grid_auto_columns": null, 571 | "grid_auto_flow": null, 572 | "grid_auto_rows": null, 573 | "grid_column": null, 574 | "grid_gap": null, 575 | "grid_row": null, 576 | "grid_template_areas": null, 577 | "grid_template_columns": null, 578 | "grid_template_rows": null, 579 | "height": null, 580 | "justify_content": null, 581 | "justify_items": null, 582 | "left": null, 583 | "margin": null, 584 | "max_height": null, 585 | "max_width": null, 586 | "min_height": null, 587 | "min_width": null, 588 | "object_fit": null, 589 | "object_position": null, 590 | "order": null, 591 | "overflow": null, 592 | "overflow_x": null, 593 | "overflow_y": null, 594 | "padding": null, 595 | "right": null, 596 | "top": null, 597 | "visibility": null, 598 | "width": null 599 | } 600 | }, 601 | "542490f74e974451bc44009a6fa174bd": { 602 | "model_module": "@jupyter-widgets/controls", 603 | "model_module_version": "1.5.0", 604 | "model_name": "VBoxModel", 605 | "state": { 606 | "_dom_classes": [], 607 | "_model_module": "@jupyter-widgets/controls", 608 | "_model_module_version": "1.5.0", 609 | "_model_name": "VBoxModel", 610 | "_view_count": null, 611 | "_view_module": "@jupyter-widgets/controls", 612 | "_view_module_version": "1.5.0", 613 | "_view_name": "VBoxView", 614 | "box_style": "", 615 | "children": [], 616 | "layout": "IPY_MODEL_8d0e5abdd7c549f1a66ee198c9fa1430" 617 | } 618 | }, 619 | "57cb1e931c614980a4147cb125524d7d": { 620 | "model_module": "@jupyter-widgets/controls", 621 | "model_module_version": "1.5.0", 622 | "model_name": "PasswordModel", 623 | "state": { 624 | "_dom_classes": [], 625 | "_model_module": "@jupyter-widgets/controls", 626 | "_model_module_version": "1.5.0", 627 | "_model_name": "PasswordModel", 628 | "_view_count": null, 629 | "_view_module": "@jupyter-widgets/controls", 630 | "_view_module_version": "1.5.0", 631 | "_view_name": "PasswordView", 632 | "continuous_update": true, 633 | "description": "Token:", 634 | "description_tooltip": null, 635 | "disabled": false, 636 | "layout": "IPY_MODEL_868f63ea9455442d837dc2c422918800", 637 | "placeholder": "​", 638 | "style": "IPY_MODEL_5b7b4707b1bf4159a10bf7e289bde435", 639 | "value": "" 640 | } 641 | }, 642 | "5b7b4707b1bf4159a10bf7e289bde435": { 643 | "model_module": "@jupyter-widgets/controls", 644 | "model_module_version": "1.5.0", 645 | "model_name": "DescriptionStyleModel", 646 | "state": { 647 | "_model_module": "@jupyter-widgets/controls", 648 | "_model_module_version": "1.5.0", 649 | "_model_name": "DescriptionStyleModel", 650 | "_view_count": null, 651 | "_view_module": "@jupyter-widgets/base", 652 | "_view_module_version": "1.2.0", 653 | "_view_name": "StyleView", 654 | "description_width": "" 655 | } 656 | }, 657 | "66f82e7ef3694c699e3d4a2bd826392b": { 658 | "model_module": "@jupyter-widgets/controls", 659 | "model_module_version": "1.5.0", 660 | "model_name": "DescriptionStyleModel", 661 | "state": { 662 | "_model_module": "@jupyter-widgets/controls", 663 | "_model_module_version": "1.5.0", 664 | "_model_name": "DescriptionStyleModel", 665 | "_view_count": null, 666 | "_view_module": "@jupyter-widgets/base", 667 | "_view_module_version": "1.2.0", 668 | "_view_name": "StyleView", 669 | "description_width": "" 670 | } 671 | }, 672 | "68fc757825dd44a48ab2383db20958db": { 673 | "model_module": "@jupyter-widgets/controls", 674 | "model_module_version": "1.5.0", 675 | "model_name": "DescriptionStyleModel", 676 | "state": { 677 | "_model_module": "@jupyter-widgets/controls", 678 | "_model_module_version": "1.5.0", 679 | "_model_name": "DescriptionStyleModel", 680 | "_view_count": null, 681 | "_view_module": "@jupyter-widgets/base", 682 | "_view_module_version": "1.2.0", 683 | "_view_name": "StyleView", 684 | "description_width": "" 685 | } 686 | }, 687 | "76e06881e5e94197a24944e07fdf3189": { 688 | "model_module": "@jupyter-widgets/controls", 689 | "model_module_version": "1.5.0", 690 | "model_name": "DescriptionStyleModel", 691 | "state": { 692 | "_model_module": "@jupyter-widgets/controls", 693 | "_model_module_version": "1.5.0", 694 | "_model_name": "DescriptionStyleModel", 695 | "_view_count": null, 696 | "_view_module": "@jupyter-widgets/base", 697 | "_view_module_version": "1.2.0", 698 | "_view_name": "StyleView", 699 | "description_width": "" 700 | } 701 | }, 702 | "770341dc116148a8b7571cce3a2f2baf": { 703 | "model_module": "@jupyter-widgets/base", 704 | "model_module_version": "1.2.0", 705 | "model_name": "LayoutModel", 706 | "state": { 707 | "_model_module": "@jupyter-widgets/base", 708 | "_model_module_version": "1.2.0", 709 | "_model_name": "LayoutModel", 710 | "_view_count": null, 711 | "_view_module": "@jupyter-widgets/base", 712 | "_view_module_version": "1.2.0", 713 | "_view_name": "LayoutView", 714 | "align_content": null, 715 | "align_items": null, 716 | "align_self": null, 717 | "border": null, 718 | "bottom": null, 719 | "display": null, 720 | "flex": null, 721 | "flex_flow": null, 722 | "grid_area": null, 723 | "grid_auto_columns": null, 724 | "grid_auto_flow": null, 725 | "grid_auto_rows": null, 726 | "grid_column": null, 727 | "grid_gap": null, 728 | "grid_row": null, 729 | "grid_template_areas": null, 730 | "grid_template_columns": null, 731 | "grid_template_rows": null, 732 | "height": null, 733 | "justify_content": null, 734 | "justify_items": null, 735 | "left": null, 736 | "margin": null, 737 | "max_height": null, 738 | "max_width": null, 739 | "min_height": null, 740 | "min_width": null, 741 | "object_fit": null, 742 | "object_position": null, 743 | "order": null, 744 | "overflow": null, 745 | "overflow_x": null, 746 | "overflow_y": null, 747 | "padding": null, 748 | "right": null, 749 | "top": null, 750 | "visibility": null, 751 | "width": null 752 | } 753 | }, 754 | "777d7addfb144fd8896b77a1e0d54f25": { 755 | "model_module": "@jupyter-widgets/controls", 756 | "model_module_version": "1.5.0", 757 | "model_name": "FloatProgressModel", 758 | "state": { 759 | "_dom_classes": [], 760 | "_model_module": "@jupyter-widgets/controls", 761 | "_model_module_version": "1.5.0", 762 | "_model_name": "FloatProgressModel", 763 | "_view_count": null, 764 | "_view_module": "@jupyter-widgets/controls", 765 | "_view_module_version": "1.5.0", 766 | "_view_name": "ProgressView", 767 | "bar_style": "success", 768 | "description": "", 769 | "description_tooltip": null, 770 | "layout": "IPY_MODEL_2bfd51e3ae954008ae83704c24dbd6cb", 771 | "max": 4, 772 | "min": 0, 773 | "orientation": "horizontal", 774 | "style": "IPY_MODEL_f8b84d8c06384680973ef6fe787b5a5d", 775 | "value": 4 776 | } 777 | }, 778 | "868f63ea9455442d837dc2c422918800": { 779 | "model_module": "@jupyter-widgets/base", 780 | "model_module_version": "1.2.0", 781 | "model_name": "LayoutModel", 782 | "state": { 783 | "_model_module": "@jupyter-widgets/base", 784 | "_model_module_version": "1.2.0", 785 | "_model_name": "LayoutModel", 786 | "_view_count": null, 787 | "_view_module": "@jupyter-widgets/base", 788 | "_view_module_version": "1.2.0", 789 | "_view_name": "LayoutView", 790 | "align_content": null, 791 | "align_items": null, 792 | "align_self": null, 793 | "border": null, 794 | "bottom": null, 795 | "display": null, 796 | "flex": null, 797 | "flex_flow": null, 798 | "grid_area": null, 799 | "grid_auto_columns": null, 800 | "grid_auto_flow": null, 801 | "grid_auto_rows": null, 802 | "grid_column": null, 803 | "grid_gap": null, 804 | "grid_row": null, 805 | "grid_template_areas": null, 806 | "grid_template_columns": null, 807 | "grid_template_rows": null, 808 | "height": null, 809 | "justify_content": null, 810 | "justify_items": null, 811 | "left": null, 812 | "margin": null, 813 | "max_height": null, 814 | "max_width": null, 815 | "min_height": null, 816 | "min_width": null, 817 | "object_fit": null, 818 | "object_position": null, 819 | "order": null, 820 | "overflow": null, 821 | "overflow_x": null, 822 | "overflow_y": null, 823 | "padding": null, 824 | "right": null, 825 | "top": null, 826 | "visibility": null, 827 | "width": null 828 | } 829 | }, 830 | "8704264bff4d46c9813ac9acf92da962": { 831 | "model_module": "@jupyter-widgets/controls", 832 | "model_module_version": "1.5.0", 833 | "model_name": "ButtonStyleModel", 834 | "state": { 835 | "_model_module": "@jupyter-widgets/controls", 836 | "_model_module_version": "1.5.0", 837 | "_model_name": "ButtonStyleModel", 838 | "_view_count": null, 839 | "_view_module": "@jupyter-widgets/base", 840 | "_view_module_version": "1.2.0", 841 | "_view_name": "StyleView", 842 | "button_color": null, 843 | "font_weight": "" 844 | } 845 | }, 846 | "87dc7aaf52e349a7bb43bb1b8bc137ee": { 847 | "model_module": "@jupyter-widgets/controls", 848 | "model_module_version": "1.5.0", 849 | "model_name": "CheckboxModel", 850 | "state": { 851 | "_dom_classes": [], 852 | "_model_module": "@jupyter-widgets/controls", 853 | "_model_module_version": "1.5.0", 854 | "_model_name": "CheckboxModel", 855 | "_view_count": null, 856 | "_view_module": "@jupyter-widgets/controls", 857 | "_view_module_version": "1.5.0", 858 | "_view_name": "CheckboxView", 859 | "description": "Add token as git credential?", 860 | "description_tooltip": null, 861 | "disabled": false, 862 | "indent": true, 863 | "layout": "IPY_MODEL_889d0d1ed24e4de2b89896511d008e60", 864 | "style": "IPY_MODEL_68fc757825dd44a48ab2383db20958db", 865 | "value": true 866 | } 867 | }, 868 | "889d0d1ed24e4de2b89896511d008e60": { 869 | "model_module": "@jupyter-widgets/base", 870 | "model_module_version": "1.2.0", 871 | "model_name": "LayoutModel", 872 | "state": { 873 | "_model_module": "@jupyter-widgets/base", 874 | "_model_module_version": "1.2.0", 875 | "_model_name": "LayoutModel", 876 | "_view_count": null, 877 | "_view_module": "@jupyter-widgets/base", 878 | "_view_module_version": "1.2.0", 879 | "_view_name": "LayoutView", 880 | "align_content": null, 881 | "align_items": null, 882 | "align_self": null, 883 | "border": null, 884 | "bottom": null, 885 | "display": null, 886 | "flex": null, 887 | "flex_flow": null, 888 | "grid_area": null, 889 | "grid_auto_columns": null, 890 | "grid_auto_flow": null, 891 | "grid_auto_rows": null, 892 | "grid_column": null, 893 | "grid_gap": null, 894 | "grid_row": null, 895 | "grid_template_areas": null, 896 | "grid_template_columns": null, 897 | "grid_template_rows": null, 898 | "height": null, 899 | "justify_content": null, 900 | "justify_items": null, 901 | "left": null, 902 | "margin": null, 903 | "max_height": null, 904 | "max_width": null, 905 | "min_height": null, 906 | "min_width": null, 907 | "object_fit": null, 908 | "object_position": null, 909 | "order": null, 910 | "overflow": null, 911 | "overflow_x": null, 912 | "overflow_y": null, 913 | "padding": null, 914 | "right": null, 915 | "top": null, 916 | "visibility": null, 917 | "width": null 918 | } 919 | }, 920 | "8d0e5abdd7c549f1a66ee198c9fa1430": { 921 | "model_module": "@jupyter-widgets/base", 922 | "model_module_version": "1.2.0", 923 | "model_name": "LayoutModel", 924 | "state": { 925 | "_model_module": "@jupyter-widgets/base", 926 | "_model_module_version": "1.2.0", 927 | "_model_name": "LayoutModel", 928 | "_view_count": null, 929 | "_view_module": "@jupyter-widgets/base", 930 | "_view_module_version": "1.2.0", 931 | "_view_name": "LayoutView", 932 | "align_content": null, 933 | "align_items": "center", 934 | "align_self": null, 935 | "border": null, 936 | "bottom": null, 937 | "display": "flex", 938 | "flex": null, 939 | "flex_flow": "column", 940 | "grid_area": null, 941 | "grid_auto_columns": null, 942 | "grid_auto_flow": null, 943 | "grid_auto_rows": null, 944 | "grid_column": null, 945 | "grid_gap": null, 946 | "grid_row": null, 947 | "grid_template_areas": null, 948 | "grid_template_columns": null, 949 | "grid_template_rows": null, 950 | "height": null, 951 | "justify_content": null, 952 | "justify_items": null, 953 | "left": null, 954 | "margin": null, 955 | "max_height": null, 956 | "max_width": null, 957 | "min_height": null, 958 | "min_width": null, 959 | "object_fit": null, 960 | "object_position": null, 961 | "order": null, 962 | "overflow": null, 963 | "overflow_x": null, 964 | "overflow_y": null, 965 | "padding": null, 966 | "right": null, 967 | "top": null, 968 | "visibility": null, 969 | "width": "50%" 970 | } 971 | }, 972 | "983ed4cb4eea42daa9ae8c0417021a21": { 973 | "model_module": "@jupyter-widgets/controls", 974 | "model_module_version": "1.5.0", 975 | "model_name": "ButtonModel", 976 | "state": { 977 | "_dom_classes": [], 978 | "_model_module": "@jupyter-widgets/controls", 979 | "_model_module_version": "1.5.0", 980 | "_model_name": "ButtonModel", 981 | "_view_count": null, 982 | "_view_module": "@jupyter-widgets/controls", 983 | "_view_module_version": "1.5.0", 984 | "_view_name": "ButtonView", 985 | "button_style": "", 986 | "description": "Login", 987 | "disabled": false, 988 | "icon": "", 989 | "layout": "IPY_MODEL_cb76f933e6e640d9a688f7838e5fb0b3", 990 | "style": "IPY_MODEL_8704264bff4d46c9813ac9acf92da962", 991 | "tooltip": "" 992 | } 993 | }, 994 | "9b5d87960dde401baeaf8b6144fb8bad": { 995 | "model_module": "@jupyter-widgets/base", 996 | "model_module_version": "1.2.0", 997 | "model_name": "LayoutModel", 998 | "state": { 999 | "_model_module": "@jupyter-widgets/base", 1000 | "_model_module_version": "1.2.0", 1001 | "_model_name": "LayoutModel", 1002 | "_view_count": null, 1003 | "_view_module": "@jupyter-widgets/base", 1004 | "_view_module_version": "1.2.0", 1005 | "_view_name": "LayoutView", 1006 | "align_content": null, 1007 | "align_items": null, 1008 | "align_self": null, 1009 | "border": null, 1010 | "bottom": null, 1011 | "display": null, 1012 | "flex": null, 1013 | "flex_flow": null, 1014 | "grid_area": null, 1015 | "grid_auto_columns": null, 1016 | "grid_auto_flow": null, 1017 | "grid_auto_rows": null, 1018 | "grid_column": null, 1019 | "grid_gap": null, 1020 | "grid_row": null, 1021 | "grid_template_areas": null, 1022 | "grid_template_columns": null, 1023 | "grid_template_rows": null, 1024 | "height": null, 1025 | "justify_content": null, 1026 | "justify_items": null, 1027 | "left": null, 1028 | "margin": null, 1029 | "max_height": null, 1030 | "max_width": null, 1031 | "min_height": null, 1032 | "min_width": null, 1033 | "object_fit": null, 1034 | "object_position": null, 1035 | "order": null, 1036 | "overflow": null, 1037 | "overflow_x": null, 1038 | "overflow_y": null, 1039 | "padding": null, 1040 | "right": null, 1041 | "top": null, 1042 | "visibility": null, 1043 | "width": null 1044 | } 1045 | }, 1046 | "be523e956910487ca263d943a7a58395": { 1047 | "model_module": "@jupyter-widgets/controls", 1048 | "model_module_version": "1.5.0", 1049 | "model_name": "HBoxModel", 1050 | "state": { 1051 | "_dom_classes": [], 1052 | "_model_module": "@jupyter-widgets/controls", 1053 | "_model_module_version": "1.5.0", 1054 | "_model_name": "HBoxModel", 1055 | "_view_count": null, 1056 | "_view_module": "@jupyter-widgets/controls", 1057 | "_view_module_version": "1.5.0", 1058 | "_view_name": "HBoxView", 1059 | "box_style": "", 1060 | "children": [ 1061 | "IPY_MODEL_01dc23faab3d42cda41fdfdd2a7dfed5", 1062 | "IPY_MODEL_777d7addfb144fd8896b77a1e0d54f25", 1063 | "IPY_MODEL_c518268069244b21810e84380502c190" 1064 | ], 1065 | "layout": "IPY_MODEL_fee72c1c455549b59092028b855a082a" 1066 | } 1067 | }, 1068 | "c0e471e664dd41eab98efe08301ef5e1": { 1069 | "model_module": "@jupyter-widgets/controls", 1070 | "model_module_version": "1.5.0", 1071 | "model_name": "DescriptionStyleModel", 1072 | "state": { 1073 | "_model_module": "@jupyter-widgets/controls", 1074 | "_model_module_version": "1.5.0", 1075 | "_model_name": "DescriptionStyleModel", 1076 | "_view_count": null, 1077 | "_view_module": "@jupyter-widgets/base", 1078 | "_view_module_version": "1.2.0", 1079 | "_view_name": "StyleView", 1080 | "description_width": "" 1081 | } 1082 | }, 1083 | "c518268069244b21810e84380502c190": { 1084 | "model_module": "@jupyter-widgets/controls", 1085 | "model_module_version": "1.5.0", 1086 | "model_name": "HTMLModel", 1087 | "state": { 1088 | "_dom_classes": [], 1089 | "_model_module": "@jupyter-widgets/controls", 1090 | "_model_module_version": "1.5.0", 1091 | "_model_name": "HTMLModel", 1092 | "_view_count": null, 1093 | "_view_module": "@jupyter-widgets/controls", 1094 | "_view_module_version": "1.5.0", 1095 | "_view_name": "HTMLView", 1096 | "description": "", 1097 | "description_tooltip": null, 1098 | "layout": "IPY_MODEL_770341dc116148a8b7571cce3a2f2baf", 1099 | "placeholder": "​", 1100 | "style": "IPY_MODEL_29416122cc0b4a5592668ddced7686ba", 1101 | "value": " 4/4 [00:00<00:00,  5.03it/s]" 1102 | } 1103 | }, 1104 | "c72dd3d6a4c246cfa6590c314783c8f0": { 1105 | "model_module": "@jupyter-widgets/base", 1106 | "model_module_version": "1.2.0", 1107 | "model_name": "LayoutModel", 1108 | "state": { 1109 | "_model_module": "@jupyter-widgets/base", 1110 | "_model_module_version": "1.2.0", 1111 | "_model_name": "LayoutModel", 1112 | "_view_count": null, 1113 | "_view_module": "@jupyter-widgets/base", 1114 | "_view_module_version": "1.2.0", 1115 | "_view_name": "LayoutView", 1116 | "align_content": null, 1117 | "align_items": null, 1118 | "align_self": null, 1119 | "border": null, 1120 | "bottom": null, 1121 | "display": null, 1122 | "flex": null, 1123 | "flex_flow": null, 1124 | "grid_area": null, 1125 | "grid_auto_columns": null, 1126 | "grid_auto_flow": null, 1127 | "grid_auto_rows": null, 1128 | "grid_column": null, 1129 | "grid_gap": null, 1130 | "grid_row": null, 1131 | "grid_template_areas": null, 1132 | "grid_template_columns": null, 1133 | "grid_template_rows": null, 1134 | "height": null, 1135 | "justify_content": null, 1136 | "justify_items": null, 1137 | "left": null, 1138 | "margin": null, 1139 | "max_height": null, 1140 | "max_width": null, 1141 | "min_height": null, 1142 | "min_width": null, 1143 | "object_fit": null, 1144 | "object_position": null, 1145 | "order": null, 1146 | "overflow": null, 1147 | "overflow_x": null, 1148 | "overflow_y": null, 1149 | "padding": null, 1150 | "right": null, 1151 | "top": null, 1152 | "visibility": null, 1153 | "width": null 1154 | } 1155 | }, 1156 | "cb76f933e6e640d9a688f7838e5fb0b3": { 1157 | "model_module": "@jupyter-widgets/base", 1158 | "model_module_version": "1.2.0", 1159 | "model_name": "LayoutModel", 1160 | "state": { 1161 | "_model_module": "@jupyter-widgets/base", 1162 | "_model_module_version": "1.2.0", 1163 | "_model_name": "LayoutModel", 1164 | "_view_count": null, 1165 | "_view_module": "@jupyter-widgets/base", 1166 | "_view_module_version": "1.2.0", 1167 | "_view_name": "LayoutView", 1168 | "align_content": null, 1169 | "align_items": null, 1170 | "align_self": null, 1171 | "border": null, 1172 | "bottom": null, 1173 | "display": null, 1174 | "flex": null, 1175 | "flex_flow": null, 1176 | "grid_area": null, 1177 | "grid_auto_columns": null, 1178 | "grid_auto_flow": null, 1179 | "grid_auto_rows": null, 1180 | "grid_column": null, 1181 | "grid_gap": null, 1182 | "grid_row": null, 1183 | "grid_template_areas": null, 1184 | "grid_template_columns": null, 1185 | "grid_template_rows": null, 1186 | "height": null, 1187 | "justify_content": null, 1188 | "justify_items": null, 1189 | "left": null, 1190 | "margin": null, 1191 | "max_height": null, 1192 | "max_width": null, 1193 | "min_height": null, 1194 | "min_width": null, 1195 | "object_fit": null, 1196 | "object_position": null, 1197 | "order": null, 1198 | "overflow": null, 1199 | "overflow_x": null, 1200 | "overflow_y": null, 1201 | "padding": null, 1202 | "right": null, 1203 | "top": null, 1204 | "visibility": null, 1205 | "width": null 1206 | } 1207 | }, 1208 | "ded62e6c032745ec88ca0ab694b0d397": { 1209 | "model_module": "@jupyter-widgets/controls", 1210 | "model_module_version": "1.5.0", 1211 | "model_name": "DescriptionStyleModel", 1212 | "state": { 1213 | "_model_module": "@jupyter-widgets/controls", 1214 | "_model_module_version": "1.5.0", 1215 | "_model_name": "DescriptionStyleModel", 1216 | "_view_count": null, 1217 | "_view_module": "@jupyter-widgets/base", 1218 | "_view_module_version": "1.2.0", 1219 | "_view_name": "StyleView", 1220 | "description_width": "" 1221 | } 1222 | }, 1223 | "ed0fa93199b94fb486c125d4f322d59f": { 1224 | "model_module": "@jupyter-widgets/base", 1225 | "model_module_version": "1.2.0", 1226 | "model_name": "LayoutModel", 1227 | "state": { 1228 | "_model_module": "@jupyter-widgets/base", 1229 | "_model_module_version": "1.2.0", 1230 | "_model_name": "LayoutModel", 1231 | "_view_count": null, 1232 | "_view_module": "@jupyter-widgets/base", 1233 | "_view_module_version": "1.2.0", 1234 | "_view_name": "LayoutView", 1235 | "align_content": null, 1236 | "align_items": null, 1237 | "align_self": null, 1238 | "border": null, 1239 | "bottom": null, 1240 | "display": null, 1241 | "flex": null, 1242 | "flex_flow": null, 1243 | "grid_area": null, 1244 | "grid_auto_columns": null, 1245 | "grid_auto_flow": null, 1246 | "grid_auto_rows": null, 1247 | "grid_column": null, 1248 | "grid_gap": null, 1249 | "grid_row": null, 1250 | "grid_template_areas": null, 1251 | "grid_template_columns": null, 1252 | "grid_template_rows": null, 1253 | "height": null, 1254 | "justify_content": null, 1255 | "justify_items": null, 1256 | "left": null, 1257 | "margin": null, 1258 | "max_height": null, 1259 | "max_width": null, 1260 | "min_height": null, 1261 | "min_width": null, 1262 | "object_fit": null, 1263 | "object_position": null, 1264 | "order": null, 1265 | "overflow": null, 1266 | "overflow_x": null, 1267 | "overflow_y": null, 1268 | "padding": null, 1269 | "right": null, 1270 | "top": null, 1271 | "visibility": null, 1272 | "width": null 1273 | } 1274 | }, 1275 | "f40dd696acc64c6284c6f8f485f3ce9d": { 1276 | "model_module": "@jupyter-widgets/controls", 1277 | "model_module_version": "1.5.0", 1278 | "model_name": "LabelModel", 1279 | "state": { 1280 | "_dom_classes": [], 1281 | "_model_module": "@jupyter-widgets/controls", 1282 | "_model_module_version": "1.5.0", 1283 | "_model_name": "LabelModel", 1284 | "_view_count": null, 1285 | "_view_module": "@jupyter-widgets/controls", 1286 | "_view_module_version": "1.5.0", 1287 | "_view_name": "LabelView", 1288 | "description": "", 1289 | "description_tooltip": null, 1290 | "layout": "IPY_MODEL_4488de26dce74cbbb39d99ae09bd21fa", 1291 | "placeholder": "​", 1292 | "style": "IPY_MODEL_ded62e6c032745ec88ca0ab694b0d397", 1293 | "value": "Connecting..." 1294 | } 1295 | }, 1296 | "f8b84d8c06384680973ef6fe787b5a5d": { 1297 | "model_module": "@jupyter-widgets/controls", 1298 | "model_module_version": "1.5.0", 1299 | "model_name": "ProgressStyleModel", 1300 | "state": { 1301 | "_model_module": "@jupyter-widgets/controls", 1302 | "_model_module_version": "1.5.0", 1303 | "_model_name": "ProgressStyleModel", 1304 | "_view_count": null, 1305 | "_view_module": "@jupyter-widgets/base", 1306 | "_view_module_version": "1.2.0", 1307 | "_view_name": "StyleView", 1308 | "bar_color": null, 1309 | "description_width": "" 1310 | } 1311 | }, 1312 | "fee72c1c455549b59092028b855a082a": { 1313 | "model_module": "@jupyter-widgets/base", 1314 | "model_module_version": "1.2.0", 1315 | "model_name": "LayoutModel", 1316 | "state": { 1317 | "_model_module": "@jupyter-widgets/base", 1318 | "_model_module_version": "1.2.0", 1319 | "_model_name": "LayoutModel", 1320 | "_view_count": null, 1321 | "_view_module": "@jupyter-widgets/base", 1322 | "_view_module_version": "1.2.0", 1323 | "_view_name": "LayoutView", 1324 | "align_content": null, 1325 | "align_items": null, 1326 | "align_self": null, 1327 | "border": null, 1328 | "bottom": null, 1329 | "display": null, 1330 | "flex": null, 1331 | "flex_flow": null, 1332 | "grid_area": null, 1333 | "grid_auto_columns": null, 1334 | "grid_auto_flow": null, 1335 | "grid_auto_rows": null, 1336 | "grid_column": null, 1337 | "grid_gap": null, 1338 | "grid_row": null, 1339 | "grid_template_areas": null, 1340 | "grid_template_columns": null, 1341 | "grid_template_rows": null, 1342 | "height": null, 1343 | "justify_content": null, 1344 | "justify_items": null, 1345 | "left": null, 1346 | "margin": null, 1347 | "max_height": null, 1348 | "max_width": null, 1349 | "min_height": null, 1350 | "min_width": null, 1351 | "object_fit": null, 1352 | "object_position": null, 1353 | "order": null, 1354 | "overflow": null, 1355 | "overflow_x": null, 1356 | "overflow_y": null, 1357 | "padding": null, 1358 | "right": null, 1359 | "top": null, 1360 | "visibility": null, 1361 | "width": null 1362 | } 1363 | } 1364 | } 1365 | } 1366 | }, 1367 | "nbformat": 4, 1368 | "nbformat_minor": 0 1369 | } 1370 | -------------------------------------------------------------------------------- /Gemma3n_Fine_tuning_on_All_Modalities.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "0eVo7Mc5GMyL" 7 | }, 8 | "source": [ 9 | "# Fine-tune Gemma3n on FineVideo\n", 10 | "\n", 11 | "In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.\n", 12 | "Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "!pip install -U -q timm transformers trl peft datasets" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "id": "UxE2vzKsbov0" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import io\n", 33 | "import os\n", 34 | "import zipfile\n", 35 | "\n", 36 | "import torch\n", 37 | "from datasets import load_dataset\n", 38 | "from PIL import Image\n", 39 | "from transformers import AutoProcessor, Gemma3nForConditionalGeneration\n", 40 | "\n", 41 | "from trl import (\n", 42 | " SFTConfig,\n", 43 | " SFTTrainer,\n", 44 | ")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "id": "T06yJvcMiqO6" 51 | }, 52 | "source": [ 53 | "## Download videos and preprocessing\n", 54 | "\n", 55 | "FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "id": "wBFfYgLxmg7b" 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "from datasets import load_dataset\n", 67 | "import json\n", 68 | "import os\n", 69 | "\n", 70 | "dataset = load_dataset(\"HuggingFaceFV/finevideo\", split=\"train\", streaming=True)\n", 71 | "\n", 72 | "\n", 73 | "os.makedirs(\"videos\", exist_ok=True)\n", 74 | "os.makedirs(\"metadata\", exist_ok=True)\n", 75 | "\n", 76 | "for idx, sample in enumerate(dataset):\n", 77 | " data = sample[\"json\"]\n", 78 | " duration = data.get(\"duration_seconds\", 0)\n", 79 | " if duration < 30:\n", 80 | " video_filename = f\"videos/sample_{idx}.mp4\"\n", 81 | " with open(video_filename, 'wb') as video_file:\n", 82 | " video_file.write(sample['mp4'])\n", 83 | "\n", 84 | " json_filename = f\"metadata/sample_{idx}.json\"\n", 85 | " with open(json_filename, 'w') as json_file:\n", 86 | " json.dump(sample['json'], json_file)\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 7, 92 | "metadata": { 93 | "colab": { 94 | "base_uri": "https://localhost:8080/" 95 | }, 96 | "id": "K48dmmZTdZ1l", 97 | "outputId": "31c7c32b-1c40-4df4-eb51-11857d7b4da9" 98 | }, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "Number of items in content/videos: 871\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | " print(f\"Number of items in content/videos: {len(os.listdir('videos'))}\")" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "QbkDI03qHMog" 116 | }, 117 | "source": [ 118 | "In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 10, 124 | "metadata": { 125 | "id": "0UMZi3tHb-BC" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "import cv2\n", 130 | "from PIL import Image\n", 131 | "import numpy as np\n", 132 | "\n", 133 | "def is_dark(frame, threshold=10):\n", 134 | " return np.max(frame) < threshold # all pixels are very close to 0\n", 135 | "\n", 136 | "def downsample_video(video_path):\n", 137 | " vidcap = cv2.VideoCapture(video_path)\n", 138 | " total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n", 139 | " fps = vidcap.get(cv2.CAP_PROP_FPS)\n", 140 | "\n", 141 | " frames = []\n", 142 | "\n", 143 | " # Generate 8 evenly spaced indices, skip first and last\n", 144 | " full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]\n", 145 | "\n", 146 | " for i in full_indices:\n", 147 | " found_valid = False\n", 148 | " for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark\n", 149 | " candidate_idx = i + offset\n", 150 | " if 0 <= candidate_idx < total_frames:\n", 151 | " vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)\n", 152 | " success, image = vidcap.read()\n", 153 | " if success:\n", 154 | " if not is_dark(image):\n", 155 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 156 | " pil_image = Image.fromarray(image)\n", 157 | " timestamp = round(candidate_idx / fps, 2)\n", 158 | " frames.append((pil_image, timestamp))\n", 159 | " found_valid = True\n", 160 | " break\n", 161 | " if not found_valid:\n", 162 | " print(f\"Warning: Could not find non-dark frame near index {i}\")\n", 163 | "\n", 164 | " vidcap.release()\n", 165 | "\n", 166 | " # If still fewer than 8, try to top off by scanning more frames\n", 167 | " if len(frames) < 6:\n", 168 | " print(\"Trying to top off with additional non-dark frames...\")\n", 169 | " idx = 0\n", 170 | " while len(frames) < 8 and idx < total_frames:\n", 171 | " vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n", 172 | " success, image = vidcap.read()\n", 173 | " if success and not is_dark(image):\n", 174 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 175 | " pil_image = Image.fromarray(image)\n", 176 | " timestamp = round(idx / fps, 2)\n", 177 | " # Avoid adding duplicate timestamps\n", 178 | " if not any(ts == timestamp for _, ts in frames):\n", 179 | " frames.append((pil_image, timestamp))\n", 180 | " idx += 1\n", 181 | "\n", 182 | " return frames[:8] # Ensure exactly 8 frames\n", 183 | "\n", 184 | "import os\n", 185 | "import glob\n", 186 | "\n", 187 | "def remove_dark_videos(video_dir, metadata_dir, audio_dir):\n", 188 | " \"\"\"\n", 189 | " Remove videos (and their metadata/audio files) if all frames are dark.\n", 190 | " \"\"\"\n", 191 | " video_paths = glob.glob(os.path.join(video_dir, \"*.mp4\"))\n", 192 | "\n", 193 | " for video_path in video_paths:\n", 194 | " filename = os.path.basename(video_path)\n", 195 | " base_name = os.path.splitext(filename)[0]\n", 196 | "\n", 197 | " frames = downsample_video(video_path)\n", 198 | " if len(frames) < 6:\n", 199 | " try:\n", 200 | " os.remove(video_path)\n", 201 | " print(f\"Deleted: {video_path}\")\n", 202 | " except Exception as e:\n", 203 | " print(f\"Failed to delete {video_path}: {e}\")\n", 204 | "\n", 205 | " metadata_path = os.path.join(metadata_dir, f\"{base_name}.json\")\n", 206 | " if os.path.exists(metadata_path):\n", 207 | " os.remove(metadata_path)\n", 208 | "\n", 209 | " # Remove audio\n", 210 | " audio_path = os.path.join(audio_dir, f\"{base_name}.wav\")\n", 211 | " if os.path.exists(audio_path):\n", 212 | " os.remove(audio_path)\n", 213 | "\n" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "remove_dark_videos(\n", 223 | " video_dir=\"videos\",\n", 224 | " metadata_dir=\"metadata\",\n", 225 | " audio_dir=\"audios\"\n", 226 | " )" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "-qa4Tf8PwITC" 233 | }, 234 | "source": [ 235 | "Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 8, 241 | "metadata": { 242 | "id": "OR7bhnCawHrF" 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "import os\n", 247 | "import subprocess\n", 248 | "\n", 249 | "video_dir = \"videos\"\n", 250 | "audio_dir = \"audios\"\n", 251 | "os.makedirs(audio_dir, exist_ok=True)\n", 252 | "\n", 253 | "for filename in os.listdir(video_dir):\n", 254 | " if not filename.endswith(\".mp4\"):\n", 255 | " continue\n", 256 | "\n", 257 | " idx = filename.split(\"_\")[1].split(\".\")[0]\n", 258 | " video_path = os.path.join(video_dir, filename)\n", 259 | " audio_path = os.path.join(audio_dir, f\"sample_{idx}.wav\")\n", 260 | "\n", 261 | " subprocess.run([\n", 262 | " \"ffmpeg\", \"-i\", video_path,\n", 263 | " \"-q:a\", \"0\", \"-map\", \"a\",\n", 264 | " audio_path,\n", 265 | " \"-y\"\n", 266 | " ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": { 272 | "id": "uIlVtxDcwQcy" 273 | }, 274 | "source": [ 275 | "Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes." 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": { 281 | "id": "CjtgRoSEd9TV" 282 | }, 283 | "source": [ 284 | "We will speed-up and downsample the audios to save space during training." 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "from datasets import Dataset\n", 294 | "import json\n", 295 | "\n", 296 | "def gen():\n", 297 | " meta_dir = \"metadata\"\n", 298 | " for filename in os.listdir(meta_dir):\n", 299 | " if not filename.endswith(\".json\"):\n", 300 | " continue\n", 301 | "\n", 302 | " idx = filename.split(\"_\")[1].split(\".\")[0]\n", 303 | " if os.path.exists(f\"videos/sample_{idx}.mp4\"):\n", 304 | " video_filename = f\"sample_{idx}.mp4\"\n", 305 | " audio_filename = f\"sample_{idx}.wav\"\n", 306 | " json_path = os.path.join(meta_dir, filename)\n", 307 | "\n", 308 | " with open(json_path, \"r\") as f:\n", 309 | " metadata = json.load(f)\n", 310 | "\n", 311 | "\n", 312 | " yield {\n", 313 | " \"video\": video_filename,\n", 314 | " \"audio\": audio_filename,\n", 315 | " \"content_parent_category\": metadata[\"content_parent_category\"],\n", 316 | " \"sample_index\": int(idx)\n", 317 | " }\n", 318 | " else:\n", 319 | " pass\n", 320 | "\n", 321 | "dataset = Dataset.from_generator(gen)\n" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 14, 327 | "metadata": { 328 | "id": "8DDaQ86MD1Y3" 329 | }, 330 | "outputs": [], 331 | "source": [ 332 | "import torchaudio\n", 333 | "from torchaudio.transforms import Resample\n", 334 | "import os\n", 335 | "import torch\n", 336 | "\n", 337 | "def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):\n", 338 | " waveform, sample_rate = torchaudio.load(audio_path)\n", 339 | "\n", 340 | " if waveform.shape[0] > 1:\n", 341 | " waveform = waveform.mean(dim=0, keepdim=True)\n", 342 | "\n", 343 | " if sample_rate != target_sample_rate:\n", 344 | " resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", 345 | " waveform = resampler(waveform)\n", 346 | " sample_rate = target_sample_rate\n", 347 | "\n", 348 | " if speedup_factor > 1.0:\n", 349 | " indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()\n", 350 | " if indices[-1] >= waveform.shape[1]:\n", 351 | " indices = indices[:-1]\n", 352 | " waveform = waveform[:, indices]\n", 353 | "\n", 354 | " max_length = int(target_sample_rate * max_duration_sec)\n", 355 | " if waveform.shape[1] > max_length:\n", 356 | " waveform = waveform[:, :max_length]\n", 357 | "\n", 358 | " torchaudio.save(audio_path, waveform, sample_rate)\n" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 15, 364 | "metadata": { 365 | "id": "IQ7L2_0bI1tP" 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "for file_name in os.listdir(\"audios\"):\n", 370 | " if file_name.lower().endswith(\".wav\"):\n", 371 | " audio_path = os.path.join(\"audios\", file_name)\n", 372 | " preprocess_audio(audio_path)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 16, 378 | "metadata": { 379 | "id": "pspaO2Lv4SxG" 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "dataset = dataset.train_test_split(test_size=0.10, seed=42)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": { 389 | "id": "hrvYdvQ9Hye4" 390 | }, 391 | "source": [ 392 | "### Load the model\n", 393 | "\n", 394 | "Make sure you have your Hugging Face token in your Colab secrets." 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "model = Gemma3nForConditionalGeneration.from_pretrained(\n", 404 | " \"google/gemma-3n-E2B-it\", torch_dtype=torch.bfloat16,\n", 405 | ")\n", 406 | "processor = AutoProcessor.from_pretrained(\n", 407 | " \"google/gemma-3n-E2B-it\",\n", 408 | ")\n", 409 | "processor.tokenizer.padding_side = \"right\"" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "colab": { 417 | "base_uri": "https://localhost:8080/" 418 | }, 419 | "id": "epPCxTFi3XQ2", 420 | "outputId": "f59ad356-5d7c-463e-9c6c-35eb0f0aa586" 421 | }, 422 | "outputs": [ 423 | { 424 | "data": { 425 | "text/plain": [ 426 | "[2, 1, 3, 0, 262273, 256000, 255999, 262272, 262144, 262145]" 427 | ] 428 | }, 429 | "execution_count": 24, 430 | "metadata": {}, 431 | "output_type": "execute_result" 432 | } 433 | ], 434 | "source": [ 435 | "processor.tokenizer.all_special_ids" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "i-xR4GHUeQ9l" 442 | }, 443 | "source": [ 444 | "Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.\n", 445 | "\n", 446 | "In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames." 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 36, 452 | "metadata": { 453 | "id": "x_e3IjDCzioP" 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "def collate_fn(examples):\n", 458 | " video_path = examples[0][\"video\"]\n", 459 | " audio_path = examples[0][\"audio\"]\n", 460 | " sample_idx = filename.split(\"_\")[1].split(\".\")[0]\n", 461 | " frames = downsample_video(f\"videos/{video_path}\")\n", 462 | "\n", 463 | " text = \"Based on the video, predict the category of it.\"\n", 464 | " message = [\n", 465 | " {\n", 466 | " \"role\": \"user\",\n", 467 | " \"content\": [\n", 468 | " {\"type\": \"text\", \"text\": text}\n", 469 | " ],\n", 470 | " },\n", 471 | " ]\n", 472 | " # this is how video inference should be formatted in Gemma3n\n", 473 | " for frame in frames:\n", 474 | " image, timestamp = frame\n", 475 | " message[0][\"content\"].append({\"type\": \"text\", \"text\": f\"Frame {timestamp}:\"})\n", 476 | " timestamp = str(timestamp).replace(\".\", \"_\")\n", 477 | " image.save(f\"image_idx_{sample_idx}_{timestamp}.png\")\n", 478 | " message[0][\"content\"].append({\"type\": \"image\", \"url\": f\"image_idx_{sample_idx}_{timestamp}.png\"})\n", 479 | "\n", 480 | " message[0][\"content\"].append({\"type\": \"audio\", \"audio\": f\"audios/{audio_path}\"})\n", 481 | " message.append({\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": examples[0][\"content_parent_category\"]}]})\n", 482 | " inputs = processor.apply_chat_template(\n", 483 | " message,\n", 484 | " add_generation_prompt=False,\n", 485 | " tokenize=True,\n", 486 | " return_dict=True,\n", 487 | " return_tensors=\"pt\",\n", 488 | " padding=True,\n", 489 | " ).to(model.device)\n", 490 | "\n", 491 | " labels = inputs[\"input_ids\"].clone()\n", 492 | " special_token_ids = processor.tokenizer.all_special_ids\n", 493 | "\n", 494 | " special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)\n", 495 | " mask = torch.isin(labels, special_token_ids_tensor)\n", 496 | " labels[mask] = -100\n", 497 | "\n", 498 | " inputs[\"labels\"] = labels\n", 499 | " if torch.all(inputs[\"pixel_values\"] == 0):\n", 500 | " print(\"Frames are dark\")\n", 501 | "\n", 502 | " return inputs" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "wM6OxwNTiyZ1" 509 | }, 510 | "source": [ 511 | "## Training" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "Wj7yYQTQH7wg" 518 | }, 519 | "source": [ 520 | "We do LoRA fine-tuning again to save up on space." 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 58, 526 | "metadata": { 527 | "id": "uD3W2OO5-1PC" 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "from peft import LoraConfig\n", 532 | "peft_config = LoraConfig(\n", 533 | " task_type=\"CAUSAL_LM\",\n", 534 | " r=16,\n", 535 | " target_modules=\"all-linear\",\n", 536 | " lora_alpha=32,\n", 537 | " lora_dropout=0.05,\n", 538 | " bias=\"none\",\n", 539 | " use_rslora=False,\n", 540 | " use_dora=False,\n", 541 | " modules_to_save=None\n", 542 | ")" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 59, 548 | "metadata": { 549 | "id": "CT7xlPul8RNJ" 550 | }, 551 | "outputs": [], 552 | "source": [ 553 | "model.gradient_checkpointing_disable()" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 60, 559 | "metadata": { 560 | "id": "3stdS0v15tnY" 561 | }, 562 | "outputs": [], 563 | "source": [ 564 | "model.config.use_cache = False" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 61, 570 | "metadata": { 571 | "id": "zG53iSes76H-" 572 | }, 573 | "outputs": [], 574 | "source": [ 575 | "training_args = SFTConfig(\n", 576 | " output_dir=\"/content/gemma-3n-finevideo\",\n", 577 | " eval_strategy='epoch',\n", 578 | " per_device_train_batch_size=1,\n", 579 | " per_device_eval_batch_size=1,\n", 580 | " gradient_accumulation_steps=4,\n", 581 | " gradient_checkpointing=False,\n", 582 | " learning_rate=1e-05,\n", 583 | " num_train_epochs=3.0,\n", 584 | " logging_steps=10,\n", 585 | " save_steps=100,\n", 586 | " bf16=True,\n", 587 | " report_to=[\"tensorboard\"],\n", 588 | " dataset_kwargs={'skip_prepare_dataset': True},\n", 589 | " remove_unused_columns=False,\n", 590 | " max_seq_length=None,\n", 591 | " push_to_hub=True,\n", 592 | " dataloader_pin_memory=False,\n", 593 | ")" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 62, 599 | "metadata": { 600 | "colab": { 601 | "base_uri": "https://localhost:8080/" 602 | }, 603 | "id": "hPaplK2u70D9", 604 | "outputId": "4bd2f1cd-e4d2-4e38-e555-ec2e07528e02" 605 | }, 606 | "outputs": [ 607 | { 608 | "name": "stderr", 609 | "output_type": "stream", 610 | "text": [ 611 | "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "trainer = SFTTrainer(\n", 617 | " model=model,\n", 618 | " args=training_args,\n", 619 | " data_collator=collate_fn,\n", 620 | " train_dataset=dataset[\"train\"],\n", 621 | " eval_dataset=dataset[\"test\"] if training_args.eval_strategy != \"no\" else None,\n", 622 | " processing_class=processor.tokenizer,\n", 623 | " peft_config=peft_config,\n", 624 | ")" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "trainer.train()" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "metadata": { 639 | "id": "qKtWUXVoUyKE" 640 | }, 641 | "source": [ 642 | "Test the model with a video of snowboarding." 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "metadata": {}, 649 | "outputs": [], 650 | "source": [ 651 | "!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 89, 657 | "metadata": { 658 | "id": "KBfMiUChc2Ky" 659 | }, 660 | "outputs": [], 661 | "source": [ 662 | "model = trainer.model # trainer has the adapter" 663 | ] 664 | }, 665 | { 666 | "cell_type": "markdown", 667 | "metadata": { 668 | "id": "R14WzyjbZCwI" 669 | }, 670 | "source": [ 671 | "Strip audio and downsample video." 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 97, 677 | "metadata": { 678 | "colab": { 679 | "base_uri": "https://localhost:8080/" 680 | }, 681 | "id": "RnJZ-QNJaOqp", 682 | "outputId": "c2f42e28-d427-4da7-cf86-6c3b70e6ee02" 683 | }, 684 | "outputs": [ 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "CompletedProcess(args=['ffmpeg', '-i', '/content/IMG_8137.mp4', '-q:a', '0', '-map', 'a', '/content/test_audio.wav', '-y'], returncode=0)" 689 | ] 690 | }, 691 | "execution_count": 97, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "audio_path = \"/content/test_audio.wav\"\n", 698 | "subprocess.run([\n", 699 | " \"ffmpeg\", \"-i\", \"/content/IMG_8137.mp4\",\n", 700 | " \"-q:a\", \"0\", \"-map\", \"a\",\n", 701 | " f\"{audio_path}\",\n", 702 | " \"-y\"\n", 703 | " ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 98, 709 | "metadata": { 710 | "id": "9drrCnfRYi6O" 711 | }, 712 | "outputs": [], 713 | "source": [ 714 | "frames = downsample_video(\"/content/IMG_8137.mp4\")\n", 715 | "\n", 716 | "# repeat the chat template\n", 717 | "text = \"Based on the video, predict the category of it.\"\n", 718 | "message = [\n", 719 | " {\n", 720 | " \"role\": \"user\",\n", 721 | " \"content\": [\n", 722 | " {\"type\": \"text\", \"text\": text}\n", 723 | " ],\n", 724 | " },\n", 725 | "]\n", 726 | "for frame in frames:\n", 727 | " image, timestamp = frame\n", 728 | " message[0][\"content\"].append({\"type\": \"text\", \"text\": f\"Frame {timestamp}:\"})\n", 729 | " timestamp = str(timestamp).replace(\".\", \"_\")\n", 730 | " image.save(f\"test_frame_{timestamp}.png\")\n", 731 | " message[0][\"content\"].append({\"type\": \"image\", \"url\": f\"test_frame_{timestamp}.png\"})\n", 732 | "\n", 733 | "message[0][\"content\"].append({\"type\": \"audio\", \"audio\": f\"{audio_path}\"})" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 99, 739 | "metadata": { 740 | "colab": { 741 | "base_uri": "https://localhost:8080/" 742 | }, 743 | "id": "7s1Dhxf_Z3xU", 744 | "outputId": "1eba1e9e-d859-4aa7-ff4e-992ef272df7c" 745 | }, 746 | "outputs": [ 747 | { 748 | "data": { 749 | "text/plain": [ 750 | "[{'role': 'user',\n", 751 | " 'content': [{'type': 'text',\n", 752 | " 'text': 'Based on the video, predict the category of it.'},\n", 753 | " {'type': 'text', 'text': 'Frame 0.88:'},\n", 754 | " {'type': 'image', 'url': 'test_frame_0_88.png'},\n", 755 | " {'type': 'text', 'text': 'Frame 1.79:'},\n", 756 | " {'type': 'image', 'url': 'test_frame_1_79.png'},\n", 757 | " {'type': 'text', 'text': 'Frame 2.67:'},\n", 758 | " {'type': 'image', 'url': 'test_frame_2_67.png'},\n", 759 | " {'type': 'text', 'text': 'Frame 3.57:'},\n", 760 | " {'type': 'image', 'url': 'test_frame_3_57.png'},\n", 761 | " {'type': 'text', 'text': 'Frame 4.45:'},\n", 762 | " {'type': 'image', 'url': 'test_frame_4_45.png'},\n", 763 | " {'type': 'text', 'text': 'Frame 5.36:'},\n", 764 | " {'type': 'image', 'url': 'test_frame_5_36.png'},\n", 765 | " {'type': 'audio', 'audio': '/content/test_audio.wav'}]}]" 766 | ] 767 | }, 768 | "execution_count": 99, 769 | "metadata": {}, 770 | "output_type": "execute_result" 771 | } 772 | ], 773 | "source": [ 774 | "message" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 100, 780 | "metadata": { 781 | "id": "xNTQRMzsZyQz" 782 | }, 783 | "outputs": [], 784 | "source": [ 785 | "inputs = processor.apply_chat_template(\n", 786 | " message,\n", 787 | " add_generation_prompt=True,\n", 788 | " tokenize=True,\n", 789 | " return_dict=True,\n", 790 | " return_tensors=\"pt\",\n", 791 | " padding=True,\n", 792 | ").to(model.device).to(model.dtype)" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 101, 798 | "metadata": { 799 | "colab": { 800 | "base_uri": "https://localhost:8080/" 801 | }, 802 | "id": "WNfnannnZ5-S", 803 | "outputId": "0afca313-a4f7-4c02-872e-665a853a19df" 804 | }, 805 | "outputs": [ 806 | { 807 | "name": "stderr", 808 | "output_type": "stream", 809 | "text": [ 810 | "The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" 811 | ] 812 | }, 813 | { 814 | "name": "stdout", 815 | "output_type": "stream", 816 | "text": [ 817 | "Snowboarding\n" 818 | ] 819 | } 820 | ], 821 | "source": [ 822 | "input_len = inputs[\"input_ids\"].shape[-1]\n", 823 | "\n", 824 | "with torch.inference_mode():\n", 825 | " generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)\n", 826 | " generation = generation[0][input_len:]\n", 827 | "\n", 828 | "decoded = processor.decode(generation, skip_special_tokens=True)\n", 829 | "print(decoded)" 830 | ] 831 | }, 832 | { 833 | "cell_type": "markdown", 834 | "metadata": { 835 | "id": "LOUBj5dgeddG" 836 | }, 837 | "source": [ 838 | "Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": null, 844 | "metadata": { 845 | "id": "4KnNR6lneuKm" 846 | }, 847 | "outputs": [], 848 | "source": [] 849 | } 850 | ], 851 | "metadata": { 852 | "accelerator": "GPU", 853 | "colab": { 854 | "gpuType": "A100", 855 | "machine_shape": "hm", 856 | "provenance": [] 857 | }, 858 | "kernelspec": { 859 | "display_name": "Python 3", 860 | "name": "python3" 861 | }, 862 | "language_info": { 863 | "name": "python" 864 | }, 865 | "widgets": { 866 | "application/vnd.jupyter.widget-state+json": { 867 | "073975370eab45d9abc4f69f2b7b3d48": { 868 | "model_module": "@jupyter-widgets/base", 869 | "model_module_version": "1.2.0", 870 | "model_name": "LayoutModel", 871 | "state": { 872 | "_model_module": "@jupyter-widgets/base", 873 | "_model_module_version": "1.2.0", 874 | "_model_name": "LayoutModel", 875 | "_view_count": null, 876 | "_view_module": "@jupyter-widgets/base", 877 | "_view_module_version": "1.2.0", 878 | "_view_name": "LayoutView", 879 | "align_content": null, 880 | "align_items": null, 881 | "align_self": null, 882 | "border": null, 883 | "bottom": null, 884 | "display": null, 885 | "flex": null, 886 | "flex_flow": null, 887 | "grid_area": null, 888 | "grid_auto_columns": null, 889 | "grid_auto_flow": null, 890 | "grid_auto_rows": null, 891 | "grid_column": null, 892 | "grid_gap": null, 893 | "grid_row": null, 894 | "grid_template_areas": null, 895 | "grid_template_columns": null, 896 | "grid_template_rows": null, 897 | "height": null, 898 | "justify_content": null, 899 | "justify_items": null, 900 | "left": null, 901 | "margin": null, 902 | "max_height": null, 903 | "max_width": null, 904 | "min_height": null, 905 | "min_width": null, 906 | "object_fit": null, 907 | "object_position": null, 908 | "order": null, 909 | "overflow": null, 910 | "overflow_x": null, 911 | "overflow_y": null, 912 | "padding": null, 913 | "right": null, 914 | "top": null, 915 | "visibility": null, 916 | "width": null 917 | } 918 | }, 919 | "0d1dfc47d0704506bc6e521c07162b4b": { 920 | "model_module": "@jupyter-widgets/controls", 921 | "model_module_version": "1.5.0", 922 | "model_name": "DescriptionStyleModel", 923 | "state": { 924 | "_model_module": "@jupyter-widgets/controls", 925 | "_model_module_version": "1.5.0", 926 | "_model_name": "DescriptionStyleModel", 927 | "_view_count": null, 928 | "_view_module": "@jupyter-widgets/base", 929 | "_view_module_version": "1.2.0", 930 | "_view_name": "StyleView", 931 | "description_width": "" 932 | } 933 | }, 934 | "143d6079d1744eedb41e2e1182bd0f33": { 935 | "model_module": "@jupyter-widgets/controls", 936 | "model_module_version": "1.5.0", 937 | "model_name": "ProgressStyleModel", 938 | "state": { 939 | "_model_module": "@jupyter-widgets/controls", 940 | "_model_module_version": "1.5.0", 941 | "_model_name": "ProgressStyleModel", 942 | "_view_count": null, 943 | "_view_module": "@jupyter-widgets/base", 944 | "_view_module_version": "1.2.0", 945 | "_view_name": "StyleView", 946 | "bar_color": null, 947 | "description_width": "" 948 | } 949 | }, 950 | "1801493cd54742fd99752b2f605af1cb": { 951 | "model_module": "@jupyter-widgets/controls", 952 | "model_module_version": "1.5.0", 953 | "model_name": "FloatProgressModel", 954 | "state": { 955 | "_dom_classes": [], 956 | "_model_module": "@jupyter-widgets/controls", 957 | "_model_module_version": "1.5.0", 958 | "_model_name": "FloatProgressModel", 959 | "_view_count": null, 960 | "_view_module": "@jupyter-widgets/controls", 961 | "_view_module_version": "1.5.0", 962 | "_view_name": "ProgressView", 963 | "bar_style": "success", 964 | "description": "", 965 | "description_tooltip": null, 966 | "layout": "IPY_MODEL_20b59cdc19684e1c97517e36f5bf8d6a", 967 | "max": 1, 968 | "min": 0, 969 | "orientation": "horizontal", 970 | "style": "IPY_MODEL_143d6079d1744eedb41e2e1182bd0f33", 971 | "value": 1 972 | } 973 | }, 974 | "20b59cdc19684e1c97517e36f5bf8d6a": { 975 | "model_module": "@jupyter-widgets/base", 976 | "model_module_version": "1.2.0", 977 | "model_name": "LayoutModel", 978 | "state": { 979 | "_model_module": "@jupyter-widgets/base", 980 | "_model_module_version": "1.2.0", 981 | "_model_name": "LayoutModel", 982 | "_view_count": null, 983 | "_view_module": "@jupyter-widgets/base", 984 | "_view_module_version": "1.2.0", 985 | "_view_name": "LayoutView", 986 | "align_content": null, 987 | "align_items": null, 988 | "align_self": null, 989 | "border": null, 990 | "bottom": null, 991 | "display": null, 992 | "flex": null, 993 | "flex_flow": null, 994 | "grid_area": null, 995 | "grid_auto_columns": null, 996 | "grid_auto_flow": null, 997 | "grid_auto_rows": null, 998 | "grid_column": null, 999 | "grid_gap": null, 1000 | "grid_row": null, 1001 | "grid_template_areas": null, 1002 | "grid_template_columns": null, 1003 | "grid_template_rows": null, 1004 | "height": null, 1005 | "justify_content": null, 1006 | "justify_items": null, 1007 | "left": null, 1008 | "margin": null, 1009 | "max_height": null, 1010 | "max_width": null, 1011 | "min_height": null, 1012 | "min_width": null, 1013 | "object_fit": null, 1014 | "object_position": null, 1015 | "order": null, 1016 | "overflow": null, 1017 | "overflow_x": null, 1018 | "overflow_y": null, 1019 | "padding": null, 1020 | "right": null, 1021 | "top": null, 1022 | "visibility": null, 1023 | "width": "20px" 1024 | } 1025 | }, 1026 | "2e9d5cf7a5c6466a9e1de6d4f403cd95": { 1027 | "model_module": "@jupyter-widgets/controls", 1028 | "model_module_version": "1.5.0", 1029 | "model_name": "DescriptionStyleModel", 1030 | "state": { 1031 | "_model_module": "@jupyter-widgets/controls", 1032 | "_model_module_version": "1.5.0", 1033 | "_model_name": "DescriptionStyleModel", 1034 | "_view_count": null, 1035 | "_view_module": "@jupyter-widgets/base", 1036 | "_view_module_version": "1.2.0", 1037 | "_view_name": "StyleView", 1038 | "description_width": "" 1039 | } 1040 | }, 1041 | "3262178b8baf4741b06250d7416df1f3": { 1042 | "model_module": "@jupyter-widgets/base", 1043 | "model_module_version": "1.2.0", 1044 | "model_name": "LayoutModel", 1045 | "state": { 1046 | "_model_module": "@jupyter-widgets/base", 1047 | "_model_module_version": "1.2.0", 1048 | "_model_name": "LayoutModel", 1049 | "_view_count": null, 1050 | "_view_module": "@jupyter-widgets/base", 1051 | "_view_module_version": "1.2.0", 1052 | "_view_name": "LayoutView", 1053 | "align_content": null, 1054 | "align_items": null, 1055 | "align_self": null, 1056 | "border": null, 1057 | "bottom": null, 1058 | "display": null, 1059 | "flex": null, 1060 | "flex_flow": null, 1061 | "grid_area": null, 1062 | "grid_auto_columns": null, 1063 | "grid_auto_flow": null, 1064 | "grid_auto_rows": null, 1065 | "grid_column": null, 1066 | "grid_gap": null, 1067 | "grid_row": null, 1068 | "grid_template_areas": null, 1069 | "grid_template_columns": null, 1070 | "grid_template_rows": null, 1071 | "height": null, 1072 | "justify_content": null, 1073 | "justify_items": null, 1074 | "left": null, 1075 | "margin": null, 1076 | "max_height": null, 1077 | "max_width": null, 1078 | "min_height": null, 1079 | "min_width": null, 1080 | "object_fit": null, 1081 | "object_position": null, 1082 | "order": null, 1083 | "overflow": null, 1084 | "overflow_x": null, 1085 | "overflow_y": null, 1086 | "padding": null, 1087 | "right": null, 1088 | "top": null, 1089 | "visibility": null, 1090 | "width": null 1091 | } 1092 | }, 1093 | "3e25db05674d4d2f8fd839a0ec63e7d8": { 1094 | "model_module": "@jupyter-widgets/base", 1095 | "model_module_version": "1.2.0", 1096 | "model_name": "LayoutModel", 1097 | "state": { 1098 | "_model_module": "@jupyter-widgets/base", 1099 | "_model_module_version": "1.2.0", 1100 | "_model_name": "LayoutModel", 1101 | "_view_count": null, 1102 | "_view_module": "@jupyter-widgets/base", 1103 | "_view_module_version": "1.2.0", 1104 | "_view_name": "LayoutView", 1105 | "align_content": null, 1106 | "align_items": null, 1107 | "align_self": null, 1108 | "border": null, 1109 | "bottom": null, 1110 | "display": null, 1111 | "flex": null, 1112 | "flex_flow": null, 1113 | "grid_area": null, 1114 | "grid_auto_columns": null, 1115 | "grid_auto_flow": null, 1116 | "grid_auto_rows": null, 1117 | "grid_column": null, 1118 | "grid_gap": null, 1119 | "grid_row": null, 1120 | "grid_template_areas": null, 1121 | "grid_template_columns": null, 1122 | "grid_template_rows": null, 1123 | "height": null, 1124 | "justify_content": null, 1125 | "justify_items": null, 1126 | "left": null, 1127 | "margin": null, 1128 | "max_height": null, 1129 | "max_width": null, 1130 | "min_height": null, 1131 | "min_width": null, 1132 | "object_fit": null, 1133 | "object_position": null, 1134 | "order": null, 1135 | "overflow": null, 1136 | "overflow_x": null, 1137 | "overflow_y": null, 1138 | "padding": null, 1139 | "right": null, 1140 | "top": null, 1141 | "visibility": null, 1142 | "width": null 1143 | } 1144 | }, 1145 | "425f9f26bd0647b1989ecb704414aa9f": { 1146 | "model_module": "@jupyter-widgets/base", 1147 | "model_module_version": "1.2.0", 1148 | "model_name": "LayoutModel", 1149 | "state": { 1150 | "_model_module": "@jupyter-widgets/base", 1151 | "_model_module_version": "1.2.0", 1152 | "_model_name": "LayoutModel", 1153 | "_view_count": null, 1154 | "_view_module": "@jupyter-widgets/base", 1155 | "_view_module_version": "1.2.0", 1156 | "_view_name": "LayoutView", 1157 | "align_content": null, 1158 | "align_items": null, 1159 | "align_self": null, 1160 | "border": null, 1161 | "bottom": null, 1162 | "display": null, 1163 | "flex": null, 1164 | "flex_flow": null, 1165 | "grid_area": null, 1166 | "grid_auto_columns": null, 1167 | "grid_auto_flow": null, 1168 | "grid_auto_rows": null, 1169 | "grid_column": null, 1170 | "grid_gap": null, 1171 | "grid_row": null, 1172 | "grid_template_areas": null, 1173 | "grid_template_columns": null, 1174 | "grid_template_rows": null, 1175 | "height": null, 1176 | "justify_content": null, 1177 | "justify_items": null, 1178 | "left": null, 1179 | "margin": null, 1180 | "max_height": null, 1181 | "max_width": null, 1182 | "min_height": null, 1183 | "min_width": null, 1184 | "object_fit": null, 1185 | "object_position": null, 1186 | "order": null, 1187 | "overflow": null, 1188 | "overflow_x": null, 1189 | "overflow_y": null, 1190 | "padding": null, 1191 | "right": null, 1192 | "top": null, 1193 | "visibility": null, 1194 | "width": null 1195 | } 1196 | }, 1197 | "464ffcc84f48468b8f5d3f08412c6101": { 1198 | "model_module": "@jupyter-widgets/controls", 1199 | "model_module_version": "1.5.0", 1200 | "model_name": "DescriptionStyleModel", 1201 | "state": { 1202 | "_model_module": "@jupyter-widgets/controls", 1203 | "_model_module_version": "1.5.0", 1204 | "_model_name": "DescriptionStyleModel", 1205 | "_view_count": null, 1206 | "_view_module": "@jupyter-widgets/base", 1207 | "_view_module_version": "1.2.0", 1208 | "_view_name": "StyleView", 1209 | "description_width": "" 1210 | } 1211 | }, 1212 | "4846c29045294042b8d916cb0fd8f9d6": { 1213 | "model_module": "@jupyter-widgets/controls", 1214 | "model_module_version": "1.5.0", 1215 | "model_name": "DescriptionStyleModel", 1216 | "state": { 1217 | "_model_module": "@jupyter-widgets/controls", 1218 | "_model_module_version": "1.5.0", 1219 | "_model_name": "DescriptionStyleModel", 1220 | "_view_count": null, 1221 | "_view_module": "@jupyter-widgets/base", 1222 | "_view_module_version": "1.2.0", 1223 | "_view_name": "StyleView", 1224 | "description_width": "" 1225 | } 1226 | }, 1227 | "4eb3613e8efa4fd9adf2cfe27bfbd699": { 1228 | "model_module": "@jupyter-widgets/controls", 1229 | "model_module_version": "1.5.0", 1230 | "model_name": "HBoxModel", 1231 | "state": { 1232 | "_dom_classes": [], 1233 | "_model_module": "@jupyter-widgets/controls", 1234 | "_model_module_version": "1.5.0", 1235 | "_model_name": "HBoxModel", 1236 | "_view_count": null, 1237 | "_view_module": "@jupyter-widgets/controls", 1238 | "_view_module_version": "1.5.0", 1239 | "_view_name": "HBoxView", 1240 | "box_style": "", 1241 | "children": [ 1242 | "IPY_MODEL_c15cc5cb9d7947a99a01a30e430d0459", 1243 | "IPY_MODEL_1801493cd54742fd99752b2f605af1cb", 1244 | "IPY_MODEL_e5e518d8cf5f4aa5a0ecad6583f0d317" 1245 | ], 1246 | "layout": "IPY_MODEL_425f9f26bd0647b1989ecb704414aa9f" 1247 | } 1248 | }, 1249 | "5eeff3de00c5488db1817328e83bb992": { 1250 | "model_module": "@jupyter-widgets/base", 1251 | "model_module_version": "1.2.0", 1252 | "model_name": "LayoutModel", 1253 | "state": { 1254 | "_model_module": "@jupyter-widgets/base", 1255 | "_model_module_version": "1.2.0", 1256 | "_model_name": "LayoutModel", 1257 | "_view_count": null, 1258 | "_view_module": "@jupyter-widgets/base", 1259 | "_view_module_version": "1.2.0", 1260 | "_view_name": "LayoutView", 1261 | "align_content": null, 1262 | "align_items": null, 1263 | "align_self": null, 1264 | "border": null, 1265 | "bottom": null, 1266 | "display": null, 1267 | "flex": null, 1268 | "flex_flow": null, 1269 | "grid_area": null, 1270 | "grid_auto_columns": null, 1271 | "grid_auto_flow": null, 1272 | "grid_auto_rows": null, 1273 | "grid_column": null, 1274 | "grid_gap": null, 1275 | "grid_row": null, 1276 | "grid_template_areas": null, 1277 | "grid_template_columns": null, 1278 | "grid_template_rows": null, 1279 | "height": null, 1280 | "justify_content": null, 1281 | "justify_items": null, 1282 | "left": null, 1283 | "margin": null, 1284 | "max_height": null, 1285 | "max_width": null, 1286 | "min_height": null, 1287 | "min_width": null, 1288 | "object_fit": null, 1289 | "object_position": null, 1290 | "order": null, 1291 | "overflow": null, 1292 | "overflow_x": null, 1293 | "overflow_y": null, 1294 | "padding": null, 1295 | "right": null, 1296 | "top": null, 1297 | "visibility": null, 1298 | "width": null 1299 | } 1300 | }, 1301 | "94d5d3b00449488caa6d8badc443a74f": { 1302 | "model_module": "@jupyter-widgets/controls", 1303 | "model_module_version": "1.5.0", 1304 | "model_name": "HTMLModel", 1305 | "state": { 1306 | "_dom_classes": [], 1307 | "_model_module": "@jupyter-widgets/controls", 1308 | "_model_module_version": "1.5.0", 1309 | "_model_name": "HTMLModel", 1310 | "_view_count": null, 1311 | "_view_module": "@jupyter-widgets/controls", 1312 | "_view_module_version": "1.5.0", 1313 | "_view_name": "HTMLView", 1314 | "description": "", 1315 | "description_tooltip": null, 1316 | "layout": "IPY_MODEL_3262178b8baf4741b06250d7416df1f3", 1317 | "placeholder": "​", 1318 | "style": "IPY_MODEL_2e9d5cf7a5c6466a9e1de6d4f403cd95", 1319 | "value": "Loading checkpoint shards: 100%" 1320 | } 1321 | }, 1322 | "9c0857a4034f4780ab5e7fdd9aa9d09d": { 1323 | "model_module": "@jupyter-widgets/controls", 1324 | "model_module_version": "1.5.0", 1325 | "model_name": "ProgressStyleModel", 1326 | "state": { 1327 | "_model_module": "@jupyter-widgets/controls", 1328 | "_model_module_version": "1.5.0", 1329 | "_model_name": "ProgressStyleModel", 1330 | "_view_count": null, 1331 | "_view_module": "@jupyter-widgets/base", 1332 | "_view_module_version": "1.2.0", 1333 | "_view_name": "StyleView", 1334 | "bar_color": null, 1335 | "description_width": "" 1336 | } 1337 | }, 1338 | "9d2631150d5c4089bcc95f22a6698287": { 1339 | "model_module": "@jupyter-widgets/base", 1340 | "model_module_version": "1.2.0", 1341 | "model_name": "LayoutModel", 1342 | "state": { 1343 | "_model_module": "@jupyter-widgets/base", 1344 | "_model_module_version": "1.2.0", 1345 | "_model_name": "LayoutModel", 1346 | "_view_count": null, 1347 | "_view_module": "@jupyter-widgets/base", 1348 | "_view_module_version": "1.2.0", 1349 | "_view_name": "LayoutView", 1350 | "align_content": null, 1351 | "align_items": null, 1352 | "align_self": null, 1353 | "border": null, 1354 | "bottom": null, 1355 | "display": null, 1356 | "flex": null, 1357 | "flex_flow": null, 1358 | "grid_area": null, 1359 | "grid_auto_columns": null, 1360 | "grid_auto_flow": null, 1361 | "grid_auto_rows": null, 1362 | "grid_column": null, 1363 | "grid_gap": null, 1364 | "grid_row": null, 1365 | "grid_template_areas": null, 1366 | "grid_template_columns": null, 1367 | "grid_template_rows": null, 1368 | "height": null, 1369 | "justify_content": null, 1370 | "justify_items": null, 1371 | "left": null, 1372 | "margin": null, 1373 | "max_height": null, 1374 | "max_width": null, 1375 | "min_height": null, 1376 | "min_width": null, 1377 | "object_fit": null, 1378 | "object_position": null, 1379 | "order": null, 1380 | "overflow": null, 1381 | "overflow_x": null, 1382 | "overflow_y": null, 1383 | "padding": null, 1384 | "right": null, 1385 | "top": null, 1386 | "visibility": null, 1387 | "width": null 1388 | } 1389 | }, 1390 | "a33fedc485b346b1b9d4fb8b18e8ac64": { 1391 | "model_module": "@jupyter-widgets/controls", 1392 | "model_module_version": "1.5.0", 1393 | "model_name": "HBoxModel", 1394 | "state": { 1395 | "_dom_classes": [], 1396 | "_model_module": "@jupyter-widgets/controls", 1397 | "_model_module_version": "1.5.0", 1398 | "_model_name": "HBoxModel", 1399 | "_view_count": null, 1400 | "_view_module": "@jupyter-widgets/controls", 1401 | "_view_module_version": "1.5.0", 1402 | "_view_name": "HBoxView", 1403 | "box_style": "", 1404 | "children": [ 1405 | "IPY_MODEL_94d5d3b00449488caa6d8badc443a74f", 1406 | "IPY_MODEL_a60a111fc7c24bd7b21fed3f3dd64f29", 1407 | "IPY_MODEL_e830732fc2bc4848847ea85c772d0b98" 1408 | ], 1409 | "layout": "IPY_MODEL_3e25db05674d4d2f8fd839a0ec63e7d8" 1410 | } 1411 | }, 1412 | "a60a111fc7c24bd7b21fed3f3dd64f29": { 1413 | "model_module": "@jupyter-widgets/controls", 1414 | "model_module_version": "1.5.0", 1415 | "model_name": "FloatProgressModel", 1416 | "state": { 1417 | "_dom_classes": [], 1418 | "_model_module": "@jupyter-widgets/controls", 1419 | "_model_module_version": "1.5.0", 1420 | "_model_name": "FloatProgressModel", 1421 | "_view_count": null, 1422 | "_view_module": "@jupyter-widgets/controls", 1423 | "_view_module_version": "1.5.0", 1424 | "_view_name": "ProgressView", 1425 | "bar_style": "success", 1426 | "description": "", 1427 | "description_tooltip": null, 1428 | "layout": "IPY_MODEL_9d2631150d5c4089bcc95f22a6698287", 1429 | "max": 3, 1430 | "min": 0, 1431 | "orientation": "horizontal", 1432 | "style": "IPY_MODEL_9c0857a4034f4780ab5e7fdd9aa9d09d", 1433 | "value": 3 1434 | } 1435 | }, 1436 | "c022d8fabedc43ef9db0c8aca82d215e": { 1437 | "model_module": "@jupyter-widgets/base", 1438 | "model_module_version": "1.2.0", 1439 | "model_name": "LayoutModel", 1440 | "state": { 1441 | "_model_module": "@jupyter-widgets/base", 1442 | "_model_module_version": "1.2.0", 1443 | "_model_name": "LayoutModel", 1444 | "_view_count": null, 1445 | "_view_module": "@jupyter-widgets/base", 1446 | "_view_module_version": "1.2.0", 1447 | "_view_name": "LayoutView", 1448 | "align_content": null, 1449 | "align_items": null, 1450 | "align_self": null, 1451 | "border": null, 1452 | "bottom": null, 1453 | "display": null, 1454 | "flex": null, 1455 | "flex_flow": null, 1456 | "grid_area": null, 1457 | "grid_auto_columns": null, 1458 | "grid_auto_flow": null, 1459 | "grid_auto_rows": null, 1460 | "grid_column": null, 1461 | "grid_gap": null, 1462 | "grid_row": null, 1463 | "grid_template_areas": null, 1464 | "grid_template_columns": null, 1465 | "grid_template_rows": null, 1466 | "height": null, 1467 | "justify_content": null, 1468 | "justify_items": null, 1469 | "left": null, 1470 | "margin": null, 1471 | "max_height": null, 1472 | "max_width": null, 1473 | "min_height": null, 1474 | "min_width": null, 1475 | "object_fit": null, 1476 | "object_position": null, 1477 | "order": null, 1478 | "overflow": null, 1479 | "overflow_x": null, 1480 | "overflow_y": null, 1481 | "padding": null, 1482 | "right": null, 1483 | "top": null, 1484 | "visibility": null, 1485 | "width": null 1486 | } 1487 | }, 1488 | "c15cc5cb9d7947a99a01a30e430d0459": { 1489 | "model_module": "@jupyter-widgets/controls", 1490 | "model_module_version": "1.5.0", 1491 | "model_name": "HTMLModel", 1492 | "state": { 1493 | "_dom_classes": [], 1494 | "_model_module": "@jupyter-widgets/controls", 1495 | "_model_module_version": "1.5.0", 1496 | "_model_name": "HTMLModel", 1497 | "_view_count": null, 1498 | "_view_module": "@jupyter-widgets/controls", 1499 | "_view_module_version": "1.5.0", 1500 | "_view_name": "HTMLView", 1501 | "description": "", 1502 | "description_tooltip": null, 1503 | "layout": "IPY_MODEL_5eeff3de00c5488db1817328e83bb992", 1504 | "placeholder": "​", 1505 | "style": "IPY_MODEL_4846c29045294042b8d916cb0fd8f9d6", 1506 | "value": "Generating train split: " 1507 | } 1508 | }, 1509 | "e5e518d8cf5f4aa5a0ecad6583f0d317": { 1510 | "model_module": "@jupyter-widgets/controls", 1511 | "model_module_version": "1.5.0", 1512 | "model_name": "HTMLModel", 1513 | "state": { 1514 | "_dom_classes": [], 1515 | "_model_module": "@jupyter-widgets/controls", 1516 | "_model_module_version": "1.5.0", 1517 | "_model_name": "HTMLModel", 1518 | "_view_count": null, 1519 | "_view_module": "@jupyter-widgets/controls", 1520 | "_view_module_version": "1.5.0", 1521 | "_view_name": "HTMLView", 1522 | "description": "", 1523 | "description_tooltip": null, 1524 | "layout": "IPY_MODEL_c022d8fabedc43ef9db0c8aca82d215e", 1525 | "placeholder": "​", 1526 | "style": "IPY_MODEL_464ffcc84f48468b8f5d3f08412c6101", 1527 | "value": " 869/0 [00:00<00:00, 8490.20 examples/s]" 1528 | } 1529 | }, 1530 | "e830732fc2bc4848847ea85c772d0b98": { 1531 | "model_module": "@jupyter-widgets/controls", 1532 | "model_module_version": "1.5.0", 1533 | "model_name": "HTMLModel", 1534 | "state": { 1535 | "_dom_classes": [], 1536 | "_model_module": "@jupyter-widgets/controls", 1537 | "_model_module_version": "1.5.0", 1538 | "_model_name": "HTMLModel", 1539 | "_view_count": null, 1540 | "_view_module": "@jupyter-widgets/controls", 1541 | "_view_module_version": "1.5.0", 1542 | "_view_name": "HTMLView", 1543 | "description": "", 1544 | "description_tooltip": null, 1545 | "layout": "IPY_MODEL_073975370eab45d9abc4f69f2b7b3d48", 1546 | "placeholder": "​", 1547 | "style": "IPY_MODEL_0d1dfc47d0704506bc6e521c07162b4b", 1548 | "value": " 3/3 [00:00<00:00,  3.91it/s]" 1549 | } 1550 | } 1551 | } 1552 | } 1553 | }, 1554 | "nbformat": 4, 1555 | "nbformat_minor": 0 1556 | } 1557 | --------------------------------------------------------------------------------