├── dataset ├── __init__.py └── multimodal_dataset.py ├── docs └── teaser.png ├── models ├── __init__.py └── model_builders.py ├── .gitignore ├── checkpointer ├── __init__.py ├── convert_weights.py └── checkpointer.py ├── requirements.txt ├── .gitmodules ├── configs └── llama3_2 │ ├── eval │ ├── objaworld │ │ ├── rendering.yaml │ │ └── recognition.yaml │ ├── objectron │ │ └── recognition.yaml │ └── clevr │ │ ├── question-answering.yaml │ │ ├── rendering.yaml │ │ ├── recognition.yaml │ │ ├── recognition-colab-demo.yaml │ │ └── instruction-following │ │ ├── moving-with-relation.yaml │ │ ├── removal-with-relation.yaml │ │ ├── appearance-no-relation.yaml │ │ ├── insertion-with-relation.yaml │ │ └── appearance-with-relation.yaml │ └── train │ ├── objaworld │ ├── rendering.yaml │ └── recognition.yaml │ ├── objectron │ └── recognition.yaml │ └── clevr │ ├── rendering.yaml │ ├── recognition.yaml │ ├── question-answering.yaml │ └── instruction-following.yaml ├── scripts ├── tokenizer.py ├── compute_text_answer_accuracy.py ├── taming_transformers_utils.py ├── compute_ssim_l2loss.py ├── decode_image_embeddings.py ├── compute_jaccard_index.py ├── compute_jaccard_index_objectron.py └── generate_function.py ├── taming-kyvo.yml ├── kyvo.yml ├── DATA.md ├── LICENSE └── README.md /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.multimodal_dataset import * 2 | -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AadSah/kyvo/HEAD/docs/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.model_builders import * 2 | from models.model_component_builders import * 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoints/ 3 | kyvo-datasets-and-codebooks/ 4 | taming/ 5 | taming-transformers/ -------------------------------------------------------------------------------- /checkpointer/__init__.py: -------------------------------------------------------------------------------- 1 | from checkpointer.checkpointer import * 2 | from checkpointer.convert_weights import * 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.1 2 | torchvision==0.19.1 3 | torchao==0.5.0 4 | transformers==4.41.2 5 | bitsandbytes==0.43.1 -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "torchtune"] 2 | path = torchtune 3 | url = https://github.com/pytorch/torchtune.git 4 | [submodule "taming-transformers"] 5 | path = taming-transformers 6 | url = https://github.com/CompVis/taming-transformers.git 7 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objaworld/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objaworld/rendering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/objaworld/rendering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objaworld_rendering_inference" 52 | image_embeddings_output_folder: "./checkpoints/objaworld/rendering-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objaworld/rendering-inference/three_d_json" 54 | 55 | dataset_name: "ObjaWorld" 56 | vqgan_type: "objaworld" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "3-I" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objectron/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_omni3d_objectron_custom_finer 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objectron/recognition/ 14 | checkpoint_files: [ 15 | meta_model_19.pt 16 | ] 17 | output_dir: ./checkpoints/objectron/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objectron_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/objectron/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objectron/recognition-inference/three_d_json" 54 | 55 | dataset_name: "Objectron" 56 | vqgan_type: "objectron" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 128372 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objaworld/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objaworld/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/objaworld/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objaworld_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/objaworld/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objaworld/recognition-inference/three_d_json" 54 | 55 | dataset_name: "ObjaWorld" 56 | vqgan_type: "objaworld" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/question-answering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/question-answering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/question-answering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | instruct_template: null 31 | chat_format: null 32 | max_new_tokens: 2500 33 | temperature: 1.0 34 | # top_k: 300 35 | 36 | load_text_data: True 37 | load_image_data: True 38 | load_three_d_data: True 39 | load_text_target_data: True 40 | load_image_target_data: False 41 | load_three_d_target_data: False 42 | 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_for_vqa.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_for_vqa.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_questions.json" 48 | text_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_answers.json" 49 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_for_vqa.json" 50 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_for_vqa.json" 51 | 52 | run_identifier: "clevr_question-answering_inference" 53 | image_embeddings_output_folder: "./checkpoints/clevr/question-answering-inference/image_embeddings" 54 | three_d_json_output_folder: "./checkpoints/clevr/question-answering-inference/three_d_json" 55 | 56 | dataset_name: "CLEVR" 57 | vqgan_type: "clevr" 58 | vqgan_row_col_size: 16 59 | reorder_image_tokens: False 60 | task_type: "I+3+Q-A" 61 | num_samples: -1 62 | sample_start_idx: 0 63 | image_token_offset: 129471 64 | 65 | # enable_kv_cache: True 66 | 67 | quantizer: null 68 | 69 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/rendering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/rendering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_rendering_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/rendering-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/rendering-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "3-I" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/recognition-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/recognition-colab-demo.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/recognition-inference-colab-demo 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/recognition-inference-colab-demo/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/recognition-inference-colab-demo/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/moving-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_moving_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_moving_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_moving_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_moving_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_moving_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-moving-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/removal-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_removal_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_removal_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_removal_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_removal_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_removal_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-removal-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/appearance-no-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_appearance_no_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_appearance_no_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_appearance_no_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_appearance_no_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_appearance_no_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-appearance-no-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/insertion-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_insertion_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_insertion_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_insertion_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_insertion_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_insertion_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-insertion-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/appearance-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_appearance_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_appearance_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_appearance_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_appearance_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_appearance_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-appearance-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /scripts/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | import numpy as np 3 | 4 | 5 | def get_tokenizer(model_name="meta-llama/Llama-3.2-1B-Instruct"): 6 | tokenizer = AutoTokenizer.from_pretrained(model_name) 7 | 8 | special_tokens_dict = { 9 | "additional_special_tokens": [ 10 | "[SCENE-START]", 11 | "[SCENE-END]", 12 | "[OBJECT-START]", 13 | "[OBJECT-END]", 14 | "[SIZE]", 15 | "[COLOR]", 16 | "[MATERIAL]", 17 | "[SHAPE]", 18 | "[LOCATION]", 19 | "[IMAGE-START]", 20 | "[IMAGE-END]", 21 | "[TEXT-START]", 22 | "[TEXT-END]", 23 | "[OUTPUT-START]", 24 | ] 25 | } 26 | 27 | number_tokens = np.arange(-3, 3 + 0.005, 0.005).tolist() 28 | # convert the number tokens to strings 29 | number_tokens = [str(format(round(token, 3), ".3f")) for token in number_tokens] 30 | # replace "-0.000" with "0.000" in number tokens and update the list 31 | number_tokens = [token.replace("-0.000", "0.000") for token in number_tokens] 32 | 33 | tokenizer.add_tokens(number_tokens) 34 | tokenizer.add_special_tokens(special_tokens_dict) 35 | 36 | return tokenizer 37 | 38 | 39 | def get_tokenizer_omni3d_objectron(model_name="meta-llama/Llama-3.2-1B-Instruct"): 40 | tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | 42 | special_tokens_dict = { 43 | "additional_special_tokens": [ 44 | "[SCENE-START]", 45 | "[SCENE-END]", 46 | "[OBJECT-START]", 47 | "[OBJECT-END]", 48 | "[CATEGORY]", 49 | "[CENTER_CAM]", 50 | "[DIMENSIONS]", 51 | "[IMAGE-START]", 52 | "[IMAGE-END]", 53 | "[TEXT-START]", 54 | "[TEXT-END]", 55 | "[OUTPUT-START]", 56 | ] 57 | } 58 | 59 | center_cam_x_number_tokens = [0.01 * i for i in range(-20, 21)] 60 | center_cam_y_number_tokens = [0.01 * i for i in range(-20, 11)] 61 | center_cam_z_number_tokens = [0.05 * i for i in range(61)] 62 | center_cam_z_number_tokens.extend([5.0, 7.5, 10.0, 12.5, 15.0, 17.5]) 63 | 64 | dimensions_length_number_tokens = [0.05 * i for i in range(21)] 65 | dimensions_width_number_tokens = [0.05 * i for i in range(21)] 66 | dimensions_height_number_tokens = [0.05 * i for i in range(25)] 67 | 68 | # merge, unique and sort the number tokens 69 | number_tokens = list( 70 | set( 71 | center_cam_x_number_tokens 72 | + center_cam_y_number_tokens 73 | + center_cam_z_number_tokens 74 | + dimensions_length_number_tokens 75 | + dimensions_width_number_tokens 76 | + dimensions_height_number_tokens 77 | ) 78 | ) 79 | number_tokens.sort() 80 | # convert the number tokens to strings 81 | number_tokens = [str(format(round(token, 3), ".2f")) for token in number_tokens] 82 | 83 | tokenizer.add_tokens(number_tokens) 84 | tokenizer.add_special_tokens(special_tokens_dict) 85 | 86 | return tokenizer 87 | -------------------------------------------------------------------------------- /scripts/compute_text_answer_accuracy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | # Parser for command line arguments 7 | parser = argparse.ArgumentParser(description="Compute Text Answer Accuracy") 8 | 9 | parser.add_argument( 10 | "--groundtruth_file", 11 | required=True, 12 | type=str, 13 | help="Path to the ground truth answers JSON file.", 14 | ) 15 | 16 | parser.add_argument( 17 | "--predicted_file", 18 | required=True, 19 | type=str, 20 | help="Path to the predicted answers JSON file.", 21 | ) 22 | 23 | args = parser.parse_args() 24 | 25 | 26 | def clean_answer(answer): 27 | """ 28 | Clean the answer by removing special tokens like [TEXT-START], [TEXT-END], and <|end_of_text|>. 29 | """ 30 | return re.sub(r"\[TEXT-START\]|\[TEXT-END\]|<\|end_of_text\|>", "", answer).strip() 31 | 32 | 33 | def evaluate_accuracy(groundtruth_file, predicted_file): 34 | """ 35 | Evaluate the accuracy of predicted answers against ground truth answers. 36 | 37 | Args: 38 | - groundtruth_file (str): Path to the ground truth answers JSON file. 39 | - predicted_file (str): Path to the predicted answers JSON file. 40 | 41 | Returns: 42 | - float: Accuracy of the model (in percentage). 43 | """ 44 | # Load ground truth and predicted answers 45 | with open(groundtruth_file, "r") as gt_file: 46 | groundtruth_data = json.load(gt_file) 47 | with open(predicted_file, "r") as pred_file: 48 | predicted_data = json.load(pred_file) 49 | 50 | # Assertions to ensure data integrity 51 | assert "answers" in groundtruth_data, "Ground truth JSON missing 'answers' key." 52 | assert "answers" in predicted_data, "Predicted answers JSON missing 'answers' key." 53 | assert len(groundtruth_data["answers"]) == len( 54 | predicted_data["answers"] 55 | ), "Mismatch in number of ground truth and predicted answers." 56 | 57 | groundtruth_answers = groundtruth_data["answers"] 58 | predicted_answers = predicted_data["answers"] 59 | 60 | correct_count = 0 61 | 62 | # Progress bar for evaluation 63 | for gt, pred in tqdm( 64 | zip(groundtruth_answers, predicted_answers), 65 | total=len(groundtruth_answers), 66 | desc="Evaluating", 67 | ): 68 | # Assertions for data consistency 69 | assert ( 70 | gt["image_filename"].split("_scene_")[1] 71 | == pred["image_filename"].split("_scene_")[1] 72 | ), f"Image filename mismatch: {gt['image_filename']} vs {pred['image_filename']}" 73 | 74 | # Clean answers to extract meaningful text 75 | gt_answer = clean_answer(gt["answer"]) 76 | pred_answer = clean_answer(pred["answer"]) 77 | 78 | # Compare and count correct predictions 79 | if gt_answer == pred_answer: 80 | correct_count += 1 81 | 82 | # Calculate accuracy 83 | accuracy = (correct_count / len(groundtruth_answers)) * 100 84 | return accuracy 85 | 86 | 87 | # Evaluate and print accuracy 88 | accuracy = evaluate_accuracy(args.groundtruth_file, args.predicted_file) 89 | print(f"Model Text Answer Accuracy: {accuracy:.2f}%") 90 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objaworld/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "3-I" 11 | dataset_name: "ObjaWorld" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objaworld/rendering/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 5 59 | weight_first_positions_with_weight: 10.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objaworld/rendering/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objectron/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | dataset_name: "Omni3D-Objectron" 11 | task_type: "I-3" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 128372 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_omni3d_objectron_custom_finer 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objectron/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 20 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objectron/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objaworld/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I-3" 11 | dataset_name: "ObjaWorld" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objaworld/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objaworld/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "3-I" 11 | dataset_name: "CLEVR" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_merged_for_rendering_and_recognition.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_merged_for_rendering_and_recognition.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/clevr/rendering/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 5 59 | weight_first_positions_with_weight: 10.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/clevr/rendering/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I-3" 11 | dataset_name: "CLEVR" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_merged_for_rendering_and_recognition.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_merged_for_rendering_and_recognition.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/clevr/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/clevr/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/question-answering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I+3+Q-A" 11 | dataset_name: "CLEVR" 12 | text_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_vqa_questions.json" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_for_vqa.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_for_vqa.json" 15 | text_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_vqa_answers.json" 16 | image_target: "" 17 | three_d_target: "" 18 | image_token_offset: 129471 19 | no_loss_on_input: True 20 | reorder_image_tokens: False 21 | load_text_source: True 22 | load_image_source: True 23 | load_three_d_source: True 24 | load_text_target: True 25 | load_image_target: False 26 | load_three_d_target: False 27 | 28 | seed: null 29 | shuffle: True 30 | 31 | # Model Arguments 32 | model: 33 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 34 | freeze_llama3_token_embeddings: True 35 | start_from_original_llama3: True 36 | 37 | checkpointer: 38 | _component_: checkpointer.FullModelMetaCheckpointer3D 39 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 40 | checkpoint_files: [ 41 | consolidated.00.pth 42 | ] 43 | recipe_checkpoint: null 44 | output_dir: ./checkpoints/clevr/question-answering/ 45 | model_type: LLAMA3_2 46 | convert_weights_type: 3d_sin_cos_plus_learned_num 47 | resume_from_checkpoint: False 48 | 49 | # Fine-tuning arguments 50 | batch_size: 32 51 | epochs: 10 52 | optimizer: 53 | _component_: bitsandbytes.optim.PagedAdamW8bit 54 | lr: 1e-4 55 | loss: 56 | _component_: torch.nn.CrossEntropyLoss 57 | reduction: none 58 | 59 | weight_first_positions: 0 60 | weight_first_positions_with_weight: 1.0 61 | 62 | max_steps_per_epoch: null 63 | gradient_accumulation_steps: 1 64 | optimizer_in_bwd: False 65 | compile: False 66 | 67 | # Training environment 68 | device: cuda 69 | 70 | # Memory management 71 | enable_activation_checkpointing: False 72 | 73 | # Reduced precision 74 | dtype: bf16 75 | 76 | # Logging 77 | metric_logger: 78 | _component_: torchtune.training.metric_logging.DiskLogger 79 | log_dir: ${output_dir} 80 | output_dir: ./checkpoints/clevr/question-answering/ 81 | log_every_n_steps: 1 82 | log_peak_memory_stats: False 83 | 84 | # Profiler (disabled) 85 | profiler: 86 | _component_: torchtune.training.setup_torch_profiler 87 | enabled: False 88 | 89 | #Output directory of trace artifacts 90 | output_dir: ${output_dir}/profiling_outputs 91 | 92 | #`torch.profiler.ProfilerActivity` types to trace 93 | cpu: True 94 | cuda: True 95 | 96 | #trace options passed to `torch.profiler.profile` 97 | profile_memory: True 98 | with_stack: False 99 | record_shapes: True 100 | with_flops: False 101 | 102 | # `torch.profiler.schedule` options: 103 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 104 | wait_steps: 1 105 | warmup_steps: 2 106 | active_steps: 1 107 | num_cycles: 1 108 | -------------------------------------------------------------------------------- /taming-kyvo.yml: -------------------------------------------------------------------------------- 1 | name: taming-kyvo 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h4bc722e_7 11 | - ca-certificates=2025.2.25=h06a4308_0 12 | - cudatoolkit=11.3.1=hb98b00a_13 13 | - ffmpeg=4.3=hf484d3e_0 14 | - freetype=2.10.4=h0708190_1 15 | - gmp=6.3.0=hac33072_2 16 | - gnutls=3.6.13=h85f3911_1 17 | - icu=73.2=h59595ed_0 18 | - intel-openmp=2023.1.0=hdb19cb5_46306 19 | - jbig=2.1=h7f98852_2003 20 | - jpeg=9e=h0b41bf4_3 21 | - lame=3.100=h166bdaf_1003 22 | - lcms2=2.12=hddcbb42_0 23 | - ld_impl_linux-64=2.40=h12ee557_0 24 | - lerc=2.2.1=h9c3ff4c_0 25 | - libblas=3.9.0=1_h86c2bf4_netlib 26 | - libcblas=3.9.0=8_h3b12eaf_netlib 27 | - libdeflate=1.7=h7f98852_5 28 | - libffi=3.4.4=h6a678d5_1 29 | - libgcc=14.2.0=h767d61c_2 30 | - libgcc-ng=14.2.0=h69a702a_2 31 | - libgfortran=14.2.0=h69a702a_2 32 | - libgfortran-ng=14.2.0=h69a702a_2 33 | - libgfortran5=14.2.0=hf1ad2bd_2 34 | - libgomp=14.2.0=h767d61c_2 35 | - libhwloc=2.11.2=default_h0d58e46_1001 36 | - libiconv=1.18=h4ce23a2_1 37 | - liblapack=3.9.0=8_h3b12eaf_netlib 38 | - libpng=1.6.37=h21135ba_2 39 | - libstdcxx=14.2.0=h8f9b012_2 40 | - libstdcxx-ng=14.2.0=h4852527_2 41 | - libtiff=4.3.0=hf544144_1 42 | - libuv=1.50.0=hb9d3cd8_0 43 | - libwebp-base=1.5.0=h851e524_0 44 | - libxml2=2.13.5=hfdd30dd_0 45 | - lz4-c=1.9.3=h9c3ff4c_1 46 | - mkl=2023.1.0=h213fc3f_46344 47 | - ncurses=6.4=h6a678d5_0 48 | - nettle=3.6=he412f7d_0 49 | - numpy=1.24.4=py38h59b608b_0 50 | - olefile=0.47=pyhd8ed1ab_0 51 | - openh264=2.1.1=h780b84a_0 52 | - openjpeg=2.4.0=hb52868f_1 53 | - openssl=3.4.1=h7b32b05_0 54 | - pillow=8.3.2=py38h8e6f84c_0 55 | - pip=24.2=py38h06a4308_0 56 | - python=3.8.20=he870216_0 57 | - python_abi=3.8=2_cp38 58 | - pytorch=1.10.1=py3.8_cuda11.3_cudnn8.2.0_0 59 | - pytorch-mutex=1.0=cuda 60 | - readline=8.2=h5eee18b_0 61 | - setuptools=75.1.0=py38h06a4308_0 62 | - sqlite=3.45.3=h5eee18b_0 63 | - tbb=2021.13.0=hceb3a55_1 64 | - tk=8.6.14=h39e8969_0 65 | - torchaudio=0.10.1=py38_cu113 66 | - torchvision=0.11.2=py38_cu113 67 | - typing_extensions=4.12.2=pyha770c72_0 68 | - wheel=0.44.0=py38h06a4308_0 69 | - xz=5.6.4=h5eee18b_1 70 | - zlib=1.2.13=h5eee18b_1 71 | - zstd=1.5.0=ha95c52a_0 72 | - pip: 73 | - absl-py==2.1.0 74 | - cachetools==5.5.2 75 | - certifi==2025.1.31 76 | - charset-normalizer==3.4.1 77 | - einops==0.3.0 78 | - fsspec==2025.2.0 79 | - future==1.0.0 80 | - google-auth==2.38.0 81 | - google-auth-oauthlib==1.0.0 82 | - grpcio==1.70.0 83 | - idna==3.10 84 | - importlib-metadata==8.5.0 85 | - markdown==3.7 86 | - markupsafe==2.1.5 87 | - oauthlib==3.2.2 88 | - omegaconf==2.0.0 89 | - protobuf==5.29.3 90 | - pyasn1==0.6.1 91 | - pyasn1-modules==0.4.1 92 | - pytorch-lightning==1.0.8 93 | - pyyaml==6.0.2 94 | - requests==2.32.3 95 | - requests-oauthlib==2.0.0 96 | - rsa==4.9 97 | - six==1.17.0 98 | - tensorboard==2.14.0 99 | - tensorboard-data-server==0.7.2 100 | - tqdm==4.67.1 101 | - urllib3==2.2.3 102 | - werkzeug==3.0.6 103 | - zipp==3.20.2 104 | prefix: /home/aadarsh/anaconda3/envs/taming-kyvo 105 | -------------------------------------------------------------------------------- /kyvo.yml: -------------------------------------------------------------------------------- 1 | name: kyvo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - bzip2=1.0.8=h5eee18b_6 8 | - ca-certificates=2024.12.31=h06a4308_0 9 | - ld_impl_linux-64=2.40=h12ee557_0 10 | - libffi=3.4.4=h6a678d5_1 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - libuuid=1.41.5=h5eee18b_0 15 | - ncurses=6.4=h6a678d5_0 16 | - openssl=3.0.15=h5eee18b_0 17 | - pip=25.0=py311h06a4308_0 18 | - python=3.11.11=he870216_0 19 | - readline=8.2=h5eee18b_0 20 | - setuptools=75.8.0=py311h06a4308_0 21 | - sqlite=3.45.3=h5eee18b_0 22 | - tk=8.6.14=h39e8969_0 23 | - wheel=0.45.1=py311h06a4308_0 24 | - xz=5.6.4=h5eee18b_1 25 | - zlib=1.2.13=h5eee18b_1 26 | - pip: 27 | - absl-py==2.1.0 28 | - aiohappyeyeballs==2.4.6 29 | - aiohttp==3.11.12 30 | - aiosignal==1.3.2 31 | - antlr4-python3-runtime==4.9.3 32 | - attrs==25.1.0 33 | - bitsandbytes==0.43.1 34 | - blobfile==3.0.0 35 | - certifi==2025.1.31 36 | - charset-normalizer==3.4.1 37 | - datasets==3.3.1 38 | - dill==0.3.8 39 | - filelock==3.17.0 40 | - frozenlist==1.5.0 41 | - fsspec==2024.12.0 42 | - future==1.0.0 43 | - grpcio==1.70.0 44 | - huggingface-hub==0.28.1 45 | - idna==3.10 46 | - imageio==2.37.0 47 | - jinja2==3.1.5 48 | - lazy-loader==0.4 49 | - lxml==5.3.1 50 | - markdown==3.7 51 | - markupsafe==3.0.2 52 | - mpmath==1.3.0 53 | - multidict==6.1.0 54 | - multiprocess==0.70.16 55 | - networkx==3.4.2 56 | - numpy==2.2.3 57 | - nvidia-cublas-cu12==12.1.3.1 58 | - nvidia-cuda-cupti-cu12==12.1.105 59 | - nvidia-cuda-nvrtc-cu12==12.1.105 60 | - nvidia-cuda-runtime-cu12==12.1.105 61 | - nvidia-cudnn-cu12==9.1.0.70 62 | - nvidia-cufft-cu12==11.0.2.54 63 | - nvidia-curand-cu12==10.3.2.106 64 | - nvidia-cusolver-cu12==11.4.5.107 65 | - nvidia-cusparse-cu12==12.1.0.106 66 | - nvidia-cusparselt-cu12==0.6.2 67 | - nvidia-nccl-cu12==2.20.5 68 | - nvidia-nvjitlink-cu12==12.4.127 69 | - nvidia-nvtx-cu12==12.1.105 70 | - omegaconf==2.3.0 71 | - packaging==24.2 72 | - pandas==2.2.3 73 | - pillow==11.1.0 74 | - propcache==0.2.1 75 | - protobuf==6.30.0 76 | - psutil==7.0.0 77 | - pyarrow==19.0.1 78 | - pycryptodomex==3.21.0 79 | - python-dateutil==2.9.0.post0 80 | - pytorch-lightning==1.0.8 81 | - pytz==2025.1 82 | - pyyaml==6.0.2 83 | - regex==2024.11.6 84 | - requests==2.32.3 85 | - safetensors==0.5.2 86 | - scikit-image==0.25.2 87 | - scipy==1.15.2 88 | - sentencepiece==0.2.0 89 | - six==1.17.0 90 | - sympy==1.13.1 91 | - tensorboard==2.19.0 92 | - tensorboard-data-server==0.7.2 93 | - tifffile==2025.2.18 94 | - tiktoken==0.9.0 95 | - tokenizers==0.19.1 96 | - torch==2.4.1 97 | - torchao==0.5.0 98 | - torchvision==0.19.1 99 | - tqdm==4.67.1 100 | - transformers==4.41.2 101 | - triton==3.0.0 102 | - typing-extensions==4.12.2 103 | - tzdata==2025.1 104 | - urllib3==2.3.0 105 | - werkzeug==3.1.3 106 | - xxhash==3.5.0 107 | - yarl==1.18.3 108 | prefix: /home/aadarsh/anaconda3/envs/kyvo 109 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/instruction-following.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I+3+T-I+3" 11 | dataset_name: "CLEVR" 12 | text_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_text_instructions.json" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_source.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_source.json" 15 | image_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_target.json" 16 | three_d_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_target.json" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: True 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: True 25 | load_three_d_target: True 26 | 27 | 28 | seed: null 29 | shuffle: True 30 | 31 | # Model Arguments 32 | model: 33 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 34 | freeze_llama3_token_embeddings: True 35 | start_from_original_llama3: True 36 | 37 | checkpointer: 38 | _component_: checkpointer.FullModelMetaCheckpointer3D 39 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 40 | checkpoint_files: [ 41 | consolidated.00.pth 42 | ] 43 | recipe_checkpoint: null 44 | output_dir: ./checkpoints/clevr/instruction-following/ 45 | model_type: LLAMA3_2 46 | convert_weights_type: 3d_sin_cos_plus_learned_num 47 | resume_from_checkpoint: False 48 | 49 | # Fine-tuning arguments 50 | batch_size: 16 51 | epochs: 10 52 | optimizer: 53 | _component_: bitsandbytes.optim.PagedAdamW8bit 54 | lr: 1e-4 55 | loss: 56 | _component_: torch.nn.CrossEntropyLoss 57 | reduction: none 58 | 59 | weight_first_positions: 5 60 | weight_first_positions_with_weight: 10.0 61 | 62 | max_steps_per_epoch: null 63 | gradient_accumulation_steps: 1 64 | optimizer_in_bwd: False 65 | compile: False 66 | 67 | # Training environment 68 | device: cuda 69 | 70 | # Memory management 71 | enable_activation_checkpointing: False 72 | 73 | # Reduced precision 74 | dtype: bf16 75 | 76 | # Logging 77 | metric_logger: 78 | _component_: torchtune.training.metric_logging.DiskLogger 79 | log_dir: ${output_dir} 80 | output_dir: ./checkpoints/clevr/instruction-following/ 81 | log_every_n_steps: 1 82 | log_peak_memory_stats: False 83 | 84 | # Profiler (disabled) 85 | profiler: 86 | _component_: torchtune.training.setup_torch_profiler 87 | enabled: False 88 | 89 | #Output directory of trace artifacts 90 | output_dir: ${output_dir}/profiling_outputs 91 | 92 | #`torch.profiler.ProfilerActivity` types to trace 93 | cpu: True 94 | cuda: True 95 | 96 | #trace options passed to `torch.profiler.profile` 97 | profile_memory: True 98 | with_stack: False 99 | record_shapes: True 100 | with_flops: False 101 | 102 | # `torch.profiler.schedule` options: 103 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 104 | wait_steps: 1 105 | warmup_steps: 2 106 | active_steps: 1 107 | num_cycles: 1 108 | -------------------------------------------------------------------------------- /scripts/taming_transformers_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from omegaconf import OmegaConf 4 | from taming.models.vqgan import VQModel, GumbelVQ 5 | import io 6 | import requests 7 | import PIL 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | import torchvision.transforms.functional as TF 15 | 16 | 17 | def load_config(config_path, display=False): 18 | config = OmegaConf.load(config_path) 19 | if display: 20 | print(yaml.dump(OmegaConf.to_container(config))) 21 | return config 22 | 23 | 24 | def load_vqgan(config, ckpt_path=None, is_gumbel=False): 25 | if is_gumbel: 26 | model = GumbelVQ(**config.model.params) 27 | else: 28 | model = VQModel(**config.model.params) 29 | if ckpt_path is not None: 30 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 31 | missing, unexpected = model.load_state_dict(sd, strict=False) 32 | return model.eval() 33 | 34 | 35 | def preprocess_vqgan(x): 36 | x = 2.0 * x - 1.0 37 | return x 38 | 39 | 40 | def custom_to_pil(x): 41 | x = x.detach().cpu() 42 | x = torch.clamp(x, -1.0, 1.0) 43 | x = (x + 1.0) / 2.0 44 | x = x.permute(1, 2, 0).numpy() 45 | x = (255 * x).astype(np.uint8) 46 | x = Image.fromarray(x) 47 | if not x.mode == "RGB": 48 | x = x.convert("RGB") 49 | return x 50 | 51 | 52 | def reconstruct_with_vqgan(x, model, reconstruct=False): 53 | # could also use model(x) for reconstruction but use explicit encoding and decoding here 54 | z, _, [_, _, indices] = model.encode(x) 55 | if reconstruct: 56 | xrec = model.decode(z) 57 | return xrec, indices 58 | else: 59 | return z, indices 60 | 61 | 62 | def download_image(url): 63 | resp = requests.get(url) 64 | resp.raise_for_status() 65 | return PIL.Image.open(io.BytesIO(resp.content)) 66 | 67 | 68 | def get_local_image(path, make_square=True, size=320, horizontal_flip=False): 69 | img = PIL.Image.open(path) 70 | if img.mode == "RGBA": 71 | img = img.convert("RGB") 72 | if make_square: 73 | img = img.resize((size, size)) 74 | if horizontal_flip: 75 | img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT) 76 | return img 77 | 78 | 79 | def preprocess(img, target_image_size=320): 80 | s = min(img.size) 81 | 82 | if s < target_image_size: 83 | raise ValueError(f"min dim for image {s} < {target_image_size}") 84 | 85 | r = target_image_size / s 86 | s = (round(r * img.size[1]), round(r * img.size[0])) 87 | img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) 88 | img = TF.center_crop(img, output_size=2 * [target_image_size]) 89 | img = torch.unsqueeze(T.ToTensor()(img), 0) 90 | 91 | return img 92 | 93 | 94 | def reconstruction_pipeline( 95 | path, 96 | vqgan_model, 97 | size=320, 98 | DEVICE="cuda:0", 99 | reconstruct=False, 100 | make_square=True, 101 | horizontal_flip=False, 102 | ): 103 | x_vqgan = preprocess( 104 | get_local_image( 105 | path, make_square=make_square, size=size, horizontal_flip=horizontal_flip 106 | ), 107 | target_image_size=size, 108 | ) 109 | x_vqgan = x_vqgan.to(DEVICE) 110 | vqgan_embedding, vqgan_indices = reconstruct_with_vqgan( 111 | preprocess_vqgan(x_vqgan), vqgan_model, reconstruct=reconstruct 112 | ) 113 | return vqgan_embedding, vqgan_indices 114 | -------------------------------------------------------------------------------- /scripts/compute_ssim_l2loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage.io import imread 4 | from skimage.metrics import structural_similarity as ssim 5 | import glob 6 | from tqdm import tqdm 7 | import re 8 | from PIL import Image 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description="Compute L2 Loss and SSIM") 12 | 13 | parser.add_argument( 14 | "--generated_folder", 15 | required=True, 16 | type=str, 17 | help="Path to the folder containing the generated images", 18 | ) 19 | 20 | parser.add_argument( 21 | "--groundtruth_folder", 22 | required=True, 23 | type=str, 24 | help="Path to the folder containing the ground truth images", 25 | ) 26 | 27 | parser.add_argument( 28 | "--output_folder", 29 | required=True, 30 | type=str, 31 | help="Path to the folder where the results will be saved", 32 | ) 33 | 34 | 35 | args = parser.parse_args() 36 | 37 | 38 | def compute_l2_loss(img1, img2): 39 | return np.mean((img1 - img2) ** 2) 40 | 41 | 42 | groundtruth_folder = args.groundtruth_folder 43 | generated_folder = args.generated_folder 44 | 45 | # Initialize accumulators for L2 loss and SSIM 46 | total_l2_loss = 0 47 | total_ssim = 0 48 | image_count = 0 49 | 50 | # Get sorted list of image files 51 | groundtruth_files = sorted(glob.glob(os.path.join(groundtruth_folder, "*.png"))) 52 | generated_files = sorted(glob.glob(os.path.join(generated_folder, "*.png"))) 53 | 54 | groundtruth_files = groundtruth_files[: len(generated_files)] 55 | 56 | 57 | if len(generated_files) == 0: 58 | raise ValueError("No images found in GENERATED folder.") 59 | 60 | 61 | # Iterate through the image pairs 62 | print("Computing L2 Loss and SSIM for image pairs...") 63 | for gt_file, gen_file in tqdm( 64 | zip(groundtruth_files, generated_files), total=len(groundtruth_files) 65 | ): 66 | # Assert filenames match up to the folder and prefix differences 67 | gt_filename = os.path.basename(gt_file) 68 | gen_filename = os.path.basename(gen_file) 69 | 70 | assert ( 71 | gt_filename == gen_filename.split("generated_")[1] 72 | ), f"File names do not match: {gt_filename} and {gen_filename}" 73 | 74 | # Load images 75 | groundtruth_img = imread(gt_file, as_gray=True) # Load as grayscale for SSIM 76 | generated_img = imread(gen_file, as_gray=True) 77 | 78 | # resize groundtruth image to generated image size 79 | groundtruth_img = np.array( 80 | Image.fromarray(groundtruth_img).resize( 81 | (generated_img.shape[0], generated_img.shape[1]) 82 | ) 83 | ) 84 | 85 | # Ensure the images have the same shape 86 | if groundtruth_img.shape != generated_img.shape: 87 | raise ValueError(f"Image shapes do not match: {gt_file} and {gen_file}") 88 | 89 | # Compute L2 loss and SSIM 90 | l2_loss = compute_l2_loss(groundtruth_img, generated_img) 91 | image_ssim = ssim( 92 | groundtruth_img, 93 | generated_img, 94 | data_range=generated_img.max() - generated_img.min(), 95 | ) 96 | 97 | # Accumulate results 98 | total_l2_loss += l2_loss 99 | total_ssim += image_ssim 100 | image_count += 1 101 | 102 | # Compute averages 103 | average_l2_loss = total_l2_loss / image_count 104 | average_ssim = total_ssim / image_count 105 | 106 | # Print results 107 | print("\nComputation Complete!") 108 | print(f"Average L2 Loss: {average_l2_loss}") 109 | print(f"Average SSIM: {average_ssim}") 110 | 111 | # Write results to a file in the same folder 112 | with open(f"{args.output_folder}/ssim_l2.txt", "w") as f: 113 | f.write(f"Average L2 Loss, Average SSIM:\n") 114 | f.write(f"{average_l2_loss}, {average_ssim}\n") 115 | f.write(f"Total L2 Loss, Total SSIM:\n") 116 | f.write(f"{total_l2_loss}, {total_ssim}\n") 117 | f.write(f"Image Count: {image_count}\n") 118 | -------------------------------------------------------------------------------- /scripts/decode_image_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from omegaconf import OmegaConf 4 | import argparse 5 | import os 6 | import torch 7 | from tqdm import tqdm 8 | import json 9 | import numpy as np 10 | 11 | from taming_transformers_utils import * 12 | 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="Decode embeddings to images") 16 | 17 | parser.add_argument( 18 | "--folder_path", 19 | required=True, 20 | type=str, 21 | help="Path to the folder containing the embeddings", 22 | ) 23 | 24 | parser.add_argument( 25 | "--vqgan_type", 26 | required=True, 27 | type=str, 28 | help="Type of VQGAN model used for training: choose from [clevr, objaworld, objectron, domain-agnostic]", 29 | ) 30 | 31 | parser.add_argument( 32 | "--image_output_path", 33 | required=True, 34 | type=str, 35 | help="Path to the folder where the decoded images will be saved", 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | # add arguments 41 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | torch.set_grad_enabled(False) 44 | 45 | is_gumbel = False 46 | if args.vqgan_type == "clevr": 47 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/clevr/2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/configs/2024-10-10T09-21-36-project.yaml" 48 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/clevr/2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/checkpoints/last.ckpt" 49 | elif args.vqgan_type == "objaworld": 50 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objaworld/2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/configs/2025-01-17T09-02-22-project.yaml" 51 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objaworld/2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/checkpoints/last.ckpt" 52 | elif args.vqgan_type == "objectron": 53 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objectron/2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/configs/2024-11-03T05-41-42-project.yaml" 54 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objectron/2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/checkpoints/last.ckpt" 55 | elif args.vqgan_type == "domain-agnostic": 56 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/domain-agnostic/vqgan_gumbel_f8/configs/model.yaml" 57 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/domain-agnostic/vqgan_gumbel_f8/checkpoints/last.ckpt" 58 | is_gumbel = True 59 | 60 | config = load_config( 61 | config_path, 62 | display=False, 63 | ) 64 | model = load_vqgan( 65 | config, 66 | ckpt_path=ckpt_path, 67 | is_gumbel=is_gumbel, 68 | ).to(DEVICE) 69 | 70 | 71 | folder_path = args.folder_path 72 | 73 | groundtruth_names = [ 74 | f for f in os.listdir(folder_path) if "ground_truth" in f and f.endswith(".npy") 75 | ] 76 | generated_names = [ 77 | f for f in os.listdir(folder_path) if "generated" in f and f.endswith(".npy") 78 | ] 79 | 80 | # natural sort 81 | groundtruth_names.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) 82 | generated_names.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) 83 | 84 | folder_identifier = os.path.basename(folder_path) 85 | groundtruth_key_names = [] 86 | for name in groundtruth_names: 87 | key = name.split(".png")[0].split(folder_identifier)[-1] 88 | groundtruth_key_names.append(str(key)) 89 | 90 | generated_key_names = [] 91 | for name in generated_names: 92 | key = name.split(".png")[0].split(folder_identifier)[-1] 93 | generated_key_names.append(str(key)) 94 | 95 | assert len(groundtruth_names) == len(generated_names) 96 | print("Total number of embeddings to decode: ", len(groundtruth_names)) 97 | 98 | for img_idx in tqdm(range(len(groundtruth_names)), desc="Decoding images"): 99 | # load embeddings from npy files and convert to torch tensors 100 | groundtruth_embeddings = torch.tensor( 101 | np.load(os.path.join(folder_path, groundtruth_names[img_idx])) 102 | ).to(DEVICE) 103 | generated_embeddings = torch.tensor( 104 | np.load(os.path.join(folder_path, generated_names[img_idx])) 105 | ).to(DEVICE) 106 | 107 | xrec_groundtruth = model.decode(groundtruth_embeddings.unsqueeze(0)) 108 | xrec_generated = model.decode(generated_embeddings.unsqueeze(0)) 109 | 110 | if not os.path.exists(os.path.join(args.image_output_path, "GROUNDTRUTH")): 111 | os.makedirs(os.path.join(args.image_output_path, "GROUNDTRUTH")) 112 | 113 | if not os.path.exists(os.path.join(args.image_output_path, "GENERATED")): 114 | os.makedirs(os.path.join(args.image_output_path, "GENERATED")) 115 | 116 | custom_to_pil(xrec_groundtruth[0]).save( 117 | os.path.join( 118 | args.image_output_path, 119 | "GROUNDTRUTH", 120 | f"ground_truth{groundtruth_key_names[img_idx]}.png", 121 | ) 122 | ) 123 | custom_to_pil(xrec_generated[0]).save( 124 | os.path.join( 125 | args.image_output_path, 126 | "GENERATED", 127 | f"generated{generated_key_names[img_idx]}.png", 128 | ) 129 | ) 130 | -------------------------------------------------------------------------------- /DATA.md: -------------------------------------------------------------------------------- 1 | # Kyvo Dataset and Codebooks Details 2 | 3 | This document provides details about the dataset and codebooks provided in the `kyvo-datasets-and-codebooks` repository. We will provide the details about each of the folders in the repository and the contents of each folder. 4 | 5 | ## Data Generation Pipeline 6 | 7 | The pipeline that we follow to generate the pre-tokenized data is as follows: 8 | 9 | * **3D Scenes**: 3D Scene JSON --> Serialized 3D Scene --> Tokenized 3D Scene 10 | * **Images**: Image --> VQGAN Codebook Indices --> Tokenized Image 11 | * **Text**: Text --> Tokenized Text 12 | 13 | 14 | 15 | 16 | ## Pre-tokenized Data 17 | 18 | The `pretokenized-data` folder contains all the pre-tokenized data for the datasets used in the Kyvo project. The pre-tokenized data is stored in the following structure: 19 | 20 | ```python 21 | pretokenized-data/ 22 | |-- clevr/ 23 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for CLEVR for all tasks 24 | | |-- images/ # contains all pre-tokenized images for CLEVR for all tasks 25 | | |-- text/ # contains all pre-tokenized text for CLEVR for all tasks 26 | |-- objaworld/ 27 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for ObjaWorld for all tasks 28 | | |-- images/ # contains all pre-tokenized images for ObjaWorld for all tasks 29 | |-- objectron/ 30 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for Objectron for all tasks 31 | | |-- images/ # contains all pre-tokenized images for Objectron for all tasks 32 | ``` 33 | 34 | 35 | For a given task, an input can be any combination of 3d-scenes, images, and text. The output can be any combination of images, text, and 3d-scenes. In the following table we outline the tasks for each dataset and the corresponding input and output data that are needed for each task. 36 | 37 | | **Task** | **Input Image** | **Input 3D Scene** | **Input Text** | **Output Image** | **Output 3D Scene** | **Output Text** | 38 | |:----------------------:|:------------------:|:----------------------:|:-----------------:|:------------------:|:-----------------------:|:-----------------:| 39 | | **CLEVR** | | | | | | | 40 | | Rendering | 𐄂 | ✓ | 𐄂 | ✓ | 𐄂 | 𐄂 | 41 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 42 | | Instruction-Following | ✓ | ✓ | ✓ | ✓ | ✓ | 𐄂 | 43 | | Question-Answering | ✓ | ✓ | ✓ | 𐄂 | 𐄂 | ✓ | 44 | | | | | | | | | 45 | | **ObjaWorld** | | | | | | | 46 | | Rendering | 𐄂 | ✓ | 𐄂 | ✓ | 𐄂 | 𐄂 | 47 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 48 | | | | | | | | | 49 | | **Objectron** | | | | | | | 50 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 51 | 52 | For the exact files that correspond to the input and output data for each task, please refer to the corresponding configuration files in the `configs/llama3_2/train` folder. 53 | 54 | 55 | ## VQGAN Models and Codebooks 56 | 57 | The `vqgan-models-and-codebooks` folder contains all the VQGAN model checkpoints and codebooks for the datasets used in the Kyvo project. The VQGAN model checkpoints and codebooks are stored in the following structure: 58 | 59 | ```python 60 | vqgan-models-and-codebooks/ 61 | |-- clevr/ 62 | | |-- 2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/ # contains the VQGAN model checkpoint for CLEVR 63 | | |-- custom_vqgan_embedding_1024CLEVRLARGE_256dim.npy # contains the VQGAN codebook for CLEVR 64 | |-- objaworld/ 65 | | |-- 2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/ # contains the VQGAN model checkpoint for ObjaWorld 66 | | |-- custom_vqgan_embedding_256SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100_256dim.npy # contains the VQGAN codebook for ObjaWorld 67 | |-- objectron/ 68 | | |-- 2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/ # contains the VQGAN model checkpoint for Objectron 69 | | |-- custom_vqgan_embedding_256Omni3D-OBJECTRON_256dim.npy # contains the VQGAN codebook for Objectron 70 | ``` 71 | 72 | ## Images and Scenes for Evaluation 73 | 74 | The `images-and-scenes-for-evaluation` folder contains all the groundtruth images and scenes for the datasets used in the Kyvo project. The images and scenes are used to compute the evaluation metrics for the different tasks. The images and scenes are stored in the following structure: 75 | 76 | ```python 77 | images-and-scenes-for-evaluation/ 78 | |-- clevr/ # contains all images and scenes for evaluation for CLEVR 79 | |-- objaworld/ # contains all images and scenes for evaluation for ObjaWorld 80 | |-- objectron/ # contains all scenes for evaluation for Objectron 81 | ``` 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /scripts/compute_jaccard_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description="Compute Jaccard Index") 8 | 9 | parser.add_argument( 10 | "--generated_folder", 11 | required=True, 12 | type=str, 13 | help="Path to the folder containing the predicted scenes", 14 | ) 15 | 16 | parser.add_argument( 17 | "--groundtruth_folder", 18 | required=True, 19 | type=str, 20 | help="Path to the folder containing the ground truth scenes", 21 | ) 22 | 23 | parser.add_argument( 24 | "--tau", 25 | required=True, 26 | type=float, 27 | default=0.25, 28 | help="Distance threshold for matching objects", 29 | ) 30 | 31 | parser.add_argument( 32 | "--dataset", 33 | required=True, 34 | type=str, 35 | help="Dataset used for training: choose from [clevr, objaworld]", 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | # Function to compute Euclidean distance between two 2D coordinates 42 | def dist(coords1, coords2): 43 | return np.linalg.norm(np.array(coords1) - np.array(coords2)) 44 | 45 | 46 | # Function to compute the Jaccard Index 47 | def compute_jaccard_index(ground_truth_scenes, predicted_scenes, tau): 48 | 49 | total_number_of_scenes = len(ground_truth_scenes) 50 | assert total_number_of_scenes == len(predicted_scenes) 51 | total_jaccard_index = 0 52 | 53 | for gt_scene, pred_scene in zip(ground_truth_scenes, predicted_scenes): 54 | true_positives = 0 55 | false_positives = 0 56 | false_negatives = 0 57 | 58 | if args.dataset == "clevr": 59 | assert ( 60 | gt_scene["image_filename"] == pred_scene["key"] 61 | or gt_scene["image_filename"].split(".png")[0] 62 | == pred_scene["key"].split(".png")[0] 63 | ), f"Ground truth and predicted scene mismatch: {gt_scene['image_filename']} vs {pred_scene['key']}" 64 | elif args.dataset == "objaworld": 65 | assert ( 66 | gt_scene["image_filename"].split(".png")[0] 67 | == pred_scene["key"].split(".png_")[0] 68 | ), f"Ground truth and predicted scene mismatch: {gt_scene['image_filename']} vs {pred_scene['key']}" 69 | 70 | gt_objects = gt_scene["objects"] 71 | pred_objects = pred_scene["objects"] 72 | 73 | # Create flags to track matched ground truth objects 74 | gt_matched = [False] * len(gt_objects) 75 | 76 | # Match predictions to ground truth 77 | for pred_obj in pred_objects: 78 | matched = False 79 | for i, gt_obj in enumerate(gt_objects): 80 | if not gt_matched[i]: # Check if ground truth object is unmatched 81 | # Check attribute equality and distance condition 82 | try: 83 | if args.dataset == "clevr": 84 | if ( 85 | pred_obj["size"] == gt_obj["size"] 86 | and pred_obj["color"] == gt_obj["color"] 87 | and pred_obj["material"] == gt_obj["material"] 88 | and pred_obj["shape"] == gt_obj["shape"] 89 | and dist( 90 | pred_obj["3d_coords"][:2], gt_obj["3d_coords"][:2] 91 | ) 92 | < tau 93 | ): 94 | gt_matched[i] = True # Mark ground truth as matched 95 | matched = True 96 | true_positives += 1 97 | break 98 | elif args.dataset == "objaworld": 99 | if ( 100 | pred_obj["shape"] == gt_obj["shape"].split("-")[0] 101 | and dist( 102 | pred_obj["3d_coords"][:2], gt_obj["3d_coords"][:2] 103 | ) 104 | < tau 105 | and abs( 106 | pred_obj["3d_coords"][2] - gt_obj["3d_coords"][2] 107 | ) 108 | < 0.05 109 | and abs(pred_obj["3d_coords"][3] - gt_obj["rotation"]) 110 | < 0.15 111 | ): 112 | gt_matched[i] = True # Mark ground truth as matched 113 | matched = True 114 | true_positives += 1 115 | break 116 | except: 117 | continue # Skip if any attribute is missing 118 | if not matched: 119 | false_positives += 1 120 | 121 | # Count unmatched ground truths 122 | false_negatives += gt_matched.count(False) 123 | 124 | if (true_positives + false_positives + false_negatives) == 0: 125 | jaccard_index = -1 126 | else: 127 | jaccard_index = true_positives / ( 128 | true_positives + false_positives + false_negatives 129 | ) 130 | total_jaccard_index += jaccard_index 131 | 132 | # Compute Jaccard Index 133 | 134 | jaccard_index = total_jaccard_index / total_number_of_scenes 135 | return jaccard_index 136 | 137 | 138 | # find the predicted scenes json file 139 | folder = os.path.basename(args.generated_folder) 140 | print(f"Computing Jaccard Index for {folder}") 141 | with open(f"{args.generated_folder}/predicted_scenes_{folder}.json", "r") as f: 142 | predicted_data = json.load(f) 143 | 144 | ground_truth_data = {"scenes": []} 145 | 146 | if args.dataset == "clevr": 147 | all_gt_jsons = os.listdir(args.groundtruth_folder) 148 | all_gt_jsons = sorted(all_gt_jsons) 149 | for gt_json in all_gt_jsons: 150 | with open(f"{args.groundtruth_folder}/{gt_json}", "r") as f: 151 | ground_truth_data["scenes"].append(json.load(f)) 152 | elif args.dataset == "objaworld": 153 | for pred_scene in predicted_data["scenes"]: 154 | with open( 155 | "{}/output-{}-large-10K-6/scenes/{}.json".format( 156 | args.groundtruth_folder, 157 | pred_scene["key"].split("png_")[1], 158 | pred_scene["key"].split(".png_")[0], 159 | ), 160 | "r", 161 | ) as f: 162 | ground_truth_data["scenes"].append(json.load(f)) 163 | 164 | ground_truth_scenes = ground_truth_data["scenes"] 165 | predicted_scenes = predicted_data["scenes"] 166 | 167 | ground_truth_scenes = ground_truth_scenes[: len(predicted_scenes)] 168 | 169 | jaccard_index = compute_jaccard_index( 170 | ground_truth_scenes, predicted_scenes, tau=args.tau 171 | ) 172 | 173 | print(f"Jaccard Index: {jaccard_index:.4f}") 174 | 175 | # print the Jaccard Index to a file in the same folder 176 | with open(f"{args.generated_folder}/jaccard_index_tau-{args.tau}.txt", "w") as f: 177 | f.write(f"Jaccard Index: {jaccard_index:.4f}") 178 | -------------------------------------------------------------------------------- /scripts/compute_jaccard_index_objectron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--tau", type=float, default=0.25) 12 | parser.add_argument("--dimensions_tau", type=float, default=0.05) 13 | parser.add_argument("--predictions_file", type=str, required=True) 14 | parser.add_argument("--groundtruth_file", type=str, required=True) 15 | parser.add_argument( 16 | "--method", 17 | type=str, 18 | required=True, 19 | help="Method used for predictions - 'cube-rcnn' or 'kyvo'", 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | # load the file with all data 26 | with open(args.groundtruth_file) as f: 27 | data = json.load(f) 28 | 29 | # load the prediction json 30 | with open( 31 | args.predictions_file, 32 | ) as f: 33 | predictions = json.load(f) 34 | 35 | # preprocess the data 36 | image_id_to_objects = {} 37 | for ann in data["annotations"]: 38 | image_id = ann["image_id"] 39 | if image_id not in image_id_to_objects: 40 | image_id_to_objects[image_id] = {"objects": []} 41 | image_id_to_objects[image_id]["objects"].append( 42 | { 43 | "category_id": ann["category_id"], 44 | "category": ann["category_name"], 45 | "center_cam": ann["center_cam"], 46 | "dimensions": ann["dimensions"], 47 | } 48 | ) 49 | 50 | 51 | data_filepath_to_image_id = {} 52 | for image in data["images"]: 53 | if image["file_path"] in data_filepath_to_image_id: 54 | raise ValueError("Duplicate file path found") 55 | data_filepath_to_image_id[image["file_path"]] = image["id"] 56 | 57 | 58 | final_preprocessed_data = {} 59 | for image_id, objects in image_id_to_objects.items(): 60 | if image_id not in final_preprocessed_data: 61 | final_preprocessed_data[image_id] = {} 62 | for obj in objects["objects"]: 63 | if obj["category"] not in final_preprocessed_data[image_id]: 64 | final_preprocessed_data[image_id][obj["category"]] = { 65 | "num": 0, 66 | "center_cam_3d": [], 67 | "dimensions": [], 68 | } 69 | final_preprocessed_data[image_id][obj["category"]]["num"] += 1 70 | final_preprocessed_data[image_id][obj["category"]]["center_cam_3d"].append( 71 | np.array(obj["center_cam"]) 72 | ) 73 | final_preprocessed_data[image_id][obj["category"]]["dimensions"].append( 74 | np.array(obj["dimensions"]) 75 | ) 76 | 77 | data_category_id_to_name = {} 78 | for cat in data["categories"]: 79 | data_category_id_to_name[cat["id"]] = cat["name"] 80 | 81 | 82 | if args.method == "cube-rcnn": 83 | cube_rcnn_predictions_preprocessed = {} 84 | for pred in predictions: 85 | image_id = pred["image_id"] 86 | if image_id not in cube_rcnn_predictions_preprocessed: 87 | cube_rcnn_predictions_preprocessed[image_id] = {"objects": []} 88 | cube_rcnn_predictions_preprocessed[image_id]["objects"].append( 89 | { 90 | "category_id": pred["category_id"], 91 | "category": data_category_id_to_name[pred["category_id"]], 92 | "center_cam": pred["center_cam"], 93 | "dimensions": pred["dimensions"], 94 | } 95 | ) 96 | 97 | final_preprocessed_prediction_data = {} 98 | for image_id, objects in cube_rcnn_predictions_preprocessed.items(): 99 | if image_id not in final_preprocessed_prediction_data: 100 | final_preprocessed_prediction_data[image_id] = {} 101 | for obj in objects["objects"]: 102 | if obj["category"] not in final_preprocessed_prediction_data[image_id]: 103 | final_preprocessed_prediction_data[image_id][obj["category"]] = { 104 | "num": 0, 105 | "center_cam_3d": [], 106 | "dimensions": [], 107 | } 108 | final_preprocessed_prediction_data[image_id][obj["category"]]["num"] += 1 109 | final_preprocessed_prediction_data[image_id][obj["category"]][ 110 | "center_cam_3d" 111 | ].append(np.array(obj["center_cam"])) 112 | final_preprocessed_prediction_data[image_id][obj["category"]][ 113 | "dimensions" 114 | ].append(np.array(obj["dimensions"])) 115 | elif args.method == "kyvo": 116 | final_preprocessed_prediction_data = {} 117 | for pred_scene in predictions["scenes"]: 118 | pred_objects = pred_scene["objects"] 119 | image_id = data_filepath_to_image_id[pred_scene["key"]] 120 | if image_id not in final_preprocessed_prediction_data: 121 | final_preprocessed_prediction_data[image_id] = {} 122 | for obj in pred_objects: 123 | if obj["category"] not in final_preprocessed_prediction_data[image_id]: 124 | final_preprocessed_prediction_data[image_id][obj["category"]] = { 125 | "num": 0, 126 | "center_cam_3d": [], 127 | "dimensions": [], 128 | } 129 | final_preprocessed_prediction_data[image_id][obj["category"]]["num"] += 1 130 | final_preprocessed_prediction_data[image_id][obj["category"]][ 131 | "center_cam_3d" 132 | ].append(np.array(obj["center_cam"])) 133 | final_preprocessed_prediction_data[image_id][obj["category"]][ 134 | "dimensions" 135 | ].append(np.array(obj["dimensions"])) 136 | 137 | # only keep the common keys 138 | common_keys = set(final_preprocessed_data.keys()).intersection( 139 | set(final_preprocessed_prediction_data.keys()) 140 | ) 141 | 142 | final_preprocessed_data_common = {k: final_preprocessed_data[k] for k in common_keys} 143 | 144 | final_preprocessed_prediction_data_common = { 145 | k: final_preprocessed_prediction_data[k] for k in common_keys 146 | } 147 | 148 | print("Total common keys: ", len(final_preprocessed_data_common.keys())) 149 | 150 | 151 | import numpy as np 152 | from scipy.spatial.distance import euclidean 153 | 154 | 155 | def compute_jaccard(predictions, groundtruths, tau, dimensions_tau): 156 | """ 157 | Compute the Jaccard metric based on predictions and ground truths. 158 | 159 | Parameters: 160 | predictions (list of dict): List of dictionaries containing predicted objects with their category and coordinates. 161 | groundtruths (list of dict): List of dictionaries containing ground truth objects with their category and coordinates. 162 | tau (float): Distance threshold for matching predictions with ground truths. 163 | 164 | Returns: 165 | float: Jaccard metric (tp / (tp + fp + fn)). 166 | """ 167 | true_positives = 0 168 | false_positives = 0 169 | false_negatives = 0 170 | 171 | # Iterate over each prediction 172 | for prediction in predictions: 173 | matched = set() # Track matched ground truth objects 174 | 175 | # Iterate over each predicted object 176 | for category, pred_data in prediction.items(): 177 | pred_count = pred_data["num"] 178 | pred_centers = pred_data["center_cam_3d"] 179 | pred_dimensions = pred_data["dimensions"] 180 | 181 | # Check if the category exists in ground truths 182 | if category in groundtruths[0]: 183 | gt_data = groundtruths[0][category] 184 | gt_count = gt_data["num"] 185 | gt_centers = gt_data["center_cam_3d"] 186 | gt_dimensions = gt_data["dimensions"] 187 | matched_ground_truths = [False] * gt_count 188 | 189 | # Attempt to match each predicted object with a ground truth 190 | for j, pred_center in enumerate(pred_centers): 191 | best_match_index = -1 192 | best_distance = float("inf") 193 | 194 | # Find the closest unmatched ground truth 195 | for i, gt_center in enumerate(gt_centers): 196 | if not matched_ground_truths[i]: 197 | distance = euclidean(pred_center, gt_center) 198 | dimensions_errors = np.mean( 199 | np.abs(pred_dimensions[j] - gt_dimensions[i]) 200 | ) 201 | if ( 202 | distance < tau 203 | and distance < best_distance 204 | and dimensions_errors <= dimensions_tau 205 | ): 206 | best_distance = distance 207 | best_match_index = i 208 | 209 | # If a match is found, count it as a true positive 210 | if best_match_index != -1: 211 | true_positives += 1 212 | matched_ground_truths[best_match_index] = True 213 | matched.add(best_match_index) 214 | else: 215 | false_positives += 1 216 | 217 | # Add unmatched ground truth objects to false negatives 218 | false_negatives += gt_count - sum(matched_ground_truths) 219 | else: 220 | # All predicted objects of this category are false positives if no matching ground truths 221 | false_positives += pred_count 222 | 223 | # Account for any ground truths that had no corresponding predictions 224 | for category, gt_data in groundtruths[0].items(): 225 | if category not in prediction: 226 | false_negatives += gt_data["num"] 227 | 228 | # Calculate Jaccard metric 229 | jaccard_metric = true_positives / ( 230 | true_positives + false_positives + false_negatives 231 | ) 232 | 233 | print("true_positives: ", true_positives) 234 | print("false_positives: ", false_positives) 235 | print("false_negatives: ", false_negatives) 236 | 237 | return jaccard_metric, true_positives, false_positives, false_negatives 238 | 239 | 240 | jaccard_metric_sum = 0 241 | total_tp = 0 242 | total_fp = 0 243 | total_fn = 0 244 | total_number_of_scenes = len(final_preprocessed_data_common.keys()) 245 | 246 | for key in final_preprocessed_data_common.keys(): 247 | jaccard_metric, true_positives, false_positives, false_negatives = compute_jaccard( 248 | [final_preprocessed_prediction_data_common[key]], 249 | [final_preprocessed_data_common[key]], 250 | args.tau, 251 | args.dimensions_tau, 252 | ) 253 | 254 | total_tp += true_positives 255 | total_fp += false_positives 256 | total_fn += false_negatives 257 | jaccard_metric_sum += jaccard_metric 258 | 259 | 260 | print( 261 | "Average Jaccard metric for scene matching: {}".format( 262 | round(jaccard_metric_sum / total_number_of_scenes, 4) 263 | ) 264 | ) 265 | 266 | print("Average TP: ", total_tp / total_number_of_scenes) 267 | print("Average FP: ", total_fp / total_number_of_scenes) 268 | print("Average FN: ", total_fn / total_number_of_scenes) 269 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aligning Text, Images, and 3D Structure Token-by-Token 2 | 3 | This repository contains the official code for the paper [**Aligning Text, Images, and 3D Structure Token-by-Token**](https://glab-caltech.github.io/kyvo/). 4 | **Authors:** [Aadarsh Sahoo](https://aadsah.github.io/), [Vansh Tibrewal](https://vanshtibrewal.github.io/), [Georgia Gkioxari](https://gkioxari.github.io/) 5 | 6 | 7 |

8 | 9 | Project Page 10 | 11 | 12 | arXiv Paper 13 | 14 | 15 | Dataset on Hugging Face 16 | 17 | 18 | BibTeX Citation 19 | 20 |

21 | 22 | 23 | 24 | 25 | 26 |

27 | Teaser Image 28 |

29 |

Kyvo: a decoder-only transformer aligns a structured 3D modality with language and vision. 30 | This 3D modality represents scenes as lists of objects, each defined by its 3D shape, type, 3D position, 31 | pose and size parameters. Kyvo unifies the token space of images, text, and 3D to enable a variety of 32 | complex visual 3D tasks.

33 | 34 | 35 | 36 | --- 37 | 38 | ## Table of Contents 39 | 1. [Setup](#setup) 40 | 2. [Dataset and VQGAN Codebooks](#dataset-and-vqgan-codebooks) 41 | 3. [Download Llama-3.2-1B Models](#download-llama-3.2-1b-models) 42 | 4. [Training](#training) 43 | 5. [Evaluation](#evaluation) 44 | 45 | --- 46 | 47 | ## 📢 News 48 | - **2025-06-09**: Kyvo is on arXiv! 49 | 50 | ## 📋 TODO 51 | - [ ] Release code and data for ARKitScenes and ObjaWorld with explicit shape representations. 52 | - [ ] HuggingFace 🤗 demo. 53 | - [x] Release training code and data for CLEVR, ObjaWorld (complex object shapes), Objectron. 54 | 55 | ## Setup 56 | 57 | This repository uses [torchtune](https://github.com/pytorch/torchtune) for training, included as a submodule. 58 | 59 | 1. **Clone this repository**: 60 | 61 | ```bash 62 | git clone --recurse-submodules https://github.com/AadSah/kyvo.git 63 | cd kyvo 64 | ``` 65 | 66 | 2. **Set up the environment:**: 67 | 68 | **Option A:** Create a new conda environment and install the required dependencies: 69 | 70 | ```bash 71 | cd kyvo 72 | conda create -n kyvo python=3.11 73 | conda activate kyvo 74 | cd torchtune 75 | pip install -e . 76 | cd .. 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | **Option B:** Use the provided conda environment file to create the environment: 81 | 82 | ```bash 83 | conda env create -f kyvo.yml 84 | conda activate kyvo 85 | cd torchtune 86 | pip install -e . 87 | cd .. 88 | ``` 89 | 90 | ## Dataset and VQGAN Codebooks 91 | 92 | Download the pre-tokenized data, VQGAN checkpoints, and codebooks from Hugging Face: 93 | 94 | ```bash 95 | git clone git@hf.co:datasets/aadarsh99/kyvo-datasets-and-codebooks 96 | cd kyvo-datasets-and-codebooks 97 | git lfs install 98 | git pull 99 | ``` 100 | 101 | This will create a folder `./kyvo-datasets-and-codebooks` with the following structure: 102 | 103 | ```python 104 | kyvo-datasets-and-codebooks/ 105 | |-- images-and-scenes-for-evaluation/ # contains all images and scenes for evaluation 106 | | |-- clevr/ # all CLEVR related files 107 | | |-- objaworld/ # all ObjaWorld related files 108 | | | ... 109 | |-- pretokenized-data/ # contains all pre-tokenized data for all the datasets 110 | |-- |-- clevr/ # all CLEVR related files 111 | | |-- objaworld/ # all ObjaWorld related files 112 | | | ... 113 | |-- vqgan-models-and-codebooks/ # contains all VQGAN model checkpoints and codebooks 114 | | |-- clevr/ # all CLEVR related files 115 | | |-- objaworld/ # all ObjaWorld related files 116 | | | ... 117 | ``` 118 | 119 | More details about the dataset and VQGAN codebooks can be found in the [`DATA.md`](DATA.md) file. 120 | 121 | ## Download Llama-3.2-1B Models 122 | 123 | Use `tune download` to fetch the Llama-3.2-1B models: 124 | 125 | ```bash 126 | tune download meta-llama/Llama-3.2-1B --output-dir ./llama-3-models/Llama3.2-1B/ 127 | 128 | tune download meta-llama/Llama-3.2-1B-Instruct --output-dir ./llama-3-models/Llama3.2-1B-Instruct/ 129 | ``` 130 | 131 | For convenience and compatibility with the provided config files, you may need to restructure the downloaded model directories (otherwise, update paths in your config files accordingly): 132 | 133 | ```bash 134 | mv ./llama-3-models/Llama3.2-1B/original/* ./llama-3-models/Llama3.2-1B/ 135 | 136 | mv ./llama-3-models/Llama3.2-1B-Instruct/original/* ./llama-3-models/Llama3.2-1B-Instruct/ 137 | ``` 138 | 139 | ## Training 140 | 141 | All training configuration files are located in `./configs/llama3_2/train/`. Below are sample commands for different datasets and tasks: 142 | 143 | ### CLEVR: 144 | 145 | ```bash 146 | # Rendering 147 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/rendering.yaml 148 | 149 | # Recognition 150 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/recognition.yaml 151 | 152 | # Instruction-Following 153 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/instruction_following.yaml 154 | 155 | # Question-Answering 156 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/question_answering.yaml 157 | ``` 158 | 159 | ### ObjaWorld: 160 | 161 | ```bash 162 | # Rendering 163 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objaworld/rendering.yaml 164 | 165 | # Recognition 166 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objaworld/recognition.yaml 167 | ``` 168 | 169 | ### Objectron: 170 | 171 | ```bash 172 | # Recognition 173 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objectron/recognition.yaml 174 | ``` 175 | 176 | After training, the models will be saved in the `./checkpoints/` directory. 177 | 178 | ## Evaluation 179 | 180 | Evaluation configuration files are located in `./configs/llama3_2/eval/`. Below are sample commands: 181 | 182 | ### CLEVR: 183 | 184 | ```python 185 | # Rendering 186 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/rendering.yaml 187 | 188 | # Recognition 189 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/recognition.yaml 190 | 191 | # Instruction-Following (5 sub-tasks are individually evaluated) 192 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/appearance-no-relation.yaml 193 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/appearance-with-relation.yaml 194 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/insertion-with-relation.yaml 195 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/moving-with-relation.yaml 196 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/removal-with-relation.yaml 197 | 198 | # Question-Answering 199 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/question-answering.yaml 200 | ``` 201 | 202 | ### ObjaWorld: 203 | 204 | ```python 205 | # Rendering 206 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objaworld/rendering.yaml 207 | 208 | # Recognition 209 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objaworld/recognition.yaml 210 | ``` 211 | 212 | ### Objectron: 213 | 214 | ```python 215 | # Recognition 216 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objectron/recognition.yaml 217 | ``` 218 | 219 | Evaluation outputs are saved (by default) to `./checkpoints/` or a directory specified in the config files. Depending on the task, outputs may include 3D scene JSONs, image embeddings, or text predictions. 220 | 221 | 222 | ### Metrics Computation 223 | 224 | Below are scripts for computing the evaluation metrics used in the paper. 225 | 226 | 1. **Jaccard Index on 3D Scenes** 227 | 228 | ```bash 229 | # CLEVR 230 | python3 scripts/compute_jaccard_index.py --tau 0.05 --generated_folder ./checkpoints/clevr/recognition-inference/three_d_json/clevr_recognition_inference --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/clevr/original_images/scenes --dataset clevr 231 | 232 | # ObjaWorld 233 | python3 scripts/compute_jaccard_index.py --tau 0.05 --generated_folder ./checkpoints/objaworld/recognition-inference/three_d_json/objaworld_recognition_inference --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objaworld/original_images/ --dataset objaworld 234 | 235 | # Objectron (kyvo) 236 | python3 scripts/compute_jaccard_index.py --tau 0.05 --dimension_tau 0.05 --predictions_file ./checkpoints/objectron/recognition-inference/three_d_json/objectron_recognition_inference/predicted_scenes_objectron_recognition_inference.json --groundtruth_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/Objectron_test.json --dataset objaworld --method kyvo 237 | 238 | # Objectron (Cube-RCNN) 239 | python3 scripts/compute_jaccard_index.py --tau 0.05 --dimension_tau 0.05 --predictions_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/omni_instance_results_resnet34.json --groundtruth_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/Objectron_test.json --dataset objaworld --method cube-rcnn 240 | ``` 241 | * `--tau`: Threshold for Jaccard index. 242 | * `--generated_folder`: Path to predicted 3D scenes. 243 | * `--groundtruth_folder`: Path to ground truth 3D scenes. 244 | * `--dataset`: Dataset name. 245 | 246 | 2. **Decoding Image Embeddings and Computing SSIM / L2-Loss** 247 | 248 | Create a separate environment to install [`taming-transformers`](https://github.com/CompVis/taming-transformers/tree/master): 249 | 250 | ```bash 251 | conda env create -f taming-kyvo.yml 252 | conda activate taming-kyvo 253 | cd taming-transformers 254 | pip install -e . 255 | cd .. 256 | ``` 257 | 258 | **Decode Image Embeddings:** 259 | 260 | ```bash 261 | # CLEVR 262 | python3 scripts/decode_image_embeddings.py --vqgan_type clevr --folder_path ./checkpoints/clevr/rendering-inference/image_embeddings/clevr_rendering_inference --image_output_path ./checkpoints/clevr/rendering-inference/decoded_images/ 263 | 264 | # ObjaWorld 265 | python3 scripts/decode_image_embeddings.py --vqgan_type objaworld --folder_path ./checkpoints/objaworld/rendering-inference/image_embeddings/objaworld_rendering_inference --image_output_path ./checkpoints/objaworld/rendering-inference/decoded_images/ 266 | ``` 267 | 268 | * `--folder_path`: Path to image embeddings. 269 | * `--vqgan_type`: Dataset name. 270 | * `--image_output_path`: Path to save decoded images. 271 | 272 | **Compute SSIM and L2-Loss:** 273 | 274 | ```bash 275 | # CLEVR 276 | python3 scripts/compute_ssim_l2loss.py --generated_folder ./checkpoints/clevr/rendering-inference/decoded_images/GENERATED --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/clevr/original_images/images --output_folder ./checkpoints/clevr/rendering-inference/decoded_images 277 | ``` 278 | 279 | * `--generated_folder`: Path to predicted images. 280 | * `--groundtruth_folder`: Path to ground truth images. 281 | * `--output_folder`: Path to save computed SSIM and L2-loss values. 282 | 283 | 3. **Text Output Accuracy** 284 | 285 | ```bash 286 | python3 scripts/compute_text_answer_accuracy.py --predicted_file ./checkpoints/clevr/question-answering-inference/three_d_json/clevr_question-answering_inference/predicted_answers_clevr_question-answering_inference.json --groundtruth_file ./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_answers.json 287 | ``` 288 | 289 | * `--predicted_file`: Path to predicted text answers. 290 | * `--groundtruth_file`: Path to ground truth text answers. 291 | 292 | --- 293 | 294 | ## Citation 295 | 296 | If you find this work useful, please consider citing: 297 | 298 | ```bibtex 299 | @misc{sahoo2025aligningtextimages3d, 300 | title={Aligning Text, Images, and 3D Structure Token-by-Token}, 301 | author={Aadarsh Sahoo and Vansh Tibrewal and Georgia Gkioxari}, 302 | year={2025}, 303 | eprint={2506.08002}, 304 | archivePrefix={arXiv}, 305 | primaryClass={cs.CV}, 306 | url={https://arxiv.org/abs/2506.08002}, 307 | } 308 | ``` 309 | 310 | --- 311 | 312 | For any questions or issues, please open a GitHub issue or contact [Aadarsh](mailto:aadarsh.sahoo.99@gmail.com). Thank you for your interest in our work! 313 | 314 | 315 | 316 | 317 | -------------------------------------------------------------------------------- /checkpointer/convert_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import re 8 | 9 | from typing import Any, Dict 10 | 11 | import torch 12 | 13 | _FROM_META_3D_SIN_COS_NUM = { 14 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 15 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 16 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 17 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 18 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 19 | "norm.weight": "norm.scale", 20 | "output.weight": "output.weight", 21 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 22 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 23 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 24 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 25 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 26 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 27 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 28 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 29 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 30 | } 31 | 32 | _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM = { 33 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 34 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 35 | "tok_embeddings.static_sin_cos_embedding.weight": "tok_embeddings.static_sin_cos_embedding.weight", 36 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 37 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 38 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 39 | "norm.weight": "norm.scale", 40 | "output.weight": "output.weight", 41 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 42 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 43 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 44 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 45 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 46 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 47 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 48 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 49 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 50 | } 51 | 52 | _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP = { 53 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 54 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 55 | "tok_embeddings.static_sin_cos_embedding.weight": "tok_embeddings.static_sin_cos_embedding.weight", 56 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 57 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 58 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 59 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 60 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 61 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 62 | "norm.weight": "norm.scale", 63 | "output.weight": "output.weight", 64 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 65 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 66 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 67 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 68 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 69 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 70 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 71 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 72 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 73 | } 74 | 75 | _FROM_META_3D_SIN_COS_NUM_WITH_MLP = { 76 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 77 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 78 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 79 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 80 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 81 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 82 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 83 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 84 | "norm.weight": "norm.scale", 85 | "output.weight": "output.weight", 86 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 87 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 88 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 89 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 90 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 91 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 92 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 93 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 94 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 95 | } 96 | 97 | _FROM_META_3D = { 98 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 99 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 100 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 101 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 102 | "norm.weight": "norm.scale", 103 | "output.weight": "output.weight", 104 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 105 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 106 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 107 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 108 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 109 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 110 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 111 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 112 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 113 | } 114 | 115 | # state dict key mappings from Meta's format to torchtune's format 116 | _FROM_META = { 117 | "tok_embeddings.weight": "tok_embeddings.weight", 118 | "norm.weight": "norm.scale", 119 | "output.weight": "output.weight", 120 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 121 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 122 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 123 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 124 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 125 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 126 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 127 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 128 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 129 | } 130 | 131 | # state dict key mappings from HF's format to torchtune's format 132 | _FROM_HF = { 133 | "model.embed_tokens.weight": "tok_embeddings.weight", 134 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", 135 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", 136 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", 137 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", 138 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 139 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", 140 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", 141 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", 142 | "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", 143 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", 144 | "model.norm.weight": "norm.scale", 145 | "lm_head.weight": "output.weight", 146 | } 147 | 148 | 149 | def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: 150 | try: 151 | # Checks if there is a layer # in the key 152 | if any(k.isdigit() for k in key.split(".")): 153 | # Replace layer number with "{}" to create key for lookup 154 | abstract_key = re.sub(r"(\.\d+)", ".{}", key) 155 | layer_num = re.search(r"\d+", key).group(0) 156 | new_key = mapping_dict[abstract_key] 157 | new_key = new_key.format(layer_num) 158 | else: 159 | new_key = mapping_dict[key] 160 | except KeyError as e: 161 | raise Exception( 162 | f'Error converting the state dict. Found unexpected key: "{key}". ' 163 | "Please make sure you're loading a checkpoint with the right format. " 164 | ) from e 165 | 166 | return new_key 167 | 168 | 169 | def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 170 | """ 171 | Convert a state dict from Meta's format to torchtune's format. State dicts 172 | from multiple checkpoint files should be consolidated into a single state dict 173 | before calling this function. 174 | 175 | Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` 176 | repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). 177 | 178 | Args: 179 | state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. 180 | 181 | Returns: 182 | Dict[str, torch.Tensor]: State dict in torchtune's format. 183 | """ 184 | converted_state_dict = {} 185 | for key, value in state_dict.items(): 186 | if key not in ["rope.freqs"]: # Skip loading the position embeddings 187 | new_key = get_mapped_key(key, _FROM_META) 188 | converted_state_dict[new_key] = value 189 | 190 | return converted_state_dict 191 | 192 | 193 | def meta_to_tune_3d( 194 | state_dict: Dict[str, torch.Tensor], convert_weights_type: str 195 | ) -> Dict[str, torch.Tensor]: 196 | """ 197 | Convert a state dict from Meta's format to torchtune's format. State dicts 198 | from multiple checkpoint files should be consolidated into a single state dict 199 | before calling this function. 200 | 201 | Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` 202 | repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). 203 | 204 | Args: 205 | state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. 206 | 207 | Returns: 208 | Dict[str, torch.Tensor]: State dict in torchtune's format. 209 | """ 210 | if convert_weights_type == "3d": 211 | dictionary_to_use = _FROM_META_3D 212 | elif convert_weights_type == "3d_sin_cos_num": 213 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM 214 | elif convert_weights_type == "3d_sin_cos_plus_learned_num": 215 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM 216 | elif convert_weights_type == "3d_sin_cos_plus_learned_num_with_mlp": 217 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP 218 | elif convert_weights_type == "3d_sin_cos_num_with_mlp": 219 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM_WITH_MLP 220 | else: 221 | raise ValueError( 222 | "convert_weights_type should be one of '3d', '3d_sin_cos_num', '3d_sin_cos_plus_learned_num', '3d_sin_cos_plus_learned_num_with_mlp', '3d_sin_cos_num_with_mlp'" 223 | ) 224 | converted_state_dict = {} 225 | for key, value in state_dict.items(): 226 | if key not in ["rope.freqs"]: # Skip loading the position embeddings 227 | new_key = get_mapped_key(key, dictionary_to_use) 228 | converted_state_dict[new_key] = value 229 | 230 | return converted_state_dict 231 | 232 | 233 | def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 234 | """ 235 | Convert a state dict from torchtune's format to Meta's format. This function 236 | doesn't handle any sharding or splitting of state dicts. It follows the 237 | state_dict IN -> state_dict OUT pattern. 238 | 239 | Args: 240 | state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. 241 | 242 | Returns: 243 | Dict[str, torch.Tensor]: State dict in Meta's format. 244 | """ 245 | converted_state_dict = {} 246 | inverted_mapping_dict = {v: k for k, v in _FROM_META.items()} 247 | 248 | for key, value in state_dict.items(): 249 | new_key = get_mapped_key(key, inverted_mapping_dict) 250 | converted_state_dict[new_key] = value 251 | 252 | return converted_state_dict 253 | 254 | 255 | def tune_to_meta_3d( 256 | state_dict: Dict[str, torch.Tensor], convert_weights_type: str 257 | ) -> Dict[str, torch.Tensor]: 258 | """ 259 | Convert a state dict from torchtune's format to Meta's format. This function 260 | doesn't handle any sharding or splitting of state dicts. It follows the 261 | state_dict IN -> state_dict OUT pattern. 262 | 263 | Args: 264 | state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. 265 | 266 | Returns: 267 | Dict[str, torch.Tensor]: State dict in Meta's format. 268 | """ 269 | if convert_weights_type == "3d": 270 | dictionary_to_use = _FROM_META_3D 271 | elif convert_weights_type == "3d_sin_cos_num": 272 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM 273 | elif convert_weights_type == "3d_sin_cos_plus_learned_num": 274 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM 275 | elif convert_weights_type == "3d_sin_cos_plus_learned_num_with_mlp": 276 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP 277 | elif convert_weights_type == "3d_sin_cos_num_with_mlp": 278 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM_WITH_MLP 279 | else: 280 | raise ValueError( 281 | "convert_weights_type should be one of '3d', '3d_sin_cos_num', '3d_sin_cos_plus_learned_num', '3d_sin_cos_plus_learned_num_with_mlp', '3d_sin_cos_num_with_mlp'" 282 | ) 283 | converted_state_dict = {} 284 | inverted_mapping_dict = {v: k for k, v in dictionary_to_use.items()} 285 | 286 | for key, value in state_dict.items(): 287 | new_key = get_mapped_key(key, inverted_mapping_dict) 288 | converted_state_dict[new_key] = value 289 | 290 | return converted_state_dict 291 | -------------------------------------------------------------------------------- /checkpointer/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional, Protocol 10 | 11 | import torch 12 | from torchtune import training 13 | 14 | import checkpointer.convert_weights as convert_weights 15 | from torchtune.training.checkpointing._utils import ( 16 | get_path, 17 | ModelType, 18 | safe_torch_load, 19 | ) 20 | from torchtune.utils._logging import get_logger 21 | 22 | logger = get_logger("DEBUG") 23 | 24 | 25 | class _CheckpointerInterface(Protocol): 26 | """ 27 | Interface implemented by Checkpointers in torchtune. 28 | 29 | torchtune checkpointers are designed to be composable components which can be plugged 30 | into any training recipe. Each checkpointer supports a specific set of models and training 31 | scenarios making these easy to understand, debug and extend. For example, the 32 | ``FullModelCheckpointer``s are used for loading and saving all of the model weights. 33 | This checkpointer can be used for Full-Finetuning scenarios or PEFT where the output is a 34 | merged checkpoint. In case the current suite of checkpointers are inadequate, 35 | users are encouraged to implement their own and contribute back to torchtune. 36 | 37 | torchtune is also designed to be "state-dict invariant". This means the checkpointer 38 | ensures that the output checkpoint has the same format as the original checkpoint i.e. 39 | the output checkpoint has the same keys split across the same number of files as the original 40 | checkpoint. Being "state-dict invariant" allows users to seamlessly use torchtune checkpoints 41 | with their favorite post-training tools from the open-source ecosystem without writing 42 | torchtune-specific convertors. To be "state-dict invariant", the ``load_checkpoint`` and 43 | ``save_checkpoint`` methods make use of the weight convertors available in 44 | ``torchtune/models/``. 45 | 46 | torchtune Checkpointers support two checkpointing scenarios: 47 | * End-of-training Checkpointing. The model weights at the end of a completed training 48 | run are written out to file. The checkpointer ensures that the output checkpoint 49 | files have the same keys as the input checkpoint file used to begin training. The 50 | checkpointer also ensures that the keys are partitioned across the same number of 51 | files as the original checkpoint. This ensures that the original metadata files can 52 | be used as is, and the output checkpoint can be used with any tool that understands 53 | the original checkpoint format. This includes popular inference engines such as 54 | ``llama.cpp`` and ``gpt-fast``. The output state dict has the following format: 55 | { 56 | "key_1": weight 57 | ... 58 | } 59 | 60 | 61 | Mid-training Chekpointing. In addition to the model checkpoint files, we output an 62 | additional "recipe_state.pt" file for intermediate checkpoints. These are currently 63 | output at the end of each epoch, and contain information such as optimizer state, 64 | number of epochs completed etc which is needed to correctly resume a previously 65 | interrupted training run. The recipe is responsible for constructing the state dict 66 | with the information it needs. The checkpointer extracts the model state dict 67 | (key = "model") and writes everything else out to "recipe_state.pt". To prevent us 68 | from flooding ``output_dir`` with checkpoint files, the recipe state is overwritten 69 | at the end of each epoch. The output state dicts have the following formats: 70 | 71 | Model: 72 | { 73 | "key_1": weight 74 | ... 75 | } 76 | 77 | Recipe State: 78 | { 79 | "optimizer": ..., 80 | "epoch": ..., 81 | ... 82 | } 83 | 84 | """ 85 | 86 | def load_checkpoint(self, **kwargs) -> Dict[str, Any]: ... 87 | 88 | def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None: ... 89 | 90 | 91 | class FullModelMetaCheckpointer3D(_CheckpointerInterface): 92 | """ 93 | Checkpointer which reads and writes checkpoints in Meta's format. Examples include 94 | the Llama-2-7b model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b) 95 | 96 | Currently we support reading from a single checkpoint file only. Support for reading from 97 | sharded checkpoints is WIP. 98 | 99 | Args: 100 | checkpoint_dir (str): Directory containing the checkpoint files 101 | checkpoint_files (List[str]): List of checkpoint files to load. Currently this checkpointer only 102 | supports loading a single checkpoint file. 103 | model_type (ModelType): Model type of the model for which the checkpointer is being loaded 104 | output_dir (str): Directory to save the checkpoint files 105 | adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None 106 | recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None 107 | resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to 108 | resume training from a previous run. Default is False 109 | 110 | Raises: 111 | ValueError: If ``checkpoint_files`` is not a list of length 1 112 | ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None 113 | """ 114 | 115 | def __init__( 116 | self, 117 | checkpoint_dir: str, 118 | checkpoint_files: List[str], 119 | model_type: ModelType, 120 | output_dir: str, 121 | adapter_checkpoint: Optional[str] = None, 122 | recipe_checkpoint: Optional[str] = None, 123 | resume_from_checkpoint: bool = False, 124 | convert_weights_type: str = "meta_to_tune_3d", 125 | ) -> None: 126 | # Fail fast if ``checkpoint_files`` is invalid 127 | if len(checkpoint_files) != 1: 128 | raise ValueError( 129 | "Currently we only support reading from a single torchtune checkpoint file. " 130 | f"Got {len(checkpoint_files)} files instead." 131 | ) 132 | 133 | self._checkpoint_dir = Path(checkpoint_dir) 134 | self._checkpoint_path = get_path(self._checkpoint_dir, checkpoint_files[0]) 135 | 136 | self._adapter_checkpoint = ( 137 | get_path(self._checkpoint_dir, adapter_checkpoint) 138 | if adapter_checkpoint 139 | else None 140 | ) 141 | 142 | self._resume_from_checkpoint = resume_from_checkpoint 143 | self._model_type = ModelType[model_type] 144 | self._output_dir = Path(output_dir) 145 | 146 | # recipe_checkpoint contains the recipe state. This should be available if 147 | # resume_from_checkpoint is True 148 | self._recipe_checkpoint = None 149 | if self._resume_from_checkpoint: 150 | if recipe_checkpoint is None: 151 | raise ValueError( 152 | "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." 153 | ) 154 | self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint) 155 | 156 | self._convert_weights_type = convert_weights_type 157 | 158 | def load_checkpoint(self) -> Dict[str, Any]: 159 | """ 160 | Load Meta checkpoint from file. Currently only loading from a single file is supported. 161 | """ 162 | state_dict: Dict[str:Any] = {} 163 | model_state_dict = safe_torch_load(self._checkpoint_path) 164 | if self._model_type == ModelType.LLAMA3_VISION: 165 | from torchtune.models.llama3_2_vision._convert_weights import ( 166 | llama3_vision_meta_to_tune, 167 | ) 168 | 169 | state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune( 170 | model_state_dict 171 | ) 172 | else: 173 | state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune( 174 | model_state_dict 175 | ) 176 | 177 | # llama3_2 has tied weights, so we need to remove the output.weight key 178 | if self._model_type == ModelType.LLAMA3_2: 179 | logger.info( 180 | "Identified model_type = Llama3_2. Ignoring output.weight in" 181 | " checkpoint in favor of the tok_embedding.weight" 182 | " tied weights." 183 | ) 184 | state_dict[training.MODEL_KEY].pop("output.weight") 185 | 186 | if self._adapter_checkpoint: 187 | adapter_state_dict = safe_torch_load(self._adapter_checkpoint) 188 | state_dict[training.ADAPTER_KEY] = adapter_state_dict 189 | 190 | if self._resume_from_checkpoint: 191 | recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) 192 | state_dict.update(recipe_state) 193 | return state_dict 194 | 195 | def load_checkpoint_3d(self) -> Dict[str, Any]: 196 | """ 197 | Load Meta checkpoint from file. Currently only loading from a single file is supported. 198 | """ 199 | state_dict: Dict[str:Any] = {} 200 | model_state_dict = safe_torch_load(self._checkpoint_path) 201 | if self._model_type == ModelType.LLAMA3_VISION: 202 | from torchtune.models.llama3_2_vision._convert_weights import ( 203 | llama3_vision_meta_to_tune, 204 | ) 205 | 206 | state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune( 207 | model_state_dict 208 | ) 209 | else: 210 | state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune_3d( 211 | model_state_dict, self._convert_weights_type 212 | ) 213 | 214 | # llama3_2 has tied weights, so we need to remove the output.weight key 215 | if self._model_type == ModelType.LLAMA3_2: 216 | logger.info( 217 | "Identified model_type = Llama3_2. Ignoring output.weight in" 218 | " checkpoint in favor of the tok_embedding.weight" 219 | " tied weights." 220 | ) 221 | # state_dict[training.MODEL_KEY].pop("output.weight") 222 | # We dont need to do this for 3D models 223 | 224 | if self._adapter_checkpoint: 225 | adapter_state_dict = safe_torch_load(self._adapter_checkpoint) 226 | state_dict[training.ADAPTER_KEY] = adapter_state_dict 227 | 228 | if self._resume_from_checkpoint: 229 | recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) 230 | state_dict.update(recipe_state) 231 | return state_dict 232 | 233 | def save_checkpoint( 234 | self, 235 | state_dict: Dict[str, Any], 236 | epoch: int, 237 | intermediate_checkpoint: bool = False, 238 | adapter_only: bool = False, 239 | ) -> None: 240 | """ 241 | Save Meta checkpoint to file. If ``intermediate_checkpoint`` is True, an additional 242 | checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe 243 | state. 244 | 245 | Args: 246 | state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file 247 | epoch (int): Epoch number. Used to create the checkpoint file name 248 | intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state 249 | and (if applicable) adapter weights are created. Default is False 250 | adapter_only (bool): If True, only save the adapter weights. Default is False 251 | 252 | Raises: 253 | ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict. 254 | """ 255 | self._output_dir.mkdir(exist_ok=True) 256 | 257 | if not adapter_only: 258 | model_state_dict = state_dict[training.MODEL_KEY] 259 | if self._model_type == ModelType.LLAMA3_VISION: 260 | from torchtune.models.llama3_2_vision._convert_weights import ( 261 | llama3_vision_tune_to_meta, 262 | ) 263 | 264 | state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( 265 | model_state_dict 266 | ) 267 | else: 268 | # llama3_2 has tied weights, so we need to add the output.weight key 269 | if ( 270 | self._model_type == ModelType.LLAMA3_2 271 | and "output.weight" not in model_state_dict 272 | ): 273 | model_state_dict["output.weight"] = model_state_dict[ 274 | "tok_embeddings.weight" 275 | ] 276 | 277 | state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta_3d( 278 | model_state_dict, self._convert_weights_type 279 | ) 280 | 281 | # Output file is always a .pt file with the epoch number in the name 282 | checkpoint_file = Path.joinpath( 283 | self._output_dir, f"meta_model_{epoch}" 284 | ).with_suffix(".pt") 285 | torch.save(state_dict[training.MODEL_KEY], checkpoint_file) 286 | logger.info( 287 | "Model checkpoint of size " 288 | f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB " 289 | f"saved to {checkpoint_file}" 290 | ) 291 | 292 | if training.ADAPTER_KEY in state_dict: 293 | output_path = Path.joinpath( 294 | self._output_dir, f"adapter_{epoch}" 295 | ).with_suffix(".pt") 296 | torch.save(state_dict[training.ADAPTER_KEY], output_path) 297 | logger.info( 298 | "Adapter checkpoint of size " 299 | f"{os.path.getsize(output_path) / 1000**3:.2f} GB " 300 | f"saved to {output_path}" 301 | ) 302 | elif adapter_only: 303 | raise ValueError( 304 | "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." 305 | ) 306 | 307 | # If the recipe state needs to be output, first remove the model state dict 308 | # and if it exists, remove the adapter state dict as well 309 | if intermediate_checkpoint: 310 | _ = state_dict.pop(training.MODEL_KEY) 311 | _ = state_dict.pop(training.ADAPTER_KEY, None) 312 | _ = state_dict.pop(training.ADAPTER_CONFIG, None) 313 | output_path = Path.joinpath(self._output_dir, "recipe_state.pt") 314 | torch.save(state_dict, output_path) 315 | logger.info( 316 | "Recipe checkpoint of size " 317 | f"{os.path.getsize(output_path) / 1000**3:.2f} GB " 318 | f"saved to {output_path}" 319 | ) 320 | else: 321 | logger.info("Saving final epoch checkpoint.") 322 | if adapter_only: 323 | logger.info( 324 | "Please note that you have set adapter_only=True, so only adapter weights will be saved." 325 | "You need to merge the adapter weights into your base model for further use. " 326 | f"See {self.__class__.__name__}.save_checkpoint for more details." 327 | ) 328 | else: 329 | logger.info( 330 | "The full model checkpoint, including all weights and configurations, has been saved successfully." 331 | "You can now use this checkpoint for further training or inference." 332 | ) 333 | -------------------------------------------------------------------------------- /models/model_builders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import List 7 | from functools import partial 8 | 9 | from models.model_component_builders import ( 10 | llama3_2_clevr3d, 11 | lora_llama3_2_clevr3d, 12 | ) 13 | 14 | from torchtune.modules import TransformerDecoder 15 | from torchtune.modules.peft import LORA_ATTN_MODULES 16 | 17 | """ 18 | Model builders build specific instantiations using component builders. For example 19 | the llama3_2_1b model builder uses the llama3_2 component builder to create the 20 | Llama3.2 1B model. 21 | """ 22 | 23 | DOMAIN_AGNOSTIC_VQGAN_CODEBOOK_PATH = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/domain-agnostic/quantize_weight_8192.npy" 24 | CLEVR_VQGAN_CODEBOOK_PATH = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/clevr/custom_vqgan_embedding_1024CLEVRLARGE_256dim.npy" 25 | OBJAWORLD_VQGAN_CODEBOOK_PATH = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objaworld/custom_vqgan_embedding_256SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100_256dim.npy" 26 | OBJECTRON_VQGAN_CODEBOOK_PATH = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objectron/custom_vqgan_embedding_256Omni3D-OBJECTRON_256dim.npy" 27 | 28 | 29 | def lora_llama3_2_1b_clevr3d_sin_cos_numbers( 30 | lora_attn_modules: List[LORA_ATTN_MODULES], 31 | apply_lora_to_mlp: bool = False, 32 | apply_lora_to_output: bool = False, 33 | lora_rank: int = 8, 34 | lora_alpha: float = 16, 35 | lora_dropout: float = 0.0, 36 | use_dora: bool = False, 37 | quantize_base: bool = False, 38 | ) -> TransformerDecoder: 39 | """ 40 | Builder for creating a Llama3.2 1B model with LoRA enabled. 41 | The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`, 42 | while LoRA default params are based on 43 | https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. 44 | 45 | Args: 46 | lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers 47 | LoRA should be applied to in each self-attention block. Options are 48 | ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. 49 | apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. 50 | Default: False 51 | apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. 52 | Default: False 53 | lora_rank (int): rank of each low-rank approximation 54 | lora_alpha (float): scaling factor for the low-rank approximation 55 | lora_dropout (float): dropout probability for the low-rank approximation 56 | use_dora (bool): Decompose the LoRA weight into magnitude and direction, as 57 | introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). 58 | quantize_base (bool): Whether to quantize base model weights 59 | 60 | Returns: 61 | TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied 62 | """ 63 | return lora_llama3_2_clevr3d( 64 | lora_attn_modules=lora_attn_modules, 65 | apply_lora_to_mlp=apply_lora_to_mlp, 66 | apply_lora_to_output=apply_lora_to_output, 67 | vocab_size=128_256, 68 | num_layers=16, 69 | num_heads=32, 70 | num_kv_heads=8, 71 | embed_dim=2048, 72 | max_seq_len=131072, 73 | intermediate_dim=8192, 74 | attn_dropout=0.0, 75 | norm_eps=1e-5, 76 | rope_base=500_000, 77 | scale_factor=32, 78 | lora_rank=lora_rank, 79 | lora_alpha=lora_alpha, 80 | lora_dropout=lora_dropout, 81 | use_dora=use_dora, 82 | quantize_base=quantize_base, 83 | added_tokens_offset=128256, 84 | vqgan_embed_dim=256, 85 | vqgan_start_index=129471, 86 | vqgan_end_index=130495, 87 | vqgan_vocab_size=1024, 88 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 89 | image_token_offset=129471, 90 | use_sin_cos_numbers=True, 91 | ) 92 | 93 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_NO_FINDINGS() -> ( 94 | TransformerDecoder 95 | ): 96 | """ 97 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 98 | 99 | Returns: 100 | TransformerDecoder: Instantiation of Llama3.2 1B model 101 | """ 102 | return llama3_2_clevr3d( 103 | vocab_size=128_256, 104 | num_layers=16, 105 | num_heads=32, 106 | num_kv_heads=8, 107 | embed_dim=2048, 108 | max_seq_len=131072, 109 | intermediate_dim=8192, 110 | attn_dropout=0.0, 111 | norm_eps=1e-5, 112 | rope_base=500_000, 113 | scale_factor=32, 114 | added_tokens_offset=128256, 115 | vqgan_embed_dim=256, 116 | vqgan_start_index=128256, 117 | vqgan_end_index=136448, 118 | vqgan_vocab_size=8192, 119 | vqgan_codebook_path=DOMAIN_AGNOSTIC_VQGAN_CODEBOOK_PATH, 120 | image_token_offset=128256, 121 | use_sin_cos_numbers=False, 122 | sin_cos_numbers_offset=0, 123 | use_sin_cos_plus_learned=False, 124 | no_independent_numbers=False, 125 | no_independent_numbers_no_3d_tokens=True, 126 | ) 127 | 128 | 129 | 130 | def lora_llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers( 131 | lora_attn_modules: List[LORA_ATTN_MODULES], 132 | apply_lora_to_mlp: bool = False, 133 | apply_lora_to_output: bool = False, 134 | lora_rank: int = 8, 135 | lora_alpha: float = 16, 136 | lora_dropout: float = 0.0, 137 | use_dora: bool = False, 138 | quantize_base: bool = False, 139 | ) -> TransformerDecoder: 140 | """ 141 | Builder for creating a Llama3.2 1B model with LoRA enabled. 142 | The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`, 143 | while LoRA default params are based on 144 | https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. 145 | 146 | Args: 147 | lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers 148 | LoRA should be applied to in each self-attention block. Options are 149 | ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. 150 | apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. 151 | Default: False 152 | apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. 153 | Default: False 154 | lora_rank (int): rank of each low-rank approximation 155 | lora_alpha (float): scaling factor for the low-rank approximation 156 | lora_dropout (float): dropout probability for the low-rank approximation 157 | use_dora (bool): Decompose the LoRA weight into magnitude and direction, as 158 | introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). 159 | quantize_base (bool): Whether to quantize base model weights 160 | 161 | Returns: 162 | TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied 163 | """ 164 | return lora_llama3_2_clevr3d( 165 | lora_attn_modules=lora_attn_modules, 166 | apply_lora_to_mlp=apply_lora_to_mlp, 167 | apply_lora_to_output=apply_lora_to_output, 168 | vocab_size=128_256, 169 | num_layers=16, 170 | num_heads=32, 171 | num_kv_heads=8, 172 | embed_dim=2048, 173 | max_seq_len=131072, 174 | intermediate_dim=8192, 175 | attn_dropout=0.0, 176 | norm_eps=1e-5, 177 | rope_base=500_000, 178 | scale_factor=32, 179 | lora_rank=lora_rank, 180 | lora_alpha=lora_alpha, 181 | lora_dropout=lora_dropout, 182 | use_dora=use_dora, 183 | quantize_base=quantize_base, 184 | added_tokens_offset=128256, 185 | vqgan_embed_dim=256, 186 | vqgan_start_index=129471, 187 | vqgan_end_index=130495, 188 | vqgan_vocab_size=1024, 189 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 190 | image_token_offset=129471, 191 | use_sin_cos_numbers=False, 192 | use_sin_cos_plus_learned=True, 193 | ) 194 | 195 | 196 | def llama3_2_1b_clevr3d_sin_cos_numbers() -> TransformerDecoder: 197 | """ 198 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 199 | 200 | Returns: 201 | TransformerDecoder: Instantiation of Llama3.2 1B model 202 | """ 203 | return llama3_2_clevr3d( 204 | vocab_size=128_256, 205 | num_layers=16, 206 | num_heads=32, 207 | num_kv_heads=8, 208 | embed_dim=2048, 209 | max_seq_len=131072, 210 | intermediate_dim=8192, 211 | attn_dropout=0.0, 212 | norm_eps=1e-5, 213 | rope_base=500_000, 214 | scale_factor=32, 215 | added_tokens_offset=128256, 216 | vqgan_embed_dim=256, 217 | vqgan_start_index=129471, 218 | vqgan_end_index=130495, 219 | vqgan_vocab_size=1024, 220 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 221 | image_token_offset=129471, 222 | use_sin_cos_numbers=True, 223 | ) 224 | 225 | 226 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100() -> ( 227 | TransformerDecoder 228 | ): 229 | """ 230 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 231 | 232 | Returns: 233 | TransformerDecoder: Instantiation of Llama3.2 1B model 234 | """ 235 | return llama3_2_clevr3d( 236 | vocab_size=128_256, 237 | num_layers=16, 238 | num_heads=32, 239 | num_kv_heads=8, 240 | embed_dim=2048, 241 | max_seq_len=131072, 242 | intermediate_dim=8192, 243 | attn_dropout=0.0, 244 | norm_eps=1e-5, 245 | rope_base=500_000, 246 | scale_factor=32, 247 | added_tokens_offset=128256, 248 | vqgan_embed_dim=256, 249 | vqgan_start_index=129471, 250 | vqgan_end_index=130495, 251 | vqgan_vocab_size=1024, 252 | vqgan_codebook_path=OBJAWORLD_VQGAN_CODEBOOK_PATH, 253 | image_token_offset=129471, 254 | use_sin_cos_numbers=False, 255 | use_sin_cos_plus_learned=True, 256 | ) 257 | 258 | 259 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers() -> TransformerDecoder: 260 | """ 261 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 262 | 263 | Returns: 264 | TransformerDecoder: Instantiation of Llama3.2 1B model 265 | """ 266 | return llama3_2_clevr3d( 267 | vocab_size=128_256, 268 | num_layers=16, 269 | num_heads=32, 270 | num_kv_heads=8, 271 | embed_dim=2048, 272 | max_seq_len=131072, 273 | intermediate_dim=8192, 274 | attn_dropout=0.0, 275 | norm_eps=1e-5, 276 | rope_base=500_000, 277 | scale_factor=32, 278 | added_tokens_offset=128256, 279 | vqgan_embed_dim=256, 280 | vqgan_start_index=129471, 281 | vqgan_end_index=130495, 282 | vqgan_vocab_size=1024, 283 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 284 | image_token_offset=129471, 285 | use_sin_cos_numbers=False, 286 | use_sin_cos_plus_learned=True, 287 | ) 288 | 289 | 290 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_domain_agnostic_vqgan() -> ( 291 | TransformerDecoder 292 | ): 293 | """ 294 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 295 | 296 | Returns: 297 | TransformerDecoder: Instantiation of Llama3.2 1B model 298 | """ 299 | return llama3_2_clevr3d( 300 | vocab_size=128_256, 301 | num_layers=16, 302 | num_heads=32, 303 | num_kv_heads=8, 304 | embed_dim=2048, 305 | max_seq_len=131072, 306 | intermediate_dim=8192, 307 | attn_dropout=0.0, 308 | norm_eps=1e-5, 309 | rope_base=500_000, 310 | scale_factor=32, 311 | added_tokens_offset=128256, 312 | vqgan_embed_dim=256, 313 | vqgan_start_index=129471, 314 | vqgan_end_index=137663, 315 | vqgan_vocab_size=8192, 316 | vqgan_codebook_path=DOMAIN_AGNOSTIC_VQGAN_CODEBOOK_PATH, 317 | image_token_offset=129471, 318 | use_sin_cos_numbers=False, 319 | use_sin_cos_plus_learned=True, 320 | ) 321 | 322 | 323 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_mlp_projector() -> ( 324 | TransformerDecoder 325 | ): 326 | """ 327 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 328 | 329 | Returns: 330 | TransformerDecoder: Instantiation of Llama3.2 1B model 331 | """ 332 | return llama3_2_clevr3d( 333 | vocab_size=128_256, 334 | num_layers=16, 335 | num_heads=32, 336 | num_kv_heads=8, 337 | embed_dim=2048, 338 | max_seq_len=131072, 339 | intermediate_dim=8192, 340 | attn_dropout=0.0, 341 | norm_eps=1e-5, 342 | rope_base=500_000, 343 | scale_factor=32, 344 | added_tokens_offset=128256, 345 | vqgan_embed_dim=256, 346 | vqgan_start_index=129471, 347 | vqgan_end_index=130495, 348 | vqgan_vocab_size=1024, 349 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 350 | image_token_offset=129471, 351 | use_sin_cos_numbers=False, 352 | use_sin_cos_plus_learned=True, 353 | use_mlp_projector=True, 354 | ) 355 | 356 | 357 | def llama3_2_3b_clevr3d_sin_cos_plus_learned_numbers() -> TransformerDecoder: 358 | """ 359 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 360 | 361 | Returns: 362 | TransformerDecoder: Instantiation of Llama3.2 1B model 363 | """ 364 | return llama3_2_clevr3d( 365 | vocab_size=128_256, 366 | num_layers=28, 367 | num_heads=24, 368 | num_kv_heads=8, 369 | embed_dim=3072, 370 | max_seq_len=131072, 371 | intermediate_dim=8192, 372 | attn_dropout=0.0, 373 | norm_eps=1e-5, 374 | rope_base=500_000, 375 | scale_factor=32, 376 | added_tokens_offset=128256, 377 | vqgan_embed_dim=256, 378 | vqgan_start_index=129471, 379 | vqgan_end_index=130495, 380 | vqgan_vocab_size=1024, 381 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 382 | image_token_offset=129471, 383 | use_sin_cos_numbers=False, 384 | use_sin_cos_plus_learned=True, 385 | ) 386 | 387 | 388 | def llama3_2_1b_clevr3d_NO_sin_cos_numbers() -> TransformerDecoder: 389 | """ 390 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 391 | 392 | Returns: 393 | TransformerDecoder: Instantiation of Llama3.2 1B model 394 | """ 395 | return llama3_2_clevr3d( 396 | vocab_size=128_256, 397 | num_layers=16, 398 | num_heads=32, 399 | num_kv_heads=8, 400 | embed_dim=2048, 401 | max_seq_len=131072, 402 | intermediate_dim=8192, 403 | attn_dropout=0.0, 404 | norm_eps=1e-5, 405 | rope_base=500_000, 406 | scale_factor=32, 407 | added_tokens_offset=128256, 408 | vqgan_embed_dim=256, 409 | vqgan_start_index=129471, 410 | vqgan_end_index=130495, 411 | vqgan_vocab_size=1024, 412 | vqgan_codebook_path=CLEVR_VQGAN_CODEBOOK_PATH, 413 | image_token_offset=129471, 414 | use_sin_cos_numbers=False, 415 | ) 416 | 417 | 418 | def llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_omni3d_objectron_custom_finer() -> ( 419 | TransformerDecoder 420 | ): 421 | """ 422 | Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values. 423 | 424 | Returns: 425 | TransformerDecoder: Instantiation of Llama3.2 1B model 426 | """ 427 | return llama3_2_clevr3d( 428 | vocab_size=128_256, 429 | num_layers=16, 430 | num_heads=32, 431 | num_kv_heads=8, 432 | embed_dim=2048, 433 | max_seq_len=131072, 434 | intermediate_dim=8192, 435 | attn_dropout=0.0, 436 | norm_eps=1e-5, 437 | rope_base=500_000, 438 | scale_factor=32, 439 | added_tokens_offset=128256, 440 | vqgan_embed_dim=256, 441 | vqgan_start_index=128372, 442 | vqgan_end_index=129396, 443 | vqgan_vocab_size=1024, 444 | vqgan_codebook_path=OBJECTRON_VQGAN_CODEBOOK_PATH, 445 | image_token_offset=128372, 446 | use_sin_cos_numbers=False, 447 | sin_cos_numbers_offset=104, 448 | use_sin_cos_plus_learned=True, 449 | ) 450 | -------------------------------------------------------------------------------- /scripts/generate_function.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, List, Optional, Tuple 8 | 9 | import torch 10 | from torchtune.modules.transformer import TransformerDecoder 11 | 12 | 13 | def multinomial_sample_one(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor: 14 | """Samples from a multinomial distribution.""" 15 | # return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) 16 | return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int) 17 | 18 | 19 | def sample( 20 | logits: torch.Tensor, 21 | *, 22 | temperature: float = 1.0, 23 | top_k: Optional[int] = None, 24 | q: Optional[torch.Tensor] = None, 25 | ) -> torch.Tensor: 26 | """Generic sample from a probability distribution. Includes support for Top-K sampling 27 | and Temperature. 28 | 29 | Args: 30 | logits (torch.Tensor): logits from which to sample 31 | temperature (float): value to scale the predicted logits by, default 1.0. 32 | top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities 33 | q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick. If None, 34 | we use the default softmax sampling trick. Default None. 35 | 36 | Example: 37 | >>> from torchtune.generation import sample 38 | >>> logits = torch.empty(3, 3).uniform_(0, 1) 39 | >>> sample(logits) 40 | tensor([[1], 41 | [2], 42 | [0]], dtype=torch.int32) 43 | 44 | Returns: 45 | torch.Tensor: sampled token id 46 | """ 47 | # scale the logits based on temperature 48 | logits = logits / max(temperature, 1e-5) 49 | 50 | if top_k is not None: 51 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 52 | # select the very last value from the top_k above as the pivot 53 | pivot = v.select(-1, -1).unsqueeze(-1) 54 | # set everything smaller than pivot value to inf since these 55 | # should be pruned 56 | logits = torch.where(logits < pivot, -float("Inf"), logits) 57 | 58 | # change logits into probabilities 59 | probs = torch.nn.functional.softmax(logits, dim=-1) 60 | 61 | # if q is None, we use the default softmax sampling trick 62 | if q is None: 63 | q = torch.empty_like(probs).exponential_(1) 64 | 65 | return multinomial_sample_one(probs, q) 66 | 67 | 68 | def generate_next_token( 69 | model: TransformerDecoder, 70 | input_pos: torch.Tensor, 71 | x: torch.Tensor, 72 | q: torch.Tensor, 73 | *, 74 | mask: Optional[torch.Tensor] = None, 75 | temperature: float = 1.0, 76 | top_k: Optional[int] = None, 77 | ) -> Tuple[torch.Tensor, torch.Tensor]: 78 | """ 79 | Generates the next tokens given a prompt, and also returns the corresponding logits. 80 | 81 | Args: 82 | model (TransformerDecoder): model used for generation 83 | input_pos (torch.Tensor): tensor with the positional encodings associated with the given prompt, 84 | with shape [bsz x seq_length]. 85 | x (torch.Tensor): tensor with the token IDs associated with the given prompt, 86 | with shape [bsz x seq_length]. 87 | q (torch.Tensor): randomly sampled tensor for softmax sampling trick. 88 | See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40 89 | mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length], 90 | default None. 91 | temperature (float): value to scale the predicted logits by, default 1.0. 92 | top_k (Optional[int]): Top-k value to use for sampling, default None. 93 | 94 | Returns: 95 | Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: 96 | - tokens (torch.Tensor): tensor with the generated tokens, 97 | with shape [bsz x 1]. 98 | - logits (torch.Tensor): tensor with the logits associated with the generated tokens, 99 | with shape [bsz x seq_length x vocab_size]. 100 | 101 | """ 102 | # model produces logits in [bsz, seq_length, vocab_size] 103 | # we want to take the last token's logits as the input to the next model call 104 | logits = model(x, input_pos=input_pos, mask=mask) 105 | return ( 106 | sample(logits[:, -1].clone(), temperature=temperature, top_k=top_k, q=q), 107 | logits, 108 | ) 109 | 110 | 111 | def update_stop_tokens_tracker( 112 | tokens: torch.Tensor, stop_tokens: torch.Tensor, stop_token_reached: torch.Tensor 113 | ) -> torch.Tensor: 114 | """Updates which sequences have reached a stop token.""" 115 | # tokens: [bsz, 1] 116 | # stop_tokens: [num_stop_tokens] 117 | # stop_token_reached: [bsz] 118 | stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten() 119 | stop_token_reached |= stop_token_reached_curr 120 | return stop_token_reached 121 | 122 | 123 | def get_causal_mask_from_padding_mask( 124 | padding_mask: torch.Tensor, target_seq_len: Optional[int] = None 125 | ) -> torch.Tensor: 126 | """ 127 | Converts a padding mask of shape ``[bsz, seq_len]`` to a ``[bsz, seq_len, seq_len]`` causal attention mask suitable for 128 | consumption by :func:`~torch.nn.functional.scaled_dot_product_attention`. If ``target_seq_len`` 129 | is provided, this will return a mask of shape ``[bsz, seq_len, target_seq_len]``. This is useful 130 | when generating masks for static KV caches where the maximum length the caches have been setup with 131 | are longer than the current sequence. 132 | 133 | Args: 134 | padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence 135 | is a padding token and should be masked out in attention, with shape [bsz x seq_length] 136 | target_seq_len (Optional[int]): target sequence length to create attention mask with. Default None. 137 | 138 | Returns: 139 | torch.Tensor: Boolean causal mask with shape 140 | - [bsz, seq_length, seq_length] or 141 | - [bsz, seq_length, target_seq_len] if ``target_seq_len`` was specified. 142 | 143 | Raises: 144 | AssertionError: if ``target_seq_len > seq_len``, the sequence length of the padding mask. 145 | 146 | Example: 147 | >>> padding_mask = torch.tensor([[False, True, True, True]]) 148 | >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) 149 | tensor([[[ True, False, False, False, False], 150 | [False, True, False, False, False], 151 | [False, True, True, False, False], 152 | [False, True, True, True, False]]]) 153 | ]) 154 | """ 155 | bsz, seq_len = padding_mask.shape 156 | target_seq_len = seq_len if target_seq_len is None else target_seq_len 157 | 158 | if target_seq_len < seq_len: 159 | raise AssertionError( 160 | "target_seq_len cannot be shorter than the sequence length of the padding mask." 161 | ) 162 | 163 | mask = torch.tril( 164 | torch.ones(seq_len, target_seq_len, device=padding_mask.device, dtype=bool), 165 | diagonal=0, 166 | ).repeat(bsz, 1, 1) 167 | mask.narrow(2, 0, seq_len).mul_(padding_mask[:, None, :].expand(-1, seq_len, -1)) 168 | mask.diagonal(dim1=1, dim2=2).copy_(torch.Tensor([True])) 169 | return mask 170 | 171 | 172 | def get_position_ids_from_padding_mask( 173 | padding_mask: torch.Tensor, 174 | ): 175 | """ 176 | Calculates position ids given a padding mask which right-shifts position ids to start 177 | from the first valid token. 178 | 179 | Args: 180 | padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence 181 | is a padding token and should be masked out in attention. Shape [bsz, seq_len] 182 | 183 | Returns: 184 | torch.Tensor: position ids which are appropriately shifted according to any padding values. 185 | 186 | Example: 187 | >>> padding_mask = torch.tensor([False, False, False, True, True, True, True, True]) 188 | >>> get_position_ids_from_padding_mask(padding_mask) 189 | torch.Tensor([0, 0, 0, 0, 1, 2, 3, 4]) 190 | """ 191 | return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int) 192 | 193 | 194 | @torch.inference_mode() 195 | def generate( 196 | model: TransformerDecoder, 197 | prompt: torch.Tensor, 198 | *, 199 | max_generated_tokens: int, 200 | pad_id: int = 0, 201 | temperature: float = 1.0, 202 | top_k: Optional[int] = None, 203 | stop_tokens: Optional[List[int]] = None, 204 | rng: Optional[torch.Generator] = None, 205 | custom_generate_next_token: Optional[Callable] = None, 206 | ) -> Tuple[torch.Tensor, torch.Tensor]: 207 | """ 208 | Generates tokens from a model conditioned on a prompt, and also returns logits for the generations. 209 | 210 | Args: 211 | model (TransformerDecoder): model used for generation 212 | prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, 213 | with shape either [seq_length] or [bsz x seq_length]. 214 | max_generated_tokens (int): number of tokens to be generated 215 | pad_id (int): token ID to use for padding, default 0. 216 | temperature (float): value to scale the predicted logits by, default 1.0. 217 | top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, 218 | default None. 219 | stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated, 220 | default None. 221 | rng (Optional[torch.Generator]): random number generator, default None. 222 | custom_generate_next_token (Optional[Callable]): If specified, we'll use the 223 | ``custom_generate_next_token function``. This is generally only useful if 224 | you want to specify a ``torch.compile`` version of the generate next token for 225 | performance reasons. If None, we use the default :func:`generate_next_token`. 226 | Default is None. 227 | 228 | Note: 229 | This function has only been tested with decoder-only models. 230 | 231 | Examples: 232 | >>> model = torchtune.models.llama3.llama3_8b() 233 | >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() 234 | >>> prompt = tokenizer.encode("Hi my name is") 235 | >>> rng.manual_seed(42) 236 | >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0) 237 | >>> print(tokenizer.decode(output[0].tolist())) 238 | Hi my name is Jeremy and I'm a friendly language model assistant! 239 | 240 | Returns: 241 | Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: 242 | - tokens (torch.Tensor): tensor with the generated tokens, 243 | with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens`` 244 | may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided. 245 | - logits (torch.Tensor): tensor with the logits associated with the generated tokens, 246 | with shape ``[bsz x seq_len + num_generated_tokens x vocab_size]``. 247 | """ 248 | prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt 249 | 250 | stop_tokens = ( 251 | torch.tensor(stop_tokens, device=prompt.device) if stop_tokens else None 252 | ) 253 | 254 | if custom_generate_next_token is None: 255 | custom_generate_next_token = generate_next_token 256 | 257 | bsz, prompt_length = prompt.size() 258 | total_response_length = prompt_length + max_generated_tokens 259 | 260 | generated_tokens = prompt.clone() 261 | incremental_decoding = model.caches_are_enabled() 262 | 263 | # grab the correct max_seq_len to generate full causal masks/position ids 264 | # this is the model's max cache len if incremental decoding, or the sequence 265 | # length otherwise 266 | max_seq_len = ( 267 | total_response_length 268 | if not incremental_decoding 269 | else model.decoder_max_cache_seq_len 270 | ) 271 | 272 | padding_masks = generated_tokens != pad_id 273 | 274 | if not padding_masks.all(): 275 | # we have padding in the prompt due to varying-length sequences in a batch 276 | # extend padding masks out to the correct seq len 277 | padding_masks = torch.nn.functional.pad( 278 | padding_masks, (0, max_generated_tokens), value=True 279 | ) 280 | 281 | # generate the full causal mask for the whole padding mask with padding ignored 282 | masks = get_causal_mask_from_padding_mask( 283 | padding_masks, target_seq_len=max_seq_len 284 | ) 285 | 286 | # right-shift position IDs to account for padding 287 | input_pos = get_position_ids_from_padding_mask(padding_masks) 288 | else: 289 | # just use a regular causal mask if there is no padding 290 | masks = torch.tril( 291 | torch.ones( 292 | total_response_length, 293 | max_seq_len, 294 | dtype=torch.bool, 295 | device=prompt.device, 296 | ) 297 | ).unsqueeze(0) 298 | input_pos = torch.arange( 299 | 0, total_response_length, device=generated_tokens.device 300 | ).unsqueeze(0) 301 | 302 | if incremental_decoding: 303 | # if KV-caches are enabled, we need a causal mask of shape [bsz, prompt_length, max_cache_len] 304 | # to match the key/value cache tensor shapes 305 | curr_masks = masks[:, :prompt_length] 306 | else: 307 | # otherwise the causal mask is shape [bsz, prompt_length, prompt_length] because key/value 308 | # tensors are of identical shape to the prompt 309 | curr_masks = masks[:, :prompt_length, :prompt_length] 310 | 311 | # q = torch.empty( 312 | # (bsz, model.tok_embeddings.num_embeddings), device=prompt.device 313 | # ).exponential_(1, generator=rng) 314 | q = None 315 | tokens, generated_logits = generate_next_token( 316 | model, 317 | input_pos=input_pos[:, :prompt_length].squeeze(), 318 | mask=curr_masks, 319 | x=prompt, 320 | temperature=temperature, 321 | top_k=top_k, 322 | q=q, 323 | ) 324 | 325 | generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) 326 | 327 | curr_pos = prompt_length 328 | 329 | # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop 330 | stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) 331 | 332 | # everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences 333 | # that already hit a stop token 334 | stop_token_mask = torch.ones( 335 | (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device 336 | ) 337 | 338 | # stop early if we reach a stop token in every seq 339 | if stop_tokens is not None: 340 | stop_token_reached = update_stop_tokens_tracker( 341 | tokens, stop_tokens, stop_token_reached 342 | ) 343 | if stop_token_reached.all().item(): 344 | return generated_tokens, generated_logits 345 | 346 | for _ in range(max_generated_tokens - 1): 347 | # update stop_token_mask if we reached a stop token in a previous step 348 | # by appending the logical not of stop_token_reached to the end of the mask 349 | # reshaped to be bsz first 350 | if stop_tokens is not None: 351 | stop_token_mask = torch.cat( 352 | [stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1 353 | ) 354 | 355 | # if incremental decoding is enabled, we can use the current position 356 | # otherwise, we take the whole sequence up to the current position 357 | if incremental_decoding: 358 | curr_input_pos = input_pos[:, curr_pos] 359 | curr_masks = masks[:, curr_pos, None, :] 360 | else: 361 | tokens = generated_tokens.clone() 362 | curr_input_pos = input_pos[:, : curr_pos + 1] 363 | curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] 364 | 365 | # q = torch.empty( 366 | # (bsz, model.tok_embeddings.num_embeddings), device=prompt.device 367 | # ).exponential_(1, generator=rng) 368 | q = None 369 | tokens, logits = custom_generate_next_token( 370 | model, 371 | input_pos=curr_input_pos, 372 | x=tokens, 373 | mask=curr_masks, 374 | temperature=temperature, 375 | top_k=top_k, 376 | q=q, 377 | ) 378 | generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) 379 | curr_pos += 1 380 | if incremental_decoding: 381 | generated_logits = torch.cat([generated_logits, logits], dim=1) 382 | else: 383 | generated_logits = logits 384 | 385 | if stop_tokens is not None: 386 | stop_token_reached = update_stop_tokens_tracker( 387 | tokens, stop_tokens, stop_token_reached 388 | ) 389 | if stop_token_reached.all(): 390 | break 391 | 392 | # mask out generated tokens in seqs that already hit a stop token 393 | if stop_tokens is not None: 394 | generated_tokens *= stop_token_mask 395 | generated_logits *= stop_token_mask[:, :-1, None] 396 | 397 | return generated_tokens, generated_logits 398 | -------------------------------------------------------------------------------- /dataset/multimodal_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Mapping, Optional 8 | 9 | from torch.utils.data import Dataset 10 | import json 11 | import numpy as np 12 | import math 13 | from torchtune.modules.tokenizers import ModelTokenizer 14 | from tqdm import tqdm 15 | 16 | 17 | class ThreeDMLLMDataset(Dataset): 18 | """ 19 | Dataset class for Kyvo data. This dataset is used for multimodal tasks that require text, image, and 3D data. 20 | 21 | Args: 22 | task_type (str): The type of task to perform. Options are "I-3", "3-I", "I+Q-A", "I+3+Q-A", "3+I+Q-A", "3+T-3", etc 23 | text_source (str): The path to the text source file. 24 | image_source (str): The path to the image source file. 25 | three_d_source (str): The path to the 3d source file. 26 | text_target (str): The path to the text target file. 27 | image_target (str): The path to the image target file. 28 | three_d_target (str): The path to the 3d target file. 29 | max_seq_len (int): The maximum sequence length to use for the model. Default is None. 30 | image_token_offset (int): The offset to add to the image tokens. Default is 0. 31 | load_dataset_kwargs (Dict[str, Any]): Additional keyword arguments to pass to the 32 | dataset loader. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | task_type: str, # "I-3", "3-I", "I+Q-A", "I+3+Q-A", "3+I+Q-A", "3+T-3", etc 38 | text_source: str, # json file path 39 | image_source: str = None, # json file path 40 | three_d_source: str = None, # json file path 41 | text_target: str = None, # json file path 42 | image_target: str = None, # json file path 43 | three_d_target: str = None, # json file path 44 | max_seq_len: Optional[int] = None, 45 | image_token_offset: int = 129471, # tokenizer dependent 46 | **load_dataset_kwargs: Dict[str, Any], 47 | ) -> None: 48 | 49 | self.task_type = task_type # "vqa", "edit", "difference" 50 | if self.task_type == "UNIFIED": 51 | self.unified_task = True 52 | else: 53 | self.unified_task = False 54 | self.no_loss_on_input = load_dataset_kwargs.get("no_loss_on_input", False) 55 | self.reorder_image_tokens = load_dataset_kwargs.get( 56 | "reorder_image_tokens", False 57 | ) 58 | self.image_token_offset = image_token_offset 59 | self.max_seq_len = max_seq_len 60 | self.num_samples = load_dataset_kwargs.get("num_samples", None) 61 | 62 | self.dataset_name = load_dataset_kwargs.get("dataset_name", None) 63 | assert self.dataset_name is not None, "Dataset name is required" 64 | 65 | self.load_text_source = load_dataset_kwargs.get("load_text_source", False) 66 | self.load_image_source = load_dataset_kwargs.get("load_image_source", False) 67 | self.load_three_d_source = load_dataset_kwargs.get("load_three_d_source", False) 68 | self.load_text_target = load_dataset_kwargs.get("load_text_target", False) 69 | self.load_image_target = load_dataset_kwargs.get("load_image_target", False) 70 | self.load_three_d_target = load_dataset_kwargs.get("load_three_d_target", False) 71 | 72 | self._text_data = None 73 | self._text_data_keys = None 74 | self._image_data = None 75 | self._image_data_keys = None 76 | self._three_d_data = None 77 | self._three_d_data_keys = None 78 | self._text_target = None 79 | self._text_target_keys = None 80 | self._image_target = None 81 | self._image_target_keys = None 82 | self._three_d_target = None 83 | self._three_d_target_keys = None 84 | 85 | if self.load_text_source: 86 | print("Loading text source") 87 | with open(text_source, "r") as f: 88 | text_data = json.load(f) 89 | print("Preparing text data") 90 | self._text_data = { 91 | k: v["token_ids"] for k, v in tqdm(sorted(text_data.items())) 92 | } 93 | self._text_data_keys = list(sorted(self._text_data.keys())) 94 | print("Number of text samples: ", len(self._text_data_keys)) 95 | 96 | if self.load_image_source: 97 | print("Loading image source") 98 | with open(image_source, "r") as f: 99 | image_data = json.load(f) 100 | 101 | if ( 102 | self.reorder_image_tokens 103 | and self.task_type != "I+3+T-I+3" 104 | and self.task_type != "3+I+T-3+I" 105 | ): 106 | print("Reordering image tokens (source)") 107 | self._image_data = { 108 | k: self.reorder_list_optimized(image_data[k]) 109 | for k in tqdm(sorted(image_data.keys())) 110 | } 111 | self._image_data_keys = list(sorted(self._image_data.keys())) 112 | else: 113 | self._image_data = { 114 | k: image_data[k] for k in tqdm(sorted(image_data.keys())) 115 | } 116 | self._image_data_keys = list(sorted(self._image_data.keys())) 117 | print("Number of image samples: ", len(self._image_data_keys)) 118 | 119 | if self.load_three_d_source: 120 | print("Loading 3D source") 121 | with open(three_d_source, "r") as f: 122 | three_d_data = json.load(f) 123 | print("Preparing 3D data") 124 | self._three_d_data = { 125 | k: v["token_ids"] for k, v in tqdm(sorted(three_d_data.items())) 126 | } 127 | self._three_d_data_keys = list(sorted(self._three_d_data.keys())) 128 | print("Number of 3D samples: ", len(self._three_d_data_keys)) 129 | 130 | if self.load_text_target: 131 | print("Loading text target") 132 | with open(text_target, "r") as f: 133 | text_target = json.load(f) 134 | self._text_target = { 135 | k: v["token_ids"] for k, v in sorted(text_target.items()) 136 | } 137 | self._text_target_keys = list(sorted(self._text_target.keys())) 138 | print("Number of text target samples: ", len(self._text_target)) 139 | 140 | if self.load_image_target: 141 | print("Loading image target") 142 | with open(image_target, "r") as f: 143 | image_target = json.load(f) 144 | if self.reorder_image_tokens: 145 | print("Reordering image tokens (target)") 146 | self._image_target = { 147 | k: self.reorder_list_optimized(image_target[k]) 148 | for k in tqdm(sorted(image_target.keys())) 149 | } 150 | self._image_target_keys = list(sorted(self._image_target.keys())) 151 | else: 152 | self._image_target = { 153 | k: image_target[k] for k in tqdm(sorted(image_target.keys())) 154 | } 155 | self._image_target_keys = list(sorted(self._image_target.keys())) 156 | print("Number of image target samples: ", len(self._image_target)) 157 | 158 | if self.load_three_d_target: 159 | print("Loading 3D target") 160 | with open(three_d_target, "r") as f: 161 | three_d_target = json.load(f) 162 | self._three_d_target = { 163 | k: v["token_ids"] for k, v in sorted(three_d_target.items()) 164 | } 165 | self._three_d_target_keys = list(sorted(self._three_d_target.keys())) 166 | print("Number of 3D target samples: ", len(self._three_d_target)) 167 | 168 | # get the common keys for all the data which are not None 169 | common_keys = set() 170 | if self._text_data is not None: 171 | if len(common_keys) == 0: 172 | common_keys = set(self._text_data.keys()) 173 | # common_keys = common_keys.union(set(self._text_data.keys())) 174 | if self._image_data is not None: 175 | if len(common_keys) == 0: 176 | common_keys = set(self._image_data.keys()) 177 | common_keys = common_keys.intersection(set(self._image_data.keys())) 178 | if self._three_d_data is not None: 179 | if len(common_keys) == 0: 180 | common_keys = set(self._three_d_data.keys()) 181 | common_keys = common_keys.intersection(set(self._three_d_data.keys())) 182 | if self._text_target is not None: 183 | if len(common_keys) == 0: 184 | common_keys = set(self._text_target.keys()) 185 | common_keys = common_keys.intersection(set(self._text_target.keys())) 186 | if self._image_target is not None: 187 | if len(common_keys) == 0: 188 | common_keys = set(self._image_target.keys()) 189 | common_keys = common_keys.intersection(set(self._image_target.keys())) 190 | if self._three_d_target is not None: 191 | if len(common_keys) == 0: 192 | common_keys = set(self._three_d_target.keys()) 193 | common_keys = common_keys.intersection(set(self._three_d_target.keys())) 194 | 195 | if len(common_keys) == 0: 196 | raise ValueError("No common keys found for all the data") 197 | 198 | if self.num_samples is not None: 199 | common_keys = list(common_keys)[: self.num_samples] 200 | 201 | if self._text_data is not None: 202 | self._text_data = {k: self._text_data[k] for k in common_keys} 203 | self._text_data_keys = list(sorted(self._text_data.keys())) 204 | 205 | if self._image_data is not None: 206 | self._image_data = {k: self._image_data[k] for k in common_keys} 207 | self._image_data_keys = list(sorted(self._image_data.keys())) 208 | 209 | if self._three_d_data is not None: 210 | self._three_d_data = {k: self._three_d_data[k] for k in common_keys} 211 | self._three_d_data_keys = list(sorted(self._three_d_data.keys())) 212 | 213 | if self._text_target is not None: 214 | self._text_target = {k: self._text_target[k] for k in common_keys} 215 | self._text_target_keys = list(sorted(self._text_target.keys())) 216 | 217 | if self._image_target is not None: 218 | self._image_target = {k: self._image_target[k] for k in common_keys} 219 | self._image_target_keys = list(sorted(self._image_target.keys())) 220 | 221 | if self._three_d_target is not None: 222 | self._three_d_target = {k: self._three_d_target[k] for k in common_keys} 223 | self._three_d_target_keys = list(sorted(self._three_d_target.keys())) 224 | 225 | print("Total number of samples: ", len(common_keys)) 226 | print(f"Initialization of Kyvo dataset complete for {self.dataset_name}!") 227 | 228 | def reorder_list_optimized(self, data): 229 | n = len(data) 230 | center = int((n // 2) - (math.sqrt(n) // 2)) # center of the image 231 | reordered = [0] * n 232 | reordered[0] = data[center] 233 | left, right = center - 1, center + 1 234 | index = 1 235 | 236 | while left >= 0 or right < n: 237 | if left >= 0: 238 | reordered[index] = data[left] 239 | left -= 1 240 | index += 1 241 | if right < n: 242 | reordered[index] = data[right] 243 | right += 1 244 | index += 1 245 | 246 | return reordered 247 | 248 | def __len__(self): 249 | try: 250 | return len(self._text_data) 251 | except: 252 | return len(self._three_d_data) 253 | 254 | def __getitem__(self, index: int) -> Dict[str, List[int]]: 255 | 256 | # assert that the keys are the same for all the data both source and target which are not None 257 | assert_keys = [] 258 | if self._text_data is not None: 259 | assert_keys.append(self._text_data_keys[index]) 260 | text_sample = self._text_data[self._text_data_keys[index]] 261 | if self._image_data is not None: 262 | assert_keys.append(self._image_data_keys[index]) 263 | image_sample = self._image_data[self._image_data_keys[index]] 264 | image_sample = [x + self.image_token_offset for x in image_sample] 265 | if self._three_d_data is not None: 266 | assert_keys.append(self._three_d_data_keys[index]) 267 | three_d_sample = self._three_d_data[self._three_d_data_keys[index]] 268 | if self._text_target is not None: 269 | assert_keys.append(self._text_target_keys[index]) 270 | text_target = self._text_target[self._text_target_keys[index]] 271 | if self._image_target is not None: 272 | assert_keys.append(self._image_target_keys[index]) 273 | image_target = self._image_target[self._image_target_keys[index]] 274 | image_target = [x + self.image_token_offset for x in image_target] 275 | if self._three_d_target is not None: 276 | assert_keys.append(self._three_d_target_keys[index]) 277 | three_d_target = self._three_d_target[self._three_d_target_keys[index]] 278 | 279 | # assert that the keys are the same for all the data both source and target 280 | assert all(x == assert_keys[0] for x in assert_keys), f"Keys are not the same" 281 | 282 | BOS = [128000] 283 | EOS = [128001] 284 | 285 | if self.dataset_name == "CLEVR" or self.dataset_name == "ObjaWorld": 286 | BOIMG = [129466] 287 | EOIMG = [129467] 288 | OUTSEP = [129470] 289 | elif self.dataset_name == "Omni3D-Objectron": 290 | BOIMG = [128366] 291 | EOIMG = [128367] 292 | OUTSEP = [128370] 293 | 294 | if self.unified_task: 295 | # choose a random task type 296 | task_types = [ 297 | "3-I", 298 | # "I+Q-A", 299 | # "I+3+Q-A", 300 | "3+T-3", 301 | "I+3+T-3", 302 | "I+3+T-I+3", 303 | "I+T-I+3", 304 | "I+T-3", 305 | "I+T-I", 306 | "I-3", 307 | ] 308 | self.task_type = np.random.choice(task_types) 309 | 310 | if self.task_type == "3-I": 311 | # 3D --> Image 312 | tokens = BOS + three_d_sample + OUTSEP + BOIMG + image_sample + EOIMG + EOS 313 | if self.no_loss_on_input: 314 | labels = ( 315 | [-100] * len(BOS + three_d_sample + OUTSEP) 316 | + BOIMG 317 | + image_sample 318 | + EOIMG 319 | + EOS 320 | ) 321 | else: 322 | labels = tokens.copy() 323 | 324 | elif self.task_type == "I+Q-A": 325 | # 3D --> Image 326 | tokens = ( 327 | BOS 328 | + BOIMG 329 | + image_sample 330 | + EOIMG 331 | + text_sample 332 | + OUTSEP 333 | + text_target 334 | + EOS 335 | ) 336 | if self.no_loss_on_input: 337 | labels = ( 338 | [-100] 339 | * len(BOS + BOIMG + image_sample + EOIMG + text_sample + OUTSEP) 340 | + text_target 341 | + EOS 342 | ) 343 | else: 344 | labels = tokens.copy() 345 | 346 | elif self.task_type == "I+3+Q-A": 347 | # 3D --> Image 348 | tokens = ( 349 | BOS 350 | + BOIMG 351 | + image_sample 352 | + EOIMG 353 | + three_d_sample 354 | + text_sample 355 | + OUTSEP 356 | + text_target 357 | + EOS 358 | ) 359 | if self.no_loss_on_input: 360 | labels = ( 361 | [-100] 362 | * len( 363 | BOS 364 | + BOIMG 365 | + image_sample 366 | + EOIMG 367 | + three_d_sample 368 | + text_sample 369 | + OUTSEP 370 | ) 371 | + text_target 372 | + EOS 373 | ) 374 | else: 375 | labels = tokens.copy() 376 | 377 | elif self.task_type == "3+I+Q-A": 378 | # 3D --> Image 379 | tokens = ( 380 | BOS 381 | + three_d_sample 382 | + BOIMG 383 | + image_sample 384 | + EOIMG 385 | + text_sample 386 | + OUTSEP 387 | + text_target 388 | + EOS 389 | ) 390 | if self.no_loss_on_input: 391 | labels = ( 392 | [-100] 393 | * len( 394 | BOS 395 | + three_d_sample 396 | + BOIMG 397 | + image_sample 398 | + EOIMG 399 | + text_sample 400 | + OUTSEP 401 | ) 402 | + text_target 403 | + EOS 404 | ) 405 | else: 406 | labels = tokens.copy() 407 | 408 | elif self.task_type == "3-I-first3output": 409 | # 3D --> Image 410 | tokens = BOS + three_d_sample + OUTSEP + BOIMG + image_sample + EOIMG + EOS 411 | if self.no_loss_on_input: 412 | labels = ( 413 | [-100] 414 | * len(BOS + three_d_sample + OUTSEP + BOIMG + image_sample[:3]) 415 | + image_sample[3:] 416 | + EOIMG 417 | + EOS 418 | ) 419 | else: 420 | labels = tokens.copy() 421 | 422 | elif self.task_type == "3+T-3": 423 | # 3D + Text --> 3D 424 | tokens = BOS + three_d_sample + text_sample + OUTSEP + three_d_target + EOS 425 | if self.no_loss_on_input: 426 | labels = ( 427 | [-100] * len(BOS + three_d_sample + text_sample + OUTSEP) 428 | + three_d_target 429 | + EOS 430 | ) 431 | else: 432 | labels = tokens.copy() 433 | 434 | elif self.task_type == "I+3+T-3": 435 | # Image + 3D + Text --> 3D 436 | tokens = ( 437 | BOS 438 | + BOIMG 439 | + image_sample 440 | + EOIMG 441 | + three_d_sample 442 | + text_sample 443 | + OUTSEP 444 | + three_d_target 445 | + EOS 446 | ) 447 | if self.no_loss_on_input: 448 | labels = ( 449 | [-100] 450 | * len( 451 | BOS 452 | + BOIMG 453 | + image_sample 454 | + EOIMG 455 | + three_d_sample 456 | + text_sample 457 | + OUTSEP 458 | ) 459 | + three_d_target 460 | + EOS 461 | ) 462 | else: 463 | labels = tokens.copy() 464 | 465 | elif self.task_type == "I+3+T-I+3": 466 | # Image + 3D + Text --> Image + 3D 467 | tokens = ( 468 | BOS 469 | + BOIMG 470 | + image_sample 471 | + EOIMG 472 | + three_d_sample 473 | + text_sample 474 | + OUTSEP 475 | + BOIMG 476 | + image_target 477 | + EOIMG 478 | + three_d_target 479 | + EOS 480 | ) 481 | if self.no_loss_on_input: 482 | labels = ( 483 | [-100] 484 | * len( 485 | BOS 486 | + BOIMG 487 | + image_sample 488 | + EOIMG 489 | + three_d_sample 490 | + text_sample 491 | + OUTSEP 492 | ) 493 | + BOIMG 494 | + image_target 495 | + EOIMG 496 | + three_d_target 497 | + EOS 498 | ) 499 | else: 500 | labels = tokens.copy() 501 | 502 | elif self.task_type == "3+I+T-3+I": 503 | # 3D + Image + Text --> 3D + Image 504 | tokens = ( 505 | BOS 506 | + three_d_sample 507 | + BOIMG 508 | + image_sample 509 | + EOIMG 510 | + text_sample 511 | + OUTSEP 512 | + three_d_target 513 | + BOIMG 514 | + image_target 515 | + EOIMG 516 | + EOS 517 | ) 518 | if self.no_loss_on_input: 519 | labels = ( 520 | [-100] 521 | * len( 522 | BOS 523 | + three_d_sample 524 | + BOIMG 525 | + image_sample 526 | + EOIMG 527 | + text_sample 528 | + OUTSEP 529 | ) 530 | + three_d_target 531 | + BOIMG 532 | + image_target 533 | + EOIMG 534 | + EOS 535 | ) 536 | else: 537 | labels = tokens.copy() 538 | 539 | elif self.task_type == "I+T-I+3": 540 | # Image + Text --> Image + 3D 541 | tokens = ( 542 | BOS 543 | + BOIMG 544 | + image_sample 545 | + EOIMG 546 | + text_sample 547 | + OUTSEP 548 | + BOIMG 549 | + image_target 550 | + EOIMG 551 | + three_d_target 552 | + EOS 553 | ) 554 | if self.no_loss_on_input: 555 | labels = ( 556 | [-100] 557 | * len(BOS + BOIMG + image_sample + EOIMG + text_sample + OUTSEP) 558 | + BOIMG 559 | + image_target 560 | + EOIMG 561 | + three_d_target 562 | + EOS 563 | ) 564 | else: 565 | labels = tokens.copy() 566 | 567 | elif self.task_type == "I+T-3": 568 | # Image + Text --> 3D 569 | tokens = ( 570 | BOS 571 | + BOIMG 572 | + image_sample 573 | + EOIMG 574 | + text_sample 575 | + OUTSEP 576 | + three_d_target 577 | + EOS 578 | ) 579 | if self.no_loss_on_input: 580 | labels = ( 581 | [-100] 582 | * len(BOS + BOIMG + image_sample + EOIMG + text_sample + OUTSEP) 583 | + three_d_target 584 | + EOS 585 | ) 586 | else: 587 | labels = tokens.copy() 588 | 589 | elif self.task_type == "I+T-I": 590 | # Image + Text --> Image 591 | tokens = ( 592 | BOS 593 | + BOIMG 594 | + image_sample 595 | + EOIMG 596 | + text_sample 597 | + OUTSEP 598 | + BOIMG 599 | + image_target 600 | + EOIMG 601 | + EOS 602 | ) 603 | if self.no_loss_on_input: 604 | labels = ( 605 | [-100] 606 | * len(BOS + BOIMG + image_sample + EOIMG + text_sample + OUTSEP) 607 | + BOIMG 608 | + image_target 609 | + EOIMG 610 | + EOS 611 | ) 612 | else: 613 | labels = tokens.copy() 614 | 615 | elif self.task_type == "I-3": 616 | # Image --> 3D 617 | tokens = BOS + BOIMG + image_sample + EOIMG + OUTSEP + three_d_sample + EOS 618 | if self.no_loss_on_input: 619 | labels = ( 620 | [-100] * len(BOS + BOIMG + image_sample + EOIMG + OUTSEP) 621 | + three_d_sample 622 | + EOS 623 | ) 624 | else: 625 | labels = tokens.copy() 626 | 627 | elif self.task_type == "I+T-T": 628 | # Image + Text --> Text 629 | tokens = ( 630 | BOS 631 | + BOIMG 632 | + image_sample 633 | + EOIMG 634 | + text_sample 635 | + OUTSEP 636 | + text_target 637 | + EOS 638 | ) 639 | if self.no_loss_on_input: 640 | labels = ( 641 | [-100] 642 | * len(BOS + BOIMG + image_sample + EOIMG + text_sample + OUTSEP) 643 | + text_target 644 | + EOS 645 | ) 646 | else: 647 | labels = tokens.copy() 648 | 649 | elif self.task_type == "I+3+T-T": 650 | # Image + 3D + Text --> Text 651 | tokens = ( 652 | BOS 653 | + BOIMG 654 | + image_sample 655 | + EOIMG 656 | + three_d_sample 657 | + text_sample 658 | + OUTSEP 659 | + text_target 660 | + EOS 661 | ) 662 | if self.no_loss_on_input: 663 | labels = ( 664 | [-100] 665 | * len( 666 | BOS 667 | + BOIMG 668 | + image_sample 669 | + EOIMG 670 | + three_d_sample 671 | + text_sample 672 | + OUTSEP 673 | ) 674 | + text_target 675 | + EOS 676 | ) 677 | else: 678 | labels = tokens.copy() 679 | 680 | else: 681 | raise ValueError(f"Task type {self.task_type} not recognized") 682 | 683 | return { 684 | "tokens": tokens, 685 | "labels": labels, 686 | } 687 | 688 | 689 | def threed_mllm_dataset( 690 | task_type: str, 691 | text_source: str, 692 | image_source: str, 693 | three_d_source: str, 694 | max_seq_len: Optional[int] = None, 695 | image_token_offset: int = 0, 696 | **load_dataset_kwargs: Dict[str, Any], 697 | ) -> ThreeDMLLMDataset: 698 | return ThreeDMLLMDataset( 699 | task_type=task_type, 700 | text_source=text_source, 701 | image_source=image_source, 702 | three_d_source=three_d_source, 703 | max_seq_len=max_seq_len, 704 | image_token_offset=image_token_offset, 705 | **load_dataset_kwargs, 706 | ) 707 | --------------------------------------------------------------------------------