├── dataset ├── __init__.py └── multimodal_dataset.py ├── docs └── teaser.png ├── models ├── __init__.py └── model_builders.py ├── .gitignore ├── checkpointer ├── __init__.py ├── convert_weights.py └── checkpointer.py ├── requirements.txt ├── .gitmodules ├── configs └── llama3_2 │ ├── eval │ ├── objaworld │ │ ├── rendering.yaml │ │ └── recognition.yaml │ ├── objectron │ │ └── recognition.yaml │ └── clevr │ │ ├── question-answering.yaml │ │ ├── rendering.yaml │ │ ├── recognition.yaml │ │ ├── recognition-colab-demo.yaml │ │ └── instruction-following │ │ ├── moving-with-relation.yaml │ │ ├── removal-with-relation.yaml │ │ ├── appearance-no-relation.yaml │ │ ├── insertion-with-relation.yaml │ │ └── appearance-with-relation.yaml │ └── train │ ├── objaworld │ ├── rendering.yaml │ └── recognition.yaml │ ├── objectron │ └── recognition.yaml │ └── clevr │ ├── rendering.yaml │ ├── recognition.yaml │ ├── question-answering.yaml │ └── instruction-following.yaml ├── scripts ├── tokenizer.py ├── compute_text_answer_accuracy.py ├── taming_transformers_utils.py ├── compute_ssim_l2loss.py ├── decode_image_embeddings.py ├── compute_jaccard_index.py ├── compute_jaccard_index_objectron.py └── generate_function.py ├── taming-kyvo.yml ├── kyvo.yml ├── DATA.md ├── LICENSE └── README.md /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.multimodal_dataset import * 2 | -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AadSah/kyvo/HEAD/docs/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.model_builders import * 2 | from models.model_component_builders import * 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoints/ 3 | kyvo-datasets-and-codebooks/ 4 | taming/ 5 | taming-transformers/ -------------------------------------------------------------------------------- /checkpointer/__init__.py: -------------------------------------------------------------------------------- 1 | from checkpointer.checkpointer import * 2 | from checkpointer.convert_weights import * 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.1 2 | torchvision==0.19.1 3 | torchao==0.5.0 4 | transformers==4.41.2 5 | bitsandbytes==0.43.1 -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "torchtune"] 2 | path = torchtune 3 | url = https://github.com/pytorch/torchtune.git 4 | [submodule "taming-transformers"] 5 | path = taming-transformers 6 | url = https://github.com/CompVis/taming-transformers.git 7 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objaworld/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objaworld/rendering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/objaworld/rendering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objaworld_rendering_inference" 52 | image_embeddings_output_folder: "./checkpoints/objaworld/rendering-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objaworld/rendering-inference/three_d_json" 54 | 55 | dataset_name: "ObjaWorld" 56 | vqgan_type: "objaworld" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "3-I" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objectron/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_omni3d_objectron_custom_finer 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objectron/recognition/ 14 | checkpoint_files: [ 15 | meta_model_19.pt 16 | ] 17 | output_dir: ./checkpoints/objectron/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objectron_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/objectron/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objectron/recognition-inference/three_d_json" 54 | 55 | dataset_name: "Objectron" 56 | vqgan_type: "objectron" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 128372 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/objaworld/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/objaworld/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/objaworld/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/test_vqgan_indices.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/test_tokenized_scenes.json" 50 | 51 | run_identifier: "objaworld_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/objaworld/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/objaworld/recognition-inference/three_d_json" 54 | 55 | dataset_name: "ObjaWorld" 56 | vqgan_type: "objaworld" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/question-answering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/question-answering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/question-answering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | instruct_template: null 31 | chat_format: null 32 | max_new_tokens: 2500 33 | temperature: 1.0 34 | # top_k: 300 35 | 36 | load_text_data: True 37 | load_image_data: True 38 | load_three_d_data: True 39 | load_text_target_data: True 40 | load_image_target_data: False 41 | load_three_d_target_data: False 42 | 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_for_vqa.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_for_vqa.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_questions.json" 48 | text_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_answers.json" 49 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_for_vqa.json" 50 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_for_vqa.json" 51 | 52 | run_identifier: "clevr_question-answering_inference" 53 | image_embeddings_output_folder: "./checkpoints/clevr/question-answering-inference/image_embeddings" 54 | three_d_json_output_folder: "./checkpoints/clevr/question-answering-inference/three_d_json" 55 | 56 | dataset_name: "CLEVR" 57 | vqgan_type: "clevr" 58 | vqgan_row_col_size: 16 59 | reorder_image_tokens: False 60 | task_type: "I+3+Q-A" 61 | num_samples: -1 62 | sample_start_idx: 0 63 | image_token_offset: 129471 64 | 65 | # enable_kv_cache: True 66 | 67 | quantizer: null 68 | 69 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/rendering/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/rendering-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_rendering_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/rendering-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/rendering-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "3-I" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/recognition-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/recognition-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/recognition-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/recognition-colab-demo.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/recognition/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/recognition-inference-colab-demo 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: False 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: False 42 | load_three_d_target_data: False 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/test_vqgan_indices_merged_for_rendering_and_recognition.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/test_tokenized_scenes_merged_for_rendering_and_recognition.json" 50 | 51 | run_identifier: "clevr_recognition_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/recognition-inference-colab-demo/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/recognition-inference-colab-demo/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: False 59 | task_type: "I-3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | 68 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/moving-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_moving_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_moving_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_moving_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_moving_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_moving_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-moving-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/removal-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_removal_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_removal_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_removal_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_removal_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_removal_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-removal-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/appearance-no-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_appearance_no_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_appearance_no_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_appearance_no_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_appearance_no_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_appearance_no_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-appearance-no-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/insertion-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_insertion_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_insertion_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_insertion_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_insertion_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_insertion_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-insertion-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /configs/llama3_2/eval/clevr/instruction-following/appearance-with-relation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output from an LLM 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run generate --config generation 5 | 6 | # Model arguments 7 | model: 8 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 9 | 10 | checkpointer: 11 | _component_: checkpointer.FullModelMetaCheckpointer3D 12 | convert_weights_type: 3d_sin_cos_plus_learned_num 13 | checkpoint_dir: ./checkpoints/clevr/instruction-following/ 14 | checkpoint_files: [ 15 | meta_model_9.pt 16 | ] 17 | output_dir: ./checkpoints/clevr/instruction-following-inference 18 | model_type: LLAMA3_2 19 | 20 | device: cuda 21 | dtype: bf16 22 | 23 | seed: 1234 24 | 25 | # Tokenizer arguments 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 29 | 30 | 31 | instruct_template: null 32 | chat_format: null 33 | max_new_tokens: 2500 34 | temperature: 1.0 35 | # top_k: 300 36 | 37 | load_text_data: True 38 | load_image_data: True 39 | load_three_d_data: True 40 | load_text_target_data: False 41 | load_image_target_data: True 42 | load_three_d_target_data: True 43 | 44 | image_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-source/test_vqgan_indices_appearance_with_relation.json" 45 | image_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/instruction-following-eval-target/test_vqgan_indices_appearance_with_relation.json" 46 | 47 | text_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/instruction-following-eval/test_text_instructions_appearance_with_relation.json" 48 | three_d_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-source/test_tokenized_scenes_appearance_with_relation.json" 49 | three_d_target_file: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/instruction-following-eval-target/test_tokenized_scenes_appearance_with_relation.json" 50 | 51 | run_identifier: "clevr_instruction-following-appearance-with-relation_inference" 52 | image_embeddings_output_folder: "./checkpoints/clevr/instruction-following-inference/image_embeddings" 53 | three_d_json_output_folder: "./checkpoints/clevr/instruction-following-inference/three_d_json" 54 | 55 | dataset_name: "CLEVR" 56 | vqgan_type: "clevr" 57 | vqgan_row_col_size: 16 58 | reorder_image_tokens: True 59 | task_type: "I+3+T-I+3" 60 | num_samples: -1 61 | sample_start_idx: 0 62 | image_token_offset: 129471 63 | 64 | # enable_kv_cache: True 65 | 66 | quantizer: null 67 | -------------------------------------------------------------------------------- /scripts/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | import numpy as np 3 | 4 | 5 | def get_tokenizer(model_name="meta-llama/Llama-3.2-1B-Instruct"): 6 | tokenizer = AutoTokenizer.from_pretrained(model_name) 7 | 8 | special_tokens_dict = { 9 | "additional_special_tokens": [ 10 | "[SCENE-START]", 11 | "[SCENE-END]", 12 | "[OBJECT-START]", 13 | "[OBJECT-END]", 14 | "[SIZE]", 15 | "[COLOR]", 16 | "[MATERIAL]", 17 | "[SHAPE]", 18 | "[LOCATION]", 19 | "[IMAGE-START]", 20 | "[IMAGE-END]", 21 | "[TEXT-START]", 22 | "[TEXT-END]", 23 | "[OUTPUT-START]", 24 | ] 25 | } 26 | 27 | number_tokens = np.arange(-3, 3 + 0.005, 0.005).tolist() 28 | # convert the number tokens to strings 29 | number_tokens = [str(format(round(token, 3), ".3f")) for token in number_tokens] 30 | # replace "-0.000" with "0.000" in number tokens and update the list 31 | number_tokens = [token.replace("-0.000", "0.000") for token in number_tokens] 32 | 33 | tokenizer.add_tokens(number_tokens) 34 | tokenizer.add_special_tokens(special_tokens_dict) 35 | 36 | return tokenizer 37 | 38 | 39 | def get_tokenizer_omni3d_objectron(model_name="meta-llama/Llama-3.2-1B-Instruct"): 40 | tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | 42 | special_tokens_dict = { 43 | "additional_special_tokens": [ 44 | "[SCENE-START]", 45 | "[SCENE-END]", 46 | "[OBJECT-START]", 47 | "[OBJECT-END]", 48 | "[CATEGORY]", 49 | "[CENTER_CAM]", 50 | "[DIMENSIONS]", 51 | "[IMAGE-START]", 52 | "[IMAGE-END]", 53 | "[TEXT-START]", 54 | "[TEXT-END]", 55 | "[OUTPUT-START]", 56 | ] 57 | } 58 | 59 | center_cam_x_number_tokens = [0.01 * i for i in range(-20, 21)] 60 | center_cam_y_number_tokens = [0.01 * i for i in range(-20, 11)] 61 | center_cam_z_number_tokens = [0.05 * i for i in range(61)] 62 | center_cam_z_number_tokens.extend([5.0, 7.5, 10.0, 12.5, 15.0, 17.5]) 63 | 64 | dimensions_length_number_tokens = [0.05 * i for i in range(21)] 65 | dimensions_width_number_tokens = [0.05 * i for i in range(21)] 66 | dimensions_height_number_tokens = [0.05 * i for i in range(25)] 67 | 68 | # merge, unique and sort the number tokens 69 | number_tokens = list( 70 | set( 71 | center_cam_x_number_tokens 72 | + center_cam_y_number_tokens 73 | + center_cam_z_number_tokens 74 | + dimensions_length_number_tokens 75 | + dimensions_width_number_tokens 76 | + dimensions_height_number_tokens 77 | ) 78 | ) 79 | number_tokens.sort() 80 | # convert the number tokens to strings 81 | number_tokens = [str(format(round(token, 3), ".2f")) for token in number_tokens] 82 | 83 | tokenizer.add_tokens(number_tokens) 84 | tokenizer.add_special_tokens(special_tokens_dict) 85 | 86 | return tokenizer 87 | -------------------------------------------------------------------------------- /scripts/compute_text_answer_accuracy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | # Parser for command line arguments 7 | parser = argparse.ArgumentParser(description="Compute Text Answer Accuracy") 8 | 9 | parser.add_argument( 10 | "--groundtruth_file", 11 | required=True, 12 | type=str, 13 | help="Path to the ground truth answers JSON file.", 14 | ) 15 | 16 | parser.add_argument( 17 | "--predicted_file", 18 | required=True, 19 | type=str, 20 | help="Path to the predicted answers JSON file.", 21 | ) 22 | 23 | args = parser.parse_args() 24 | 25 | 26 | def clean_answer(answer): 27 | """ 28 | Clean the answer by removing special tokens like [TEXT-START], [TEXT-END], and <|end_of_text|>. 29 | """ 30 | return re.sub(r"\[TEXT-START\]|\[TEXT-END\]|<\|end_of_text\|>", "", answer).strip() 31 | 32 | 33 | def evaluate_accuracy(groundtruth_file, predicted_file): 34 | """ 35 | Evaluate the accuracy of predicted answers against ground truth answers. 36 | 37 | Args: 38 | - groundtruth_file (str): Path to the ground truth answers JSON file. 39 | - predicted_file (str): Path to the predicted answers JSON file. 40 | 41 | Returns: 42 | - float: Accuracy of the model (in percentage). 43 | """ 44 | # Load ground truth and predicted answers 45 | with open(groundtruth_file, "r") as gt_file: 46 | groundtruth_data = json.load(gt_file) 47 | with open(predicted_file, "r") as pred_file: 48 | predicted_data = json.load(pred_file) 49 | 50 | # Assertions to ensure data integrity 51 | assert "answers" in groundtruth_data, "Ground truth JSON missing 'answers' key." 52 | assert "answers" in predicted_data, "Predicted answers JSON missing 'answers' key." 53 | assert len(groundtruth_data["answers"]) == len( 54 | predicted_data["answers"] 55 | ), "Mismatch in number of ground truth and predicted answers." 56 | 57 | groundtruth_answers = groundtruth_data["answers"] 58 | predicted_answers = predicted_data["answers"] 59 | 60 | correct_count = 0 61 | 62 | # Progress bar for evaluation 63 | for gt, pred in tqdm( 64 | zip(groundtruth_answers, predicted_answers), 65 | total=len(groundtruth_answers), 66 | desc="Evaluating", 67 | ): 68 | # Assertions for data consistency 69 | assert ( 70 | gt["image_filename"].split("_scene_")[1] 71 | == pred["image_filename"].split("_scene_")[1] 72 | ), f"Image filename mismatch: {gt['image_filename']} vs {pred['image_filename']}" 73 | 74 | # Clean answers to extract meaningful text 75 | gt_answer = clean_answer(gt["answer"]) 76 | pred_answer = clean_answer(pred["answer"]) 77 | 78 | # Compare and count correct predictions 79 | if gt_answer == pred_answer: 80 | correct_count += 1 81 | 82 | # Calculate accuracy 83 | accuracy = (correct_count / len(groundtruth_answers)) * 100 84 | return accuracy 85 | 86 | 87 | # Evaluate and print accuracy 88 | accuracy = evaluate_accuracy(args.groundtruth_file, args.predicted_file) 89 | print(f"Model Text Answer Accuracy: {accuracy:.2f}%") 90 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objaworld/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "3-I" 11 | dataset_name: "ObjaWorld" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objaworld/rendering/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 5 59 | weight_first_positions_with_weight: 10.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objaworld/rendering/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objectron/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | dataset_name: "Omni3D-Objectron" 11 | task_type: "I-3" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objectron/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 128372 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_omni3d_objectron_custom_finer 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objectron/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 20 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objectron/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/objaworld/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I-3" 11 | dataset_name: "ObjaWorld" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/images/train_vqgan_indices.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/objaworld/3d-scenes/train_tokenized_scenes.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/objaworld/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/objaworld/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/rendering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "3-I" 11 | dataset_name: "CLEVR" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_merged_for_rendering_and_recognition.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_merged_for_rendering_and_recognition.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/clevr/rendering/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 5 59 | weight_first_positions_with_weight: 10.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/clevr/rendering/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/recognition.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I-3" 11 | dataset_name: "CLEVR" 12 | text_source: "" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_merged_for_rendering_and_recognition.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_merged_for_rendering_and_recognition.json" 15 | image_target: "" 16 | three_d_target: "" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: False 20 | load_text_source: False 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: False 25 | load_three_d_target: False 26 | 27 | seed: null 28 | shuffle: True 29 | 30 | # Model Arguments 31 | model: 32 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 33 | freeze_llama3_token_embeddings: True 34 | start_from_original_llama3: True 35 | 36 | checkpointer: 37 | _component_: checkpointer.FullModelMetaCheckpointer3D 38 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 39 | checkpoint_files: [ 40 | consolidated.00.pth 41 | ] 42 | recipe_checkpoint: null 43 | output_dir: ./checkpoints/clevr/recognition/ 44 | model_type: LLAMA3_2 45 | convert_weights_type: 3d_sin_cos_plus_learned_num 46 | resume_from_checkpoint: False 47 | 48 | # Fine-tuning arguments 49 | batch_size: 32 50 | epochs: 10 51 | optimizer: 52 | _component_: bitsandbytes.optim.PagedAdamW8bit 53 | lr: 1e-4 54 | loss: 55 | _component_: torch.nn.CrossEntropyLoss 56 | reduction: none 57 | 58 | weight_first_positions: 0 59 | weight_first_positions_with_weight: 1.0 60 | 61 | max_steps_per_epoch: null 62 | gradient_accumulation_steps: 1 63 | optimizer_in_bwd: False 64 | compile: False 65 | 66 | # Training environment 67 | device: cuda 68 | 69 | # Memory management 70 | enable_activation_checkpointing: False 71 | 72 | # Reduced precision 73 | dtype: bf16 74 | 75 | # Logging 76 | metric_logger: 77 | _component_: torchtune.training.metric_logging.DiskLogger 78 | log_dir: ${output_dir} 79 | output_dir: ./checkpoints/clevr/recognition/ 80 | log_every_n_steps: 1 81 | log_peak_memory_stats: False 82 | 83 | # Profiler (disabled) 84 | profiler: 85 | _component_: torchtune.training.setup_torch_profiler 86 | enabled: False 87 | 88 | #Output directory of trace artifacts 89 | output_dir: ${output_dir}/profiling_outputs 90 | 91 | #`torch.profiler.ProfilerActivity` types to trace 92 | cpu: True 93 | cuda: True 94 | 95 | #trace options passed to `torch.profiler.profile` 96 | profile_memory: True 97 | with_stack: False 98 | record_shapes: True 99 | with_flops: False 100 | 101 | # `torch.profiler.schedule` options: 102 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 103 | wait_steps: 1 104 | warmup_steps: 2 105 | active_steps: 1 106 | num_cycles: 1 107 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/question-answering.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I+3+Q-A" 11 | dataset_name: "CLEVR" 12 | text_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_vqa_questions.json" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_for_vqa.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_for_vqa.json" 15 | text_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_vqa_answers.json" 16 | image_target: "" 17 | three_d_target: "" 18 | image_token_offset: 129471 19 | no_loss_on_input: True 20 | reorder_image_tokens: False 21 | load_text_source: True 22 | load_image_source: True 23 | load_three_d_source: True 24 | load_text_target: True 25 | load_image_target: False 26 | load_three_d_target: False 27 | 28 | seed: null 29 | shuffle: True 30 | 31 | # Model Arguments 32 | model: 33 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 34 | freeze_llama3_token_embeddings: True 35 | start_from_original_llama3: True 36 | 37 | checkpointer: 38 | _component_: checkpointer.FullModelMetaCheckpointer3D 39 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 40 | checkpoint_files: [ 41 | consolidated.00.pth 42 | ] 43 | recipe_checkpoint: null 44 | output_dir: ./checkpoints/clevr/question-answering/ 45 | model_type: LLAMA3_2 46 | convert_weights_type: 3d_sin_cos_plus_learned_num 47 | resume_from_checkpoint: False 48 | 49 | # Fine-tuning arguments 50 | batch_size: 32 51 | epochs: 10 52 | optimizer: 53 | _component_: bitsandbytes.optim.PagedAdamW8bit 54 | lr: 1e-4 55 | loss: 56 | _component_: torch.nn.CrossEntropyLoss 57 | reduction: none 58 | 59 | weight_first_positions: 0 60 | weight_first_positions_with_weight: 1.0 61 | 62 | max_steps_per_epoch: null 63 | gradient_accumulation_steps: 1 64 | optimizer_in_bwd: False 65 | compile: False 66 | 67 | # Training environment 68 | device: cuda 69 | 70 | # Memory management 71 | enable_activation_checkpointing: False 72 | 73 | # Reduced precision 74 | dtype: bf16 75 | 76 | # Logging 77 | metric_logger: 78 | _component_: torchtune.training.metric_logging.DiskLogger 79 | log_dir: ${output_dir} 80 | output_dir: ./checkpoints/clevr/question-answering/ 81 | log_every_n_steps: 1 82 | log_peak_memory_stats: False 83 | 84 | # Profiler (disabled) 85 | profiler: 86 | _component_: torchtune.training.setup_torch_profiler 87 | enabled: False 88 | 89 | #Output directory of trace artifacts 90 | output_dir: ${output_dir}/profiling_outputs 91 | 92 | #`torch.profiler.ProfilerActivity` types to trace 93 | cpu: True 94 | cuda: True 95 | 96 | #trace options passed to `torch.profiler.profile` 97 | profile_memory: True 98 | with_stack: False 99 | record_shapes: True 100 | with_flops: False 101 | 102 | # `torch.profiler.schedule` options: 103 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 104 | wait_steps: 1 105 | warmup_steps: 2 106 | active_steps: 1 107 | num_cycles: 1 108 | -------------------------------------------------------------------------------- /taming-kyvo.yml: -------------------------------------------------------------------------------- 1 | name: taming-kyvo 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h4bc722e_7 11 | - ca-certificates=2025.2.25=h06a4308_0 12 | - cudatoolkit=11.3.1=hb98b00a_13 13 | - ffmpeg=4.3=hf484d3e_0 14 | - freetype=2.10.4=h0708190_1 15 | - gmp=6.3.0=hac33072_2 16 | - gnutls=3.6.13=h85f3911_1 17 | - icu=73.2=h59595ed_0 18 | - intel-openmp=2023.1.0=hdb19cb5_46306 19 | - jbig=2.1=h7f98852_2003 20 | - jpeg=9e=h0b41bf4_3 21 | - lame=3.100=h166bdaf_1003 22 | - lcms2=2.12=hddcbb42_0 23 | - ld_impl_linux-64=2.40=h12ee557_0 24 | - lerc=2.2.1=h9c3ff4c_0 25 | - libblas=3.9.0=1_h86c2bf4_netlib 26 | - libcblas=3.9.0=8_h3b12eaf_netlib 27 | - libdeflate=1.7=h7f98852_5 28 | - libffi=3.4.4=h6a678d5_1 29 | - libgcc=14.2.0=h767d61c_2 30 | - libgcc-ng=14.2.0=h69a702a_2 31 | - libgfortran=14.2.0=h69a702a_2 32 | - libgfortran-ng=14.2.0=h69a702a_2 33 | - libgfortran5=14.2.0=hf1ad2bd_2 34 | - libgomp=14.2.0=h767d61c_2 35 | - libhwloc=2.11.2=default_h0d58e46_1001 36 | - libiconv=1.18=h4ce23a2_1 37 | - liblapack=3.9.0=8_h3b12eaf_netlib 38 | - libpng=1.6.37=h21135ba_2 39 | - libstdcxx=14.2.0=h8f9b012_2 40 | - libstdcxx-ng=14.2.0=h4852527_2 41 | - libtiff=4.3.0=hf544144_1 42 | - libuv=1.50.0=hb9d3cd8_0 43 | - libwebp-base=1.5.0=h851e524_0 44 | - libxml2=2.13.5=hfdd30dd_0 45 | - lz4-c=1.9.3=h9c3ff4c_1 46 | - mkl=2023.1.0=h213fc3f_46344 47 | - ncurses=6.4=h6a678d5_0 48 | - nettle=3.6=he412f7d_0 49 | - numpy=1.24.4=py38h59b608b_0 50 | - olefile=0.47=pyhd8ed1ab_0 51 | - openh264=2.1.1=h780b84a_0 52 | - openjpeg=2.4.0=hb52868f_1 53 | - openssl=3.4.1=h7b32b05_0 54 | - pillow=8.3.2=py38h8e6f84c_0 55 | - pip=24.2=py38h06a4308_0 56 | - python=3.8.20=he870216_0 57 | - python_abi=3.8=2_cp38 58 | - pytorch=1.10.1=py3.8_cuda11.3_cudnn8.2.0_0 59 | - pytorch-mutex=1.0=cuda 60 | - readline=8.2=h5eee18b_0 61 | - setuptools=75.1.0=py38h06a4308_0 62 | - sqlite=3.45.3=h5eee18b_0 63 | - tbb=2021.13.0=hceb3a55_1 64 | - tk=8.6.14=h39e8969_0 65 | - torchaudio=0.10.1=py38_cu113 66 | - torchvision=0.11.2=py38_cu113 67 | - typing_extensions=4.12.2=pyha770c72_0 68 | - wheel=0.44.0=py38h06a4308_0 69 | - xz=5.6.4=h5eee18b_1 70 | - zlib=1.2.13=h5eee18b_1 71 | - zstd=1.5.0=ha95c52a_0 72 | - pip: 73 | - absl-py==2.1.0 74 | - cachetools==5.5.2 75 | - certifi==2025.1.31 76 | - charset-normalizer==3.4.1 77 | - einops==0.3.0 78 | - fsspec==2025.2.0 79 | - future==1.0.0 80 | - google-auth==2.38.0 81 | - google-auth-oauthlib==1.0.0 82 | - grpcio==1.70.0 83 | - idna==3.10 84 | - importlib-metadata==8.5.0 85 | - markdown==3.7 86 | - markupsafe==2.1.5 87 | - oauthlib==3.2.2 88 | - omegaconf==2.0.0 89 | - protobuf==5.29.3 90 | - pyasn1==0.6.1 91 | - pyasn1-modules==0.4.1 92 | - pytorch-lightning==1.0.8 93 | - pyyaml==6.0.2 94 | - requests==2.32.3 95 | - requests-oauthlib==2.0.0 96 | - rsa==4.9 97 | - six==1.17.0 98 | - tensorboard==2.14.0 99 | - tensorboard-data-server==0.7.2 100 | - tqdm==4.67.1 101 | - urllib3==2.2.3 102 | - werkzeug==3.0.6 103 | - zipp==3.20.2 104 | prefix: /home/aadarsh/anaconda3/envs/taming-kyvo 105 | -------------------------------------------------------------------------------- /kyvo.yml: -------------------------------------------------------------------------------- 1 | name: kyvo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - bzip2=1.0.8=h5eee18b_6 8 | - ca-certificates=2024.12.31=h06a4308_0 9 | - ld_impl_linux-64=2.40=h12ee557_0 10 | - libffi=3.4.4=h6a678d5_1 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - libuuid=1.41.5=h5eee18b_0 15 | - ncurses=6.4=h6a678d5_0 16 | - openssl=3.0.15=h5eee18b_0 17 | - pip=25.0=py311h06a4308_0 18 | - python=3.11.11=he870216_0 19 | - readline=8.2=h5eee18b_0 20 | - setuptools=75.8.0=py311h06a4308_0 21 | - sqlite=3.45.3=h5eee18b_0 22 | - tk=8.6.14=h39e8969_0 23 | - wheel=0.45.1=py311h06a4308_0 24 | - xz=5.6.4=h5eee18b_1 25 | - zlib=1.2.13=h5eee18b_1 26 | - pip: 27 | - absl-py==2.1.0 28 | - aiohappyeyeballs==2.4.6 29 | - aiohttp==3.11.12 30 | - aiosignal==1.3.2 31 | - antlr4-python3-runtime==4.9.3 32 | - attrs==25.1.0 33 | - bitsandbytes==0.43.1 34 | - blobfile==3.0.0 35 | - certifi==2025.1.31 36 | - charset-normalizer==3.4.1 37 | - datasets==3.3.1 38 | - dill==0.3.8 39 | - filelock==3.17.0 40 | - frozenlist==1.5.0 41 | - fsspec==2024.12.0 42 | - future==1.0.0 43 | - grpcio==1.70.0 44 | - huggingface-hub==0.28.1 45 | - idna==3.10 46 | - imageio==2.37.0 47 | - jinja2==3.1.5 48 | - lazy-loader==0.4 49 | - lxml==5.3.1 50 | - markdown==3.7 51 | - markupsafe==3.0.2 52 | - mpmath==1.3.0 53 | - multidict==6.1.0 54 | - multiprocess==0.70.16 55 | - networkx==3.4.2 56 | - numpy==2.2.3 57 | - nvidia-cublas-cu12==12.1.3.1 58 | - nvidia-cuda-cupti-cu12==12.1.105 59 | - nvidia-cuda-nvrtc-cu12==12.1.105 60 | - nvidia-cuda-runtime-cu12==12.1.105 61 | - nvidia-cudnn-cu12==9.1.0.70 62 | - nvidia-cufft-cu12==11.0.2.54 63 | - nvidia-curand-cu12==10.3.2.106 64 | - nvidia-cusolver-cu12==11.4.5.107 65 | - nvidia-cusparse-cu12==12.1.0.106 66 | - nvidia-cusparselt-cu12==0.6.2 67 | - nvidia-nccl-cu12==2.20.5 68 | - nvidia-nvjitlink-cu12==12.4.127 69 | - nvidia-nvtx-cu12==12.1.105 70 | - omegaconf==2.3.0 71 | - packaging==24.2 72 | - pandas==2.2.3 73 | - pillow==11.1.0 74 | - propcache==0.2.1 75 | - protobuf==6.30.0 76 | - psutil==7.0.0 77 | - pyarrow==19.0.1 78 | - pycryptodomex==3.21.0 79 | - python-dateutil==2.9.0.post0 80 | - pytorch-lightning==1.0.8 81 | - pytz==2025.1 82 | - pyyaml==6.0.2 83 | - regex==2024.11.6 84 | - requests==2.32.3 85 | - safetensors==0.5.2 86 | - scikit-image==0.25.2 87 | - scipy==1.15.2 88 | - sentencepiece==0.2.0 89 | - six==1.17.0 90 | - sympy==1.13.1 91 | - tensorboard==2.19.0 92 | - tensorboard-data-server==0.7.2 93 | - tifffile==2025.2.18 94 | - tiktoken==0.9.0 95 | - tokenizers==0.19.1 96 | - torch==2.4.1 97 | - torchao==0.5.0 98 | - torchvision==0.19.1 99 | - tqdm==4.67.1 100 | - transformers==4.41.2 101 | - triton==3.0.0 102 | - typing-extensions==4.12.2 103 | - tzdata==2025.1 104 | - urllib3==2.3.0 105 | - werkzeug==3.1.3 106 | - xxhash==3.5.0 107 | - yarl==1.18.3 108 | prefix: /home/aadarsh/anaconda3/envs/kyvo 109 | -------------------------------------------------------------------------------- /configs/llama3_2/train/clevr/instruction-following.yaml: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | tokenizer: 3 | _component_: torchtune.models.llama3.llama3_tokenizer 4 | path: ./llama-3-models/Llama3.2-1B-Instruct/tokenizer.model 5 | max_seq_len: null 6 | 7 | # Dataset 8 | dataset: 9 | _component_: dataset.threed_mllm_dataset 10 | task_type: "I+3+T-I+3" 11 | dataset_name: "CLEVR" 12 | text_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/train_text_instructions.json" 13 | image_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_source.json" 14 | three_d_source: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_source.json" 15 | image_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/images/train_vqgan_indices_target.json" 16 | three_d_target: "./kyvo-datasets-and-codebooks/pretokenized-data/clevr/3d-scenes/train_tokenized_scenes_target.json" 17 | image_token_offset: 129471 18 | no_loss_on_input: True 19 | reorder_image_tokens: True 20 | load_text_source: True 21 | load_image_source: True 22 | load_three_d_source: True 23 | load_text_target: False 24 | load_image_target: True 25 | load_three_d_target: True 26 | 27 | 28 | seed: null 29 | shuffle: True 30 | 31 | # Model Arguments 32 | model: 33 | _component_: models.llama3_2_1b_clevr3d_sin_cos_plus_learned_numbers 34 | freeze_llama3_token_embeddings: True 35 | start_from_original_llama3: True 36 | 37 | checkpointer: 38 | _component_: checkpointer.FullModelMetaCheckpointer3D 39 | checkpoint_dir: ./llama-3-models/Llama3.2-1B-Instruct/ 40 | checkpoint_files: [ 41 | consolidated.00.pth 42 | ] 43 | recipe_checkpoint: null 44 | output_dir: ./checkpoints/clevr/instruction-following/ 45 | model_type: LLAMA3_2 46 | convert_weights_type: 3d_sin_cos_plus_learned_num 47 | resume_from_checkpoint: False 48 | 49 | # Fine-tuning arguments 50 | batch_size: 16 51 | epochs: 10 52 | optimizer: 53 | _component_: bitsandbytes.optim.PagedAdamW8bit 54 | lr: 1e-4 55 | loss: 56 | _component_: torch.nn.CrossEntropyLoss 57 | reduction: none 58 | 59 | weight_first_positions: 5 60 | weight_first_positions_with_weight: 10.0 61 | 62 | max_steps_per_epoch: null 63 | gradient_accumulation_steps: 1 64 | optimizer_in_bwd: False 65 | compile: False 66 | 67 | # Training environment 68 | device: cuda 69 | 70 | # Memory management 71 | enable_activation_checkpointing: False 72 | 73 | # Reduced precision 74 | dtype: bf16 75 | 76 | # Logging 77 | metric_logger: 78 | _component_: torchtune.training.metric_logging.DiskLogger 79 | log_dir: ${output_dir} 80 | output_dir: ./checkpoints/clevr/instruction-following/ 81 | log_every_n_steps: 1 82 | log_peak_memory_stats: False 83 | 84 | # Profiler (disabled) 85 | profiler: 86 | _component_: torchtune.training.setup_torch_profiler 87 | enabled: False 88 | 89 | #Output directory of trace artifacts 90 | output_dir: ${output_dir}/profiling_outputs 91 | 92 | #`torch.profiler.ProfilerActivity` types to trace 93 | cpu: True 94 | cuda: True 95 | 96 | #trace options passed to `torch.profiler.profile` 97 | profile_memory: True 98 | with_stack: False 99 | record_shapes: True 100 | with_flops: False 101 | 102 | # `torch.profiler.schedule` options: 103 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 104 | wait_steps: 1 105 | warmup_steps: 2 106 | active_steps: 1 107 | num_cycles: 1 108 | -------------------------------------------------------------------------------- /scripts/taming_transformers_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from omegaconf import OmegaConf 4 | from taming.models.vqgan import VQModel, GumbelVQ 5 | import io 6 | import requests 7 | import PIL 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | import torchvision.transforms.functional as TF 15 | 16 | 17 | def load_config(config_path, display=False): 18 | config = OmegaConf.load(config_path) 19 | if display: 20 | print(yaml.dump(OmegaConf.to_container(config))) 21 | return config 22 | 23 | 24 | def load_vqgan(config, ckpt_path=None, is_gumbel=False): 25 | if is_gumbel: 26 | model = GumbelVQ(**config.model.params) 27 | else: 28 | model = VQModel(**config.model.params) 29 | if ckpt_path is not None: 30 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 31 | missing, unexpected = model.load_state_dict(sd, strict=False) 32 | return model.eval() 33 | 34 | 35 | def preprocess_vqgan(x): 36 | x = 2.0 * x - 1.0 37 | return x 38 | 39 | 40 | def custom_to_pil(x): 41 | x = x.detach().cpu() 42 | x = torch.clamp(x, -1.0, 1.0) 43 | x = (x + 1.0) / 2.0 44 | x = x.permute(1, 2, 0).numpy() 45 | x = (255 * x).astype(np.uint8) 46 | x = Image.fromarray(x) 47 | if not x.mode == "RGB": 48 | x = x.convert("RGB") 49 | return x 50 | 51 | 52 | def reconstruct_with_vqgan(x, model, reconstruct=False): 53 | # could also use model(x) for reconstruction but use explicit encoding and decoding here 54 | z, _, [_, _, indices] = model.encode(x) 55 | if reconstruct: 56 | xrec = model.decode(z) 57 | return xrec, indices 58 | else: 59 | return z, indices 60 | 61 | 62 | def download_image(url): 63 | resp = requests.get(url) 64 | resp.raise_for_status() 65 | return PIL.Image.open(io.BytesIO(resp.content)) 66 | 67 | 68 | def get_local_image(path, make_square=True, size=320, horizontal_flip=False): 69 | img = PIL.Image.open(path) 70 | if img.mode == "RGBA": 71 | img = img.convert("RGB") 72 | if make_square: 73 | img = img.resize((size, size)) 74 | if horizontal_flip: 75 | img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT) 76 | return img 77 | 78 | 79 | def preprocess(img, target_image_size=320): 80 | s = min(img.size) 81 | 82 | if s < target_image_size: 83 | raise ValueError(f"min dim for image {s} < {target_image_size}") 84 | 85 | r = target_image_size / s 86 | s = (round(r * img.size[1]), round(r * img.size[0])) 87 | img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) 88 | img = TF.center_crop(img, output_size=2 * [target_image_size]) 89 | img = torch.unsqueeze(T.ToTensor()(img), 0) 90 | 91 | return img 92 | 93 | 94 | def reconstruction_pipeline( 95 | path, 96 | vqgan_model, 97 | size=320, 98 | DEVICE="cuda:0", 99 | reconstruct=False, 100 | make_square=True, 101 | horizontal_flip=False, 102 | ): 103 | x_vqgan = preprocess( 104 | get_local_image( 105 | path, make_square=make_square, size=size, horizontal_flip=horizontal_flip 106 | ), 107 | target_image_size=size, 108 | ) 109 | x_vqgan = x_vqgan.to(DEVICE) 110 | vqgan_embedding, vqgan_indices = reconstruct_with_vqgan( 111 | preprocess_vqgan(x_vqgan), vqgan_model, reconstruct=reconstruct 112 | ) 113 | return vqgan_embedding, vqgan_indices 114 | -------------------------------------------------------------------------------- /scripts/compute_ssim_l2loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage.io import imread 4 | from skimage.metrics import structural_similarity as ssim 5 | import glob 6 | from tqdm import tqdm 7 | import re 8 | from PIL import Image 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description="Compute L2 Loss and SSIM") 12 | 13 | parser.add_argument( 14 | "--generated_folder", 15 | required=True, 16 | type=str, 17 | help="Path to the folder containing the generated images", 18 | ) 19 | 20 | parser.add_argument( 21 | "--groundtruth_folder", 22 | required=True, 23 | type=str, 24 | help="Path to the folder containing the ground truth images", 25 | ) 26 | 27 | parser.add_argument( 28 | "--output_folder", 29 | required=True, 30 | type=str, 31 | help="Path to the folder where the results will be saved", 32 | ) 33 | 34 | 35 | args = parser.parse_args() 36 | 37 | 38 | def compute_l2_loss(img1, img2): 39 | return np.mean((img1 - img2) ** 2) 40 | 41 | 42 | groundtruth_folder = args.groundtruth_folder 43 | generated_folder = args.generated_folder 44 | 45 | # Initialize accumulators for L2 loss and SSIM 46 | total_l2_loss = 0 47 | total_ssim = 0 48 | image_count = 0 49 | 50 | # Get sorted list of image files 51 | groundtruth_files = sorted(glob.glob(os.path.join(groundtruth_folder, "*.png"))) 52 | generated_files = sorted(glob.glob(os.path.join(generated_folder, "*.png"))) 53 | 54 | groundtruth_files = groundtruth_files[: len(generated_files)] 55 | 56 | 57 | if len(generated_files) == 0: 58 | raise ValueError("No images found in GENERATED folder.") 59 | 60 | 61 | # Iterate through the image pairs 62 | print("Computing L2 Loss and SSIM for image pairs...") 63 | for gt_file, gen_file in tqdm( 64 | zip(groundtruth_files, generated_files), total=len(groundtruth_files) 65 | ): 66 | # Assert filenames match up to the folder and prefix differences 67 | gt_filename = os.path.basename(gt_file) 68 | gen_filename = os.path.basename(gen_file) 69 | 70 | assert ( 71 | gt_filename == gen_filename.split("generated_")[1] 72 | ), f"File names do not match: {gt_filename} and {gen_filename}" 73 | 74 | # Load images 75 | groundtruth_img = imread(gt_file, as_gray=True) # Load as grayscale for SSIM 76 | generated_img = imread(gen_file, as_gray=True) 77 | 78 | # resize groundtruth image to generated image size 79 | groundtruth_img = np.array( 80 | Image.fromarray(groundtruth_img).resize( 81 | (generated_img.shape[0], generated_img.shape[1]) 82 | ) 83 | ) 84 | 85 | # Ensure the images have the same shape 86 | if groundtruth_img.shape != generated_img.shape: 87 | raise ValueError(f"Image shapes do not match: {gt_file} and {gen_file}") 88 | 89 | # Compute L2 loss and SSIM 90 | l2_loss = compute_l2_loss(groundtruth_img, generated_img) 91 | image_ssim = ssim( 92 | groundtruth_img, 93 | generated_img, 94 | data_range=generated_img.max() - generated_img.min(), 95 | ) 96 | 97 | # Accumulate results 98 | total_l2_loss += l2_loss 99 | total_ssim += image_ssim 100 | image_count += 1 101 | 102 | # Compute averages 103 | average_l2_loss = total_l2_loss / image_count 104 | average_ssim = total_ssim / image_count 105 | 106 | # Print results 107 | print("\nComputation Complete!") 108 | print(f"Average L2 Loss: {average_l2_loss}") 109 | print(f"Average SSIM: {average_ssim}") 110 | 111 | # Write results to a file in the same folder 112 | with open(f"{args.output_folder}/ssim_l2.txt", "w") as f: 113 | f.write(f"Average L2 Loss, Average SSIM:\n") 114 | f.write(f"{average_l2_loss}, {average_ssim}\n") 115 | f.write(f"Total L2 Loss, Total SSIM:\n") 116 | f.write(f"{total_l2_loss}, {total_ssim}\n") 117 | f.write(f"Image Count: {image_count}\n") 118 | -------------------------------------------------------------------------------- /scripts/decode_image_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from omegaconf import OmegaConf 4 | import argparse 5 | import os 6 | import torch 7 | from tqdm import tqdm 8 | import json 9 | import numpy as np 10 | 11 | from taming_transformers_utils import * 12 | 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="Decode embeddings to images") 16 | 17 | parser.add_argument( 18 | "--folder_path", 19 | required=True, 20 | type=str, 21 | help="Path to the folder containing the embeddings", 22 | ) 23 | 24 | parser.add_argument( 25 | "--vqgan_type", 26 | required=True, 27 | type=str, 28 | help="Type of VQGAN model used for training: choose from [clevr, objaworld, objectron, domain-agnostic]", 29 | ) 30 | 31 | parser.add_argument( 32 | "--image_output_path", 33 | required=True, 34 | type=str, 35 | help="Path to the folder where the decoded images will be saved", 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | # add arguments 41 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | torch.set_grad_enabled(False) 44 | 45 | is_gumbel = False 46 | if args.vqgan_type == "clevr": 47 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/clevr/2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/configs/2024-10-10T09-21-36-project.yaml" 48 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/clevr/2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/checkpoints/last.ckpt" 49 | elif args.vqgan_type == "objaworld": 50 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objaworld/2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/configs/2025-01-17T09-02-22-project.yaml" 51 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objaworld/2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/checkpoints/last.ckpt" 52 | elif args.vqgan_type == "objectron": 53 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objectron/2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/configs/2024-11-03T05-41-42-project.yaml" 54 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/objectron/2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/checkpoints/last.ckpt" 55 | elif args.vqgan_type == "domain-agnostic": 56 | config_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/domain-agnostic/vqgan_gumbel_f8/configs/model.yaml" 57 | ckpt_path = "./kyvo-datasets-and-codebooks/vqgan-models-and-codebooks/domain-agnostic/vqgan_gumbel_f8/checkpoints/last.ckpt" 58 | is_gumbel = True 59 | 60 | config = load_config( 61 | config_path, 62 | display=False, 63 | ) 64 | model = load_vqgan( 65 | config, 66 | ckpt_path=ckpt_path, 67 | is_gumbel=is_gumbel, 68 | ).to(DEVICE) 69 | 70 | 71 | folder_path = args.folder_path 72 | 73 | groundtruth_names = [ 74 | f for f in os.listdir(folder_path) if "ground_truth" in f and f.endswith(".npy") 75 | ] 76 | generated_names = [ 77 | f for f in os.listdir(folder_path) if "generated" in f and f.endswith(".npy") 78 | ] 79 | 80 | # natural sort 81 | groundtruth_names.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) 82 | generated_names.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) 83 | 84 | folder_identifier = os.path.basename(folder_path) 85 | groundtruth_key_names = [] 86 | for name in groundtruth_names: 87 | key = name.split(".png")[0].split(folder_identifier)[-1] 88 | groundtruth_key_names.append(str(key)) 89 | 90 | generated_key_names = [] 91 | for name in generated_names: 92 | key = name.split(".png")[0].split(folder_identifier)[-1] 93 | generated_key_names.append(str(key)) 94 | 95 | assert len(groundtruth_names) == len(generated_names) 96 | print("Total number of embeddings to decode: ", len(groundtruth_names)) 97 | 98 | for img_idx in tqdm(range(len(groundtruth_names)), desc="Decoding images"): 99 | # load embeddings from npy files and convert to torch tensors 100 | groundtruth_embeddings = torch.tensor( 101 | np.load(os.path.join(folder_path, groundtruth_names[img_idx])) 102 | ).to(DEVICE) 103 | generated_embeddings = torch.tensor( 104 | np.load(os.path.join(folder_path, generated_names[img_idx])) 105 | ).to(DEVICE) 106 | 107 | xrec_groundtruth = model.decode(groundtruth_embeddings.unsqueeze(0)) 108 | xrec_generated = model.decode(generated_embeddings.unsqueeze(0)) 109 | 110 | if not os.path.exists(os.path.join(args.image_output_path, "GROUNDTRUTH")): 111 | os.makedirs(os.path.join(args.image_output_path, "GROUNDTRUTH")) 112 | 113 | if not os.path.exists(os.path.join(args.image_output_path, "GENERATED")): 114 | os.makedirs(os.path.join(args.image_output_path, "GENERATED")) 115 | 116 | custom_to_pil(xrec_groundtruth[0]).save( 117 | os.path.join( 118 | args.image_output_path, 119 | "GROUNDTRUTH", 120 | f"ground_truth{groundtruth_key_names[img_idx]}.png", 121 | ) 122 | ) 123 | custom_to_pil(xrec_generated[0]).save( 124 | os.path.join( 125 | args.image_output_path, 126 | "GENERATED", 127 | f"generated{generated_key_names[img_idx]}.png", 128 | ) 129 | ) 130 | -------------------------------------------------------------------------------- /DATA.md: -------------------------------------------------------------------------------- 1 | # Kyvo Dataset and Codebooks Details 2 | 3 | This document provides details about the dataset and codebooks provided in the `kyvo-datasets-and-codebooks` repository. We will provide the details about each of the folders in the repository and the contents of each folder. 4 | 5 | ## Data Generation Pipeline 6 | 7 | The pipeline that we follow to generate the pre-tokenized data is as follows: 8 | 9 | * **3D Scenes**: 3D Scene JSON --> Serialized 3D Scene --> Tokenized 3D Scene 10 | * **Images**: Image --> VQGAN Codebook Indices --> Tokenized Image 11 | * **Text**: Text --> Tokenized Text 12 | 13 | 14 | 15 | 16 | ## Pre-tokenized Data 17 | 18 | The `pretokenized-data` folder contains all the pre-tokenized data for the datasets used in the Kyvo project. The pre-tokenized data is stored in the following structure: 19 | 20 | ```python 21 | pretokenized-data/ 22 | |-- clevr/ 23 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for CLEVR for all tasks 24 | | |-- images/ # contains all pre-tokenized images for CLEVR for all tasks 25 | | |-- text/ # contains all pre-tokenized text for CLEVR for all tasks 26 | |-- objaworld/ 27 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for ObjaWorld for all tasks 28 | | |-- images/ # contains all pre-tokenized images for ObjaWorld for all tasks 29 | |-- objectron/ 30 | | |-- 3d-scenes/ # contains all pre-tokenized 3D scenes for Objectron for all tasks 31 | | |-- images/ # contains all pre-tokenized images for Objectron for all tasks 32 | ``` 33 | 34 | 35 | For a given task, an input can be any combination of 3d-scenes, images, and text. The output can be any combination of images, text, and 3d-scenes. In the following table we outline the tasks for each dataset and the corresponding input and output data that are needed for each task. 36 | 37 | | **Task** | **Input Image** | **Input 3D Scene** | **Input Text** | **Output Image** | **Output 3D Scene** | **Output Text** | 38 | |:----------------------:|:------------------:|:----------------------:|:-----------------:|:------------------:|:-----------------------:|:-----------------:| 39 | | **CLEVR** | | | | | | | 40 | | Rendering | 𐄂 | ✓ | 𐄂 | ✓ | 𐄂 | 𐄂 | 41 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 42 | | Instruction-Following | ✓ | ✓ | ✓ | ✓ | ✓ | 𐄂 | 43 | | Question-Answering | ✓ | ✓ | ✓ | 𐄂 | 𐄂 | ✓ | 44 | | | | | | | | | 45 | | **ObjaWorld** | | | | | | | 46 | | Rendering | 𐄂 | ✓ | 𐄂 | ✓ | 𐄂 | 𐄂 | 47 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 48 | | | | | | | | | 49 | | **Objectron** | | | | | | | 50 | | Recognition | ✓ | 𐄂 | 𐄂 | 𐄂 | ✓ | 𐄂 | 51 | 52 | For the exact files that correspond to the input and output data for each task, please refer to the corresponding configuration files in the `configs/llama3_2/train` folder. 53 | 54 | 55 | ## VQGAN Models and Codebooks 56 | 57 | The `vqgan-models-and-codebooks` folder contains all the VQGAN model checkpoints and codebooks for the datasets used in the Kyvo project. The VQGAN model checkpoints and codebooks are stored in the following structure: 58 | 59 | ```python 60 | vqgan-models-and-codebooks/ 61 | |-- clevr/ 62 | | |-- 2024-10-10T09-21-36_custom_vqgan_CLEVR-LARGE/ # contains the VQGAN model checkpoint for CLEVR 63 | | |-- custom_vqgan_embedding_1024CLEVRLARGE_256dim.npy # contains the VQGAN codebook for CLEVR 64 | |-- objaworld/ 65 | | |-- 2025-01-17T09-02-22_custom_vqgan_SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100/ # contains the VQGAN model checkpoint for ObjaWorld 66 | | |-- custom_vqgan_embedding_256SYNTHETIC_LIVINGROOM_PARK_LARGE_EP100_256dim.npy # contains the VQGAN codebook for ObjaWorld 67 | |-- objectron/ 68 | | |-- 2024-11-03T05-41-42_custom_vqgan_OMNI3D_OBJECTRON_ep200/ # contains the VQGAN model checkpoint for Objectron 69 | | |-- custom_vqgan_embedding_256Omni3D-OBJECTRON_256dim.npy # contains the VQGAN codebook for Objectron 70 | ``` 71 | 72 | ## Images and Scenes for Evaluation 73 | 74 | The `images-and-scenes-for-evaluation` folder contains all the groundtruth images and scenes for the datasets used in the Kyvo project. The images and scenes are used to compute the evaluation metrics for the different tasks. The images and scenes are stored in the following structure: 75 | 76 | ```python 77 | images-and-scenes-for-evaluation/ 78 | |-- clevr/ # contains all images and scenes for evaluation for CLEVR 79 | |-- objaworld/ # contains all images and scenes for evaluation for ObjaWorld 80 | |-- objectron/ # contains all scenes for evaluation for Objectron 81 | ``` 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /scripts/compute_jaccard_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description="Compute Jaccard Index") 8 | 9 | parser.add_argument( 10 | "--generated_folder", 11 | required=True, 12 | type=str, 13 | help="Path to the folder containing the predicted scenes", 14 | ) 15 | 16 | parser.add_argument( 17 | "--groundtruth_folder", 18 | required=True, 19 | type=str, 20 | help="Path to the folder containing the ground truth scenes", 21 | ) 22 | 23 | parser.add_argument( 24 | "--tau", 25 | required=True, 26 | type=float, 27 | default=0.25, 28 | help="Distance threshold for matching objects", 29 | ) 30 | 31 | parser.add_argument( 32 | "--dataset", 33 | required=True, 34 | type=str, 35 | help="Dataset used for training: choose from [clevr, objaworld]", 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | # Function to compute Euclidean distance between two 2D coordinates 42 | def dist(coords1, coords2): 43 | return np.linalg.norm(np.array(coords1) - np.array(coords2)) 44 | 45 | 46 | # Function to compute the Jaccard Index 47 | def compute_jaccard_index(ground_truth_scenes, predicted_scenes, tau): 48 | 49 | total_number_of_scenes = len(ground_truth_scenes) 50 | assert total_number_of_scenes == len(predicted_scenes) 51 | total_jaccard_index = 0 52 | 53 | for gt_scene, pred_scene in zip(ground_truth_scenes, predicted_scenes): 54 | true_positives = 0 55 | false_positives = 0 56 | false_negatives = 0 57 | 58 | if args.dataset == "clevr": 59 | assert ( 60 | gt_scene["image_filename"] == pred_scene["key"] 61 | or gt_scene["image_filename"].split(".png")[0] 62 | == pred_scene["key"].split(".png")[0] 63 | ), f"Ground truth and predicted scene mismatch: {gt_scene['image_filename']} vs {pred_scene['key']}" 64 | elif args.dataset == "objaworld": 65 | assert ( 66 | gt_scene["image_filename"].split(".png")[0] 67 | == pred_scene["key"].split(".png_")[0] 68 | ), f"Ground truth and predicted scene mismatch: {gt_scene['image_filename']} vs {pred_scene['key']}" 69 | 70 | gt_objects = gt_scene["objects"] 71 | pred_objects = pred_scene["objects"] 72 | 73 | # Create flags to track matched ground truth objects 74 | gt_matched = [False] * len(gt_objects) 75 | 76 | # Match predictions to ground truth 77 | for pred_obj in pred_objects: 78 | matched = False 79 | for i, gt_obj in enumerate(gt_objects): 80 | if not gt_matched[i]: # Check if ground truth object is unmatched 81 | # Check attribute equality and distance condition 82 | try: 83 | if args.dataset == "clevr": 84 | if ( 85 | pred_obj["size"] == gt_obj["size"] 86 | and pred_obj["color"] == gt_obj["color"] 87 | and pred_obj["material"] == gt_obj["material"] 88 | and pred_obj["shape"] == gt_obj["shape"] 89 | and dist( 90 | pred_obj["3d_coords"][:2], gt_obj["3d_coords"][:2] 91 | ) 92 | < tau 93 | ): 94 | gt_matched[i] = True # Mark ground truth as matched 95 | matched = True 96 | true_positives += 1 97 | break 98 | elif args.dataset == "objaworld": 99 | if ( 100 | pred_obj["shape"] == gt_obj["shape"].split("-")[0] 101 | and dist( 102 | pred_obj["3d_coords"][:2], gt_obj["3d_coords"][:2] 103 | ) 104 | < tau 105 | and abs( 106 | pred_obj["3d_coords"][2] - gt_obj["3d_coords"][2] 107 | ) 108 | < 0.05 109 | and abs(pred_obj["3d_coords"][3] - gt_obj["rotation"]) 110 | < 0.15 111 | ): 112 | gt_matched[i] = True # Mark ground truth as matched 113 | matched = True 114 | true_positives += 1 115 | break 116 | except: 117 | continue # Skip if any attribute is missing 118 | if not matched: 119 | false_positives += 1 120 | 121 | # Count unmatched ground truths 122 | false_negatives += gt_matched.count(False) 123 | 124 | if (true_positives + false_positives + false_negatives) == 0: 125 | jaccard_index = -1 126 | else: 127 | jaccard_index = true_positives / ( 128 | true_positives + false_positives + false_negatives 129 | ) 130 | total_jaccard_index += jaccard_index 131 | 132 | # Compute Jaccard Index 133 | 134 | jaccard_index = total_jaccard_index / total_number_of_scenes 135 | return jaccard_index 136 | 137 | 138 | # find the predicted scenes json file 139 | folder = os.path.basename(args.generated_folder) 140 | print(f"Computing Jaccard Index for {folder}") 141 | with open(f"{args.generated_folder}/predicted_scenes_{folder}.json", "r") as f: 142 | predicted_data = json.load(f) 143 | 144 | ground_truth_data = {"scenes": []} 145 | 146 | if args.dataset == "clevr": 147 | all_gt_jsons = os.listdir(args.groundtruth_folder) 148 | all_gt_jsons = sorted(all_gt_jsons) 149 | for gt_json in all_gt_jsons: 150 | with open(f"{args.groundtruth_folder}/{gt_json}", "r") as f: 151 | ground_truth_data["scenes"].append(json.load(f)) 152 | elif args.dataset == "objaworld": 153 | for pred_scene in predicted_data["scenes"]: 154 | with open( 155 | "{}/output-{}-large-10K-6/scenes/{}.json".format( 156 | args.groundtruth_folder, 157 | pred_scene["key"].split("png_")[1], 158 | pred_scene["key"].split(".png_")[0], 159 | ), 160 | "r", 161 | ) as f: 162 | ground_truth_data["scenes"].append(json.load(f)) 163 | 164 | ground_truth_scenes = ground_truth_data["scenes"] 165 | predicted_scenes = predicted_data["scenes"] 166 | 167 | ground_truth_scenes = ground_truth_scenes[: len(predicted_scenes)] 168 | 169 | jaccard_index = compute_jaccard_index( 170 | ground_truth_scenes, predicted_scenes, tau=args.tau 171 | ) 172 | 173 | print(f"Jaccard Index: {jaccard_index:.4f}") 174 | 175 | # print the Jaccard Index to a file in the same folder 176 | with open(f"{args.generated_folder}/jaccard_index_tau-{args.tau}.txt", "w") as f: 177 | f.write(f"Jaccard Index: {jaccard_index:.4f}") 178 | -------------------------------------------------------------------------------- /scripts/compute_jaccard_index_objectron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--tau", type=float, default=0.25) 12 | parser.add_argument("--dimensions_tau", type=float, default=0.05) 13 | parser.add_argument("--predictions_file", type=str, required=True) 14 | parser.add_argument("--groundtruth_file", type=str, required=True) 15 | parser.add_argument( 16 | "--method", 17 | type=str, 18 | required=True, 19 | help="Method used for predictions - 'cube-rcnn' or 'kyvo'", 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | # load the file with all data 26 | with open(args.groundtruth_file) as f: 27 | data = json.load(f) 28 | 29 | # load the prediction json 30 | with open( 31 | args.predictions_file, 32 | ) as f: 33 | predictions = json.load(f) 34 | 35 | # preprocess the data 36 | image_id_to_objects = {} 37 | for ann in data["annotations"]: 38 | image_id = ann["image_id"] 39 | if image_id not in image_id_to_objects: 40 | image_id_to_objects[image_id] = {"objects": []} 41 | image_id_to_objects[image_id]["objects"].append( 42 | { 43 | "category_id": ann["category_id"], 44 | "category": ann["category_name"], 45 | "center_cam": ann["center_cam"], 46 | "dimensions": ann["dimensions"], 47 | } 48 | ) 49 | 50 | 51 | data_filepath_to_image_id = {} 52 | for image in data["images"]: 53 | if image["file_path"] in data_filepath_to_image_id: 54 | raise ValueError("Duplicate file path found") 55 | data_filepath_to_image_id[image["file_path"]] = image["id"] 56 | 57 | 58 | final_preprocessed_data = {} 59 | for image_id, objects in image_id_to_objects.items(): 60 | if image_id not in final_preprocessed_data: 61 | final_preprocessed_data[image_id] = {} 62 | for obj in objects["objects"]: 63 | if obj["category"] not in final_preprocessed_data[image_id]: 64 | final_preprocessed_data[image_id][obj["category"]] = { 65 | "num": 0, 66 | "center_cam_3d": [], 67 | "dimensions": [], 68 | } 69 | final_preprocessed_data[image_id][obj["category"]]["num"] += 1 70 | final_preprocessed_data[image_id][obj["category"]]["center_cam_3d"].append( 71 | np.array(obj["center_cam"]) 72 | ) 73 | final_preprocessed_data[image_id][obj["category"]]["dimensions"].append( 74 | np.array(obj["dimensions"]) 75 | ) 76 | 77 | data_category_id_to_name = {} 78 | for cat in data["categories"]: 79 | data_category_id_to_name[cat["id"]] = cat["name"] 80 | 81 | 82 | if args.method == "cube-rcnn": 83 | cube_rcnn_predictions_preprocessed = {} 84 | for pred in predictions: 85 | image_id = pred["image_id"] 86 | if image_id not in cube_rcnn_predictions_preprocessed: 87 | cube_rcnn_predictions_preprocessed[image_id] = {"objects": []} 88 | cube_rcnn_predictions_preprocessed[image_id]["objects"].append( 89 | { 90 | "category_id": pred["category_id"], 91 | "category": data_category_id_to_name[pred["category_id"]], 92 | "center_cam": pred["center_cam"], 93 | "dimensions": pred["dimensions"], 94 | } 95 | ) 96 | 97 | final_preprocessed_prediction_data = {} 98 | for image_id, objects in cube_rcnn_predictions_preprocessed.items(): 99 | if image_id not in final_preprocessed_prediction_data: 100 | final_preprocessed_prediction_data[image_id] = {} 101 | for obj in objects["objects"]: 102 | if obj["category"] not in final_preprocessed_prediction_data[image_id]: 103 | final_preprocessed_prediction_data[image_id][obj["category"]] = { 104 | "num": 0, 105 | "center_cam_3d": [], 106 | "dimensions": [], 107 | } 108 | final_preprocessed_prediction_data[image_id][obj["category"]]["num"] += 1 109 | final_preprocessed_prediction_data[image_id][obj["category"]][ 110 | "center_cam_3d" 111 | ].append(np.array(obj["center_cam"])) 112 | final_preprocessed_prediction_data[image_id][obj["category"]][ 113 | "dimensions" 114 | ].append(np.array(obj["dimensions"])) 115 | elif args.method == "kyvo": 116 | final_preprocessed_prediction_data = {} 117 | for pred_scene in predictions["scenes"]: 118 | pred_objects = pred_scene["objects"] 119 | image_id = data_filepath_to_image_id[pred_scene["key"]] 120 | if image_id not in final_preprocessed_prediction_data: 121 | final_preprocessed_prediction_data[image_id] = {} 122 | for obj in pred_objects: 123 | if obj["category"] not in final_preprocessed_prediction_data[image_id]: 124 | final_preprocessed_prediction_data[image_id][obj["category"]] = { 125 | "num": 0, 126 | "center_cam_3d": [], 127 | "dimensions": [], 128 | } 129 | final_preprocessed_prediction_data[image_id][obj["category"]]["num"] += 1 130 | final_preprocessed_prediction_data[image_id][obj["category"]][ 131 | "center_cam_3d" 132 | ].append(np.array(obj["center_cam"])) 133 | final_preprocessed_prediction_data[image_id][obj["category"]][ 134 | "dimensions" 135 | ].append(np.array(obj["dimensions"])) 136 | 137 | # only keep the common keys 138 | common_keys = set(final_preprocessed_data.keys()).intersection( 139 | set(final_preprocessed_prediction_data.keys()) 140 | ) 141 | 142 | final_preprocessed_data_common = {k: final_preprocessed_data[k] for k in common_keys} 143 | 144 | final_preprocessed_prediction_data_common = { 145 | k: final_preprocessed_prediction_data[k] for k in common_keys 146 | } 147 | 148 | print("Total common keys: ", len(final_preprocessed_data_common.keys())) 149 | 150 | 151 | import numpy as np 152 | from scipy.spatial.distance import euclidean 153 | 154 | 155 | def compute_jaccard(predictions, groundtruths, tau, dimensions_tau): 156 | """ 157 | Compute the Jaccard metric based on predictions and ground truths. 158 | 159 | Parameters: 160 | predictions (list of dict): List of dictionaries containing predicted objects with their category and coordinates. 161 | groundtruths (list of dict): List of dictionaries containing ground truth objects with their category and coordinates. 162 | tau (float): Distance threshold for matching predictions with ground truths. 163 | 164 | Returns: 165 | float: Jaccard metric (tp / (tp + fp + fn)). 166 | """ 167 | true_positives = 0 168 | false_positives = 0 169 | false_negatives = 0 170 | 171 | # Iterate over each prediction 172 | for prediction in predictions: 173 | matched = set() # Track matched ground truth objects 174 | 175 | # Iterate over each predicted object 176 | for category, pred_data in prediction.items(): 177 | pred_count = pred_data["num"] 178 | pred_centers = pred_data["center_cam_3d"] 179 | pred_dimensions = pred_data["dimensions"] 180 | 181 | # Check if the category exists in ground truths 182 | if category in groundtruths[0]: 183 | gt_data = groundtruths[0][category] 184 | gt_count = gt_data["num"] 185 | gt_centers = gt_data["center_cam_3d"] 186 | gt_dimensions = gt_data["dimensions"] 187 | matched_ground_truths = [False] * gt_count 188 | 189 | # Attempt to match each predicted object with a ground truth 190 | for j, pred_center in enumerate(pred_centers): 191 | best_match_index = -1 192 | best_distance = float("inf") 193 | 194 | # Find the closest unmatched ground truth 195 | for i, gt_center in enumerate(gt_centers): 196 | if not matched_ground_truths[i]: 197 | distance = euclidean(pred_center, gt_center) 198 | dimensions_errors = np.mean( 199 | np.abs(pred_dimensions[j] - gt_dimensions[i]) 200 | ) 201 | if ( 202 | distance < tau 203 | and distance < best_distance 204 | and dimensions_errors <= dimensions_tau 205 | ): 206 | best_distance = distance 207 | best_match_index = i 208 | 209 | # If a match is found, count it as a true positive 210 | if best_match_index != -1: 211 | true_positives += 1 212 | matched_ground_truths[best_match_index] = True 213 | matched.add(best_match_index) 214 | else: 215 | false_positives += 1 216 | 217 | # Add unmatched ground truth objects to false negatives 218 | false_negatives += gt_count - sum(matched_ground_truths) 219 | else: 220 | # All predicted objects of this category are false positives if no matching ground truths 221 | false_positives += pred_count 222 | 223 | # Account for any ground truths that had no corresponding predictions 224 | for category, gt_data in groundtruths[0].items(): 225 | if category not in prediction: 226 | false_negatives += gt_data["num"] 227 | 228 | # Calculate Jaccard metric 229 | jaccard_metric = true_positives / ( 230 | true_positives + false_positives + false_negatives 231 | ) 232 | 233 | print("true_positives: ", true_positives) 234 | print("false_positives: ", false_positives) 235 | print("false_negatives: ", false_negatives) 236 | 237 | return jaccard_metric, true_positives, false_positives, false_negatives 238 | 239 | 240 | jaccard_metric_sum = 0 241 | total_tp = 0 242 | total_fp = 0 243 | total_fn = 0 244 | total_number_of_scenes = len(final_preprocessed_data_common.keys()) 245 | 246 | for key in final_preprocessed_data_common.keys(): 247 | jaccard_metric, true_positives, false_positives, false_negatives = compute_jaccard( 248 | [final_preprocessed_prediction_data_common[key]], 249 | [final_preprocessed_data_common[key]], 250 | args.tau, 251 | args.dimensions_tau, 252 | ) 253 | 254 | total_tp += true_positives 255 | total_fp += false_positives 256 | total_fn += false_negatives 257 | jaccard_metric_sum += jaccard_metric 258 | 259 | 260 | print( 261 | "Average Jaccard metric for scene matching: {}".format( 262 | round(jaccard_metric_sum / total_number_of_scenes, 4) 263 | ) 264 | ) 265 | 266 | print("Average TP: ", total_tp / total_number_of_scenes) 267 | print("Average FP: ", total_fp / total_number_of_scenes) 268 | print("Average FN: ", total_fn / total_number_of_scenes) 269 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aligning Text, Images, and 3D Structure Token-by-Token 2 | 3 | This repository contains the official code for the paper [**Aligning Text, Images, and 3D Structure Token-by-Token**](https://glab-caltech.github.io/kyvo/). 4 | **Authors:** [Aadarsh Sahoo](https://aadsah.github.io/), [Vansh Tibrewal](https://vanshtibrewal.github.io/), [Georgia Gkioxari](https://gkioxari.github.io/) 5 | 6 | 7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
27 |
28 |
Kyvo: a decoder-only transformer aligns a structured 3D modality with language and vision. 30 | This 3D modality represents scenes as lists of objects, each defined by its 3D shape, type, 3D position, 31 | pose and size parameters. Kyvo unifies the token space of images, text, and 3D to enable a variety of 32 | complex visual 3D tasks.
33 | 34 | 35 | 36 | --- 37 | 38 | ## Table of Contents 39 | 1. [Setup](#setup) 40 | 2. [Dataset and VQGAN Codebooks](#dataset-and-vqgan-codebooks) 41 | 3. [Download Llama-3.2-1B Models](#download-llama-3.2-1b-models) 42 | 4. [Training](#training) 43 | 5. [Evaluation](#evaluation) 44 | 45 | --- 46 | 47 | ## 📢 News 48 | - **2025-06-09**: Kyvo is on arXiv! 49 | 50 | ## 📋 TODO 51 | - [ ] Release code and data for ARKitScenes and ObjaWorld with explicit shape representations. 52 | - [ ] HuggingFace 🤗 demo. 53 | - [x] Release training code and data for CLEVR, ObjaWorld (complex object shapes), Objectron. 54 | 55 | ## Setup 56 | 57 | This repository uses [torchtune](https://github.com/pytorch/torchtune) for training, included as a submodule. 58 | 59 | 1. **Clone this repository**: 60 | 61 | ```bash 62 | git clone --recurse-submodules https://github.com/AadSah/kyvo.git 63 | cd kyvo 64 | ``` 65 | 66 | 2. **Set up the environment:**: 67 | 68 | **Option A:** Create a new conda environment and install the required dependencies: 69 | 70 | ```bash 71 | cd kyvo 72 | conda create -n kyvo python=3.11 73 | conda activate kyvo 74 | cd torchtune 75 | pip install -e . 76 | cd .. 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | **Option B:** Use the provided conda environment file to create the environment: 81 | 82 | ```bash 83 | conda env create -f kyvo.yml 84 | conda activate kyvo 85 | cd torchtune 86 | pip install -e . 87 | cd .. 88 | ``` 89 | 90 | ## Dataset and VQGAN Codebooks 91 | 92 | Download the pre-tokenized data, VQGAN checkpoints, and codebooks from Hugging Face: 93 | 94 | ```bash 95 | git clone git@hf.co:datasets/aadarsh99/kyvo-datasets-and-codebooks 96 | cd kyvo-datasets-and-codebooks 97 | git lfs install 98 | git pull 99 | ``` 100 | 101 | This will create a folder `./kyvo-datasets-and-codebooks` with the following structure: 102 | 103 | ```python 104 | kyvo-datasets-and-codebooks/ 105 | |-- images-and-scenes-for-evaluation/ # contains all images and scenes for evaluation 106 | | |-- clevr/ # all CLEVR related files 107 | | |-- objaworld/ # all ObjaWorld related files 108 | | | ... 109 | |-- pretokenized-data/ # contains all pre-tokenized data for all the datasets 110 | |-- |-- clevr/ # all CLEVR related files 111 | | |-- objaworld/ # all ObjaWorld related files 112 | | | ... 113 | |-- vqgan-models-and-codebooks/ # contains all VQGAN model checkpoints and codebooks 114 | | |-- clevr/ # all CLEVR related files 115 | | |-- objaworld/ # all ObjaWorld related files 116 | | | ... 117 | ``` 118 | 119 | More details about the dataset and VQGAN codebooks can be found in the [`DATA.md`](DATA.md) file. 120 | 121 | ## Download Llama-3.2-1B Models 122 | 123 | Use `tune download` to fetch the Llama-3.2-1B models: 124 | 125 | ```bash 126 | tune download meta-llama/Llama-3.2-1B --output-dir ./llama-3-models/Llama3.2-1B/ 127 | 128 | tune download meta-llama/Llama-3.2-1B-Instruct --output-dir ./llama-3-models/Llama3.2-1B-Instruct/ 129 | ``` 130 | 131 | For convenience and compatibility with the provided config files, you may need to restructure the downloaded model directories (otherwise, update paths in your config files accordingly): 132 | 133 | ```bash 134 | mv ./llama-3-models/Llama3.2-1B/original/* ./llama-3-models/Llama3.2-1B/ 135 | 136 | mv ./llama-3-models/Llama3.2-1B-Instruct/original/* ./llama-3-models/Llama3.2-1B-Instruct/ 137 | ``` 138 | 139 | ## Training 140 | 141 | All training configuration files are located in `./configs/llama3_2/train/`. Below are sample commands for different datasets and tasks: 142 | 143 | ### CLEVR: 144 | 145 | ```bash 146 | # Rendering 147 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/rendering.yaml 148 | 149 | # Recognition 150 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/recognition.yaml 151 | 152 | # Instruction-Following 153 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/instruction_following.yaml 154 | 155 | # Question-Answering 156 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/clevr/question_answering.yaml 157 | ``` 158 | 159 | ### ObjaWorld: 160 | 161 | ```bash 162 | # Rendering 163 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objaworld/rendering.yaml 164 | 165 | # Recognition 166 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objaworld/recognition.yaml 167 | ``` 168 | 169 | ### Objectron: 170 | 171 | ```bash 172 | # Recognition 173 | python3 scripts/full_finetune_single_device_3d.py --config ./kyvo/configs/llama3_2/train/objectron/recognition.yaml 174 | ``` 175 | 176 | After training, the models will be saved in the `./checkpoints/` directory. 177 | 178 | ## Evaluation 179 | 180 | Evaluation configuration files are located in `./configs/llama3_2/eval/`. Below are sample commands: 181 | 182 | ### CLEVR: 183 | 184 | ```python 185 | # Rendering 186 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/rendering.yaml 187 | 188 | # Recognition 189 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/recognition.yaml 190 | 191 | # Instruction-Following (5 sub-tasks are individually evaluated) 192 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/appearance-no-relation.yaml 193 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/appearance-with-relation.yaml 194 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/insertion-with-relation.yaml 195 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/moving-with-relation.yaml 196 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/instruction-following/removal-with-relation.yaml 197 | 198 | # Question-Answering 199 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/clevr/question-answering.yaml 200 | ``` 201 | 202 | ### ObjaWorld: 203 | 204 | ```python 205 | # Rendering 206 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objaworld/rendering.yaml 207 | 208 | # Recognition 209 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objaworld/recognition.yaml 210 | ``` 211 | 212 | ### Objectron: 213 | 214 | ```python 215 | # Recognition 216 | python3 scripts/generate_3d.py --config ./kyvo/configs/llama3_2/eval/objectron/recognition.yaml 217 | ``` 218 | 219 | Evaluation outputs are saved (by default) to `./checkpoints/` or a directory specified in the config files. Depending on the task, outputs may include 3D scene JSONs, image embeddings, or text predictions. 220 | 221 | 222 | ### Metrics Computation 223 | 224 | Below are scripts for computing the evaluation metrics used in the paper. 225 | 226 | 1. **Jaccard Index on 3D Scenes** 227 | 228 | ```bash 229 | # CLEVR 230 | python3 scripts/compute_jaccard_index.py --tau 0.05 --generated_folder ./checkpoints/clevr/recognition-inference/three_d_json/clevr_recognition_inference --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/clevr/original_images/scenes --dataset clevr 231 | 232 | # ObjaWorld 233 | python3 scripts/compute_jaccard_index.py --tau 0.05 --generated_folder ./checkpoints/objaworld/recognition-inference/three_d_json/objaworld_recognition_inference --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objaworld/original_images/ --dataset objaworld 234 | 235 | # Objectron (kyvo) 236 | python3 scripts/compute_jaccard_index.py --tau 0.05 --dimension_tau 0.05 --predictions_file ./checkpoints/objectron/recognition-inference/three_d_json/objectron_recognition_inference/predicted_scenes_objectron_recognition_inference.json --groundtruth_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/Objectron_test.json --dataset objaworld --method kyvo 237 | 238 | # Objectron (Cube-RCNN) 239 | python3 scripts/compute_jaccard_index.py --tau 0.05 --dimension_tau 0.05 --predictions_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/omni_instance_results_resnet34.json --groundtruth_file kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/objectron/Objectron_test.json --dataset objaworld --method cube-rcnn 240 | ``` 241 | * `--tau`: Threshold for Jaccard index. 242 | * `--generated_folder`: Path to predicted 3D scenes. 243 | * `--groundtruth_folder`: Path to ground truth 3D scenes. 244 | * `--dataset`: Dataset name. 245 | 246 | 2. **Decoding Image Embeddings and Computing SSIM / L2-Loss** 247 | 248 | Create a separate environment to install [`taming-transformers`](https://github.com/CompVis/taming-transformers/tree/master): 249 | 250 | ```bash 251 | conda env create -f taming-kyvo.yml 252 | conda activate taming-kyvo 253 | cd taming-transformers 254 | pip install -e . 255 | cd .. 256 | ``` 257 | 258 | **Decode Image Embeddings:** 259 | 260 | ```bash 261 | # CLEVR 262 | python3 scripts/decode_image_embeddings.py --vqgan_type clevr --folder_path ./checkpoints/clevr/rendering-inference/image_embeddings/clevr_rendering_inference --image_output_path ./checkpoints/clevr/rendering-inference/decoded_images/ 263 | 264 | # ObjaWorld 265 | python3 scripts/decode_image_embeddings.py --vqgan_type objaworld --folder_path ./checkpoints/objaworld/rendering-inference/image_embeddings/objaworld_rendering_inference --image_output_path ./checkpoints/objaworld/rendering-inference/decoded_images/ 266 | ``` 267 | 268 | * `--folder_path`: Path to image embeddings. 269 | * `--vqgan_type`: Dataset name. 270 | * `--image_output_path`: Path to save decoded images. 271 | 272 | **Compute SSIM and L2-Loss:** 273 | 274 | ```bash 275 | # CLEVR 276 | python3 scripts/compute_ssim_l2loss.py --generated_folder ./checkpoints/clevr/rendering-inference/decoded_images/GENERATED --groundtruth_folder ./kyvo-datasets-and-codebooks/images-and-scenes-for-evaluation/clevr/original_images/images --output_folder ./checkpoints/clevr/rendering-inference/decoded_images 277 | ``` 278 | 279 | * `--generated_folder`: Path to predicted images. 280 | * `--groundtruth_folder`: Path to ground truth images. 281 | * `--output_folder`: Path to save computed SSIM and L2-loss values. 282 | 283 | 3. **Text Output Accuracy** 284 | 285 | ```bash 286 | python3 scripts/compute_text_answer_accuracy.py --predicted_file ./checkpoints/clevr/question-answering-inference/three_d_json/clevr_question-answering_inference/predicted_answers_clevr_question-answering_inference.json --groundtruth_file ./kyvo-datasets-and-codebooks/pretokenized-data/clevr/text/test_vqa_answers.json 287 | ``` 288 | 289 | * `--predicted_file`: Path to predicted text answers. 290 | * `--groundtruth_file`: Path to ground truth text answers. 291 | 292 | --- 293 | 294 | ## Citation 295 | 296 | If you find this work useful, please consider citing: 297 | 298 | ```bibtex 299 | @misc{sahoo2025aligningtextimages3d, 300 | title={Aligning Text, Images, and 3D Structure Token-by-Token}, 301 | author={Aadarsh Sahoo and Vansh Tibrewal and Georgia Gkioxari}, 302 | year={2025}, 303 | eprint={2506.08002}, 304 | archivePrefix={arXiv}, 305 | primaryClass={cs.CV}, 306 | url={https://arxiv.org/abs/2506.08002}, 307 | } 308 | ``` 309 | 310 | --- 311 | 312 | For any questions or issues, please open a GitHub issue or contact [Aadarsh](mailto:aadarsh.sahoo.99@gmail.com). Thank you for your interest in our work! 313 | 314 | 315 | 316 | 317 | -------------------------------------------------------------------------------- /checkpointer/convert_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import re 8 | 9 | from typing import Any, Dict 10 | 11 | import torch 12 | 13 | _FROM_META_3D_SIN_COS_NUM = { 14 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 15 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 16 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 17 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 18 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 19 | "norm.weight": "norm.scale", 20 | "output.weight": "output.weight", 21 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 22 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 23 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 24 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 25 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 26 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 27 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 28 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 29 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 30 | } 31 | 32 | _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM = { 33 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 34 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 35 | "tok_embeddings.static_sin_cos_embedding.weight": "tok_embeddings.static_sin_cos_embedding.weight", 36 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 37 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 38 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 39 | "norm.weight": "norm.scale", 40 | "output.weight": "output.weight", 41 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 42 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 43 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 44 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 45 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 46 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 47 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 48 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 49 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 50 | } 51 | 52 | _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP = { 53 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 54 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 55 | "tok_embeddings.static_sin_cos_embedding.weight": "tok_embeddings.static_sin_cos_embedding.weight", 56 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 57 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 58 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 59 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 60 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 61 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 62 | "norm.weight": "norm.scale", 63 | "output.weight": "output.weight", 64 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 65 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 66 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 67 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 68 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 69 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 70 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 71 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 72 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 73 | } 74 | 75 | _FROM_META_3D_SIN_COS_NUM_WITH_MLP = { 76 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 77 | "tok_embeddings.numbers_embedding.weight": "tok_embeddings.numbers_embedding.weight", 78 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 79 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 80 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 81 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 82 | "tok_embeddings.vqgan_embed_proj.{}.weight": "tok_embeddings.vqgan_embed_proj.{}.weight", 83 | "tok_embeddings.vqgan_embed_proj.{}.bias": "tok_embeddings.vqgan_embed_proj.{}.bias", 84 | "norm.weight": "norm.scale", 85 | "output.weight": "output.weight", 86 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 87 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 88 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 89 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 90 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 91 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 92 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 93 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 94 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 95 | } 96 | 97 | _FROM_META_3D = { 98 | "tok_embeddings.token_embedding.weight": "tok_embeddings.token_embedding.weight", 99 | "tok_embeddings.added_embedding.weight": "tok_embeddings.added_embedding.weight", 100 | "tok_embeddings.vqgan_codebook.weight": "tok_embeddings.vqgan_codebook.weight", 101 | "tok_embeddings.vqgan_embed_proj.weight": "tok_embeddings.vqgan_embed_proj.weight", 102 | "norm.weight": "norm.scale", 103 | "output.weight": "output.weight", 104 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 105 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 106 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 107 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 108 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 109 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 110 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 111 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 112 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 113 | } 114 | 115 | # state dict key mappings from Meta's format to torchtune's format 116 | _FROM_META = { 117 | "tok_embeddings.weight": "tok_embeddings.weight", 118 | "norm.weight": "norm.scale", 119 | "output.weight": "output.weight", 120 | "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", 121 | "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", 122 | "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", 123 | "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", 124 | "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", 125 | "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", 126 | "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", 127 | "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", 128 | "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", 129 | } 130 | 131 | # state dict key mappings from HF's format to torchtune's format 132 | _FROM_HF = { 133 | "model.embed_tokens.weight": "tok_embeddings.weight", 134 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", 135 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", 136 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", 137 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", 138 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 139 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", 140 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", 141 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", 142 | "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", 143 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", 144 | "model.norm.weight": "norm.scale", 145 | "lm_head.weight": "output.weight", 146 | } 147 | 148 | 149 | def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: 150 | try: 151 | # Checks if there is a layer # in the key 152 | if any(k.isdigit() for k in key.split(".")): 153 | # Replace layer number with "{}" to create key for lookup 154 | abstract_key = re.sub(r"(\.\d+)", ".{}", key) 155 | layer_num = re.search(r"\d+", key).group(0) 156 | new_key = mapping_dict[abstract_key] 157 | new_key = new_key.format(layer_num) 158 | else: 159 | new_key = mapping_dict[key] 160 | except KeyError as e: 161 | raise Exception( 162 | f'Error converting the state dict. Found unexpected key: "{key}". ' 163 | "Please make sure you're loading a checkpoint with the right format. " 164 | ) from e 165 | 166 | return new_key 167 | 168 | 169 | def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 170 | """ 171 | Convert a state dict from Meta's format to torchtune's format. State dicts 172 | from multiple checkpoint files should be consolidated into a single state dict 173 | before calling this function. 174 | 175 | Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` 176 | repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). 177 | 178 | Args: 179 | state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. 180 | 181 | Returns: 182 | Dict[str, torch.Tensor]: State dict in torchtune's format. 183 | """ 184 | converted_state_dict = {} 185 | for key, value in state_dict.items(): 186 | if key not in ["rope.freqs"]: # Skip loading the position embeddings 187 | new_key = get_mapped_key(key, _FROM_META) 188 | converted_state_dict[new_key] = value 189 | 190 | return converted_state_dict 191 | 192 | 193 | def meta_to_tune_3d( 194 | state_dict: Dict[str, torch.Tensor], convert_weights_type: str 195 | ) -> Dict[str, torch.Tensor]: 196 | """ 197 | Convert a state dict from Meta's format to torchtune's format. State dicts 198 | from multiple checkpoint files should be consolidated into a single state dict 199 | before calling this function. 200 | 201 | Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` 202 | repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). 203 | 204 | Args: 205 | state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. 206 | 207 | Returns: 208 | Dict[str, torch.Tensor]: State dict in torchtune's format. 209 | """ 210 | if convert_weights_type == "3d": 211 | dictionary_to_use = _FROM_META_3D 212 | elif convert_weights_type == "3d_sin_cos_num": 213 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM 214 | elif convert_weights_type == "3d_sin_cos_plus_learned_num": 215 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM 216 | elif convert_weights_type == "3d_sin_cos_plus_learned_num_with_mlp": 217 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP 218 | elif convert_weights_type == "3d_sin_cos_num_with_mlp": 219 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM_WITH_MLP 220 | else: 221 | raise ValueError( 222 | "convert_weights_type should be one of '3d', '3d_sin_cos_num', '3d_sin_cos_plus_learned_num', '3d_sin_cos_plus_learned_num_with_mlp', '3d_sin_cos_num_with_mlp'" 223 | ) 224 | converted_state_dict = {} 225 | for key, value in state_dict.items(): 226 | if key not in ["rope.freqs"]: # Skip loading the position embeddings 227 | new_key = get_mapped_key(key, dictionary_to_use) 228 | converted_state_dict[new_key] = value 229 | 230 | return converted_state_dict 231 | 232 | 233 | def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 234 | """ 235 | Convert a state dict from torchtune's format to Meta's format. This function 236 | doesn't handle any sharding or splitting of state dicts. It follows the 237 | state_dict IN -> state_dict OUT pattern. 238 | 239 | Args: 240 | state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. 241 | 242 | Returns: 243 | Dict[str, torch.Tensor]: State dict in Meta's format. 244 | """ 245 | converted_state_dict = {} 246 | inverted_mapping_dict = {v: k for k, v in _FROM_META.items()} 247 | 248 | for key, value in state_dict.items(): 249 | new_key = get_mapped_key(key, inverted_mapping_dict) 250 | converted_state_dict[new_key] = value 251 | 252 | return converted_state_dict 253 | 254 | 255 | def tune_to_meta_3d( 256 | state_dict: Dict[str, torch.Tensor], convert_weights_type: str 257 | ) -> Dict[str, torch.Tensor]: 258 | """ 259 | Convert a state dict from torchtune's format to Meta's format. This function 260 | doesn't handle any sharding or splitting of state dicts. It follows the 261 | state_dict IN -> state_dict OUT pattern. 262 | 263 | Args: 264 | state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. 265 | 266 | Returns: 267 | Dict[str, torch.Tensor]: State dict in Meta's format. 268 | """ 269 | if convert_weights_type == "3d": 270 | dictionary_to_use = _FROM_META_3D 271 | elif convert_weights_type == "3d_sin_cos_num": 272 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM 273 | elif convert_weights_type == "3d_sin_cos_plus_learned_num": 274 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM 275 | elif convert_weights_type == "3d_sin_cos_plus_learned_num_with_mlp": 276 | dictionary_to_use = _FROM_META_3D_SIN_COS_PLUS_LEARNED_NUM_WITH_MLP 277 | elif convert_weights_type == "3d_sin_cos_num_with_mlp": 278 | dictionary_to_use = _FROM_META_3D_SIN_COS_NUM_WITH_MLP 279 | else: 280 | raise ValueError( 281 | "convert_weights_type should be one of '3d', '3d_sin_cos_num', '3d_sin_cos_plus_learned_num', '3d_sin_cos_plus_learned_num_with_mlp', '3d_sin_cos_num_with_mlp'" 282 | ) 283 | converted_state_dict = {} 284 | inverted_mapping_dict = {v: k for k, v in dictionary_to_use.items()} 285 | 286 | for key, value in state_dict.items(): 287 | new_key = get_mapped_key(key, inverted_mapping_dict) 288 | converted_state_dict[new_key] = value 289 | 290 | return converted_state_dict 291 | -------------------------------------------------------------------------------- /checkpointer/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional, Protocol 10 | 11 | import torch 12 | from torchtune import training 13 | 14 | import checkpointer.convert_weights as convert_weights 15 | from torchtune.training.checkpointing._utils import ( 16 | get_path, 17 | ModelType, 18 | safe_torch_load, 19 | ) 20 | from torchtune.utils._logging import get_logger 21 | 22 | logger = get_logger("DEBUG") 23 | 24 | 25 | class _CheckpointerInterface(Protocol): 26 | """ 27 | Interface implemented by Checkpointers in torchtune. 28 | 29 | torchtune checkpointers are designed to be composable components which can be plugged 30 | into any training recipe. Each checkpointer supports a specific set of models and training 31 | scenarios making these easy to understand, debug and extend. For example, the 32 | ``FullModelCheckpointer``s are used for loading and saving all of the model weights. 33 | This checkpointer can be used for Full-Finetuning scenarios or PEFT where the output is a 34 | merged checkpoint. In case the current suite of checkpointers are inadequate, 35 | users are encouraged to implement their own and contribute back to torchtune. 36 | 37 | torchtune is also designed to be "state-dict invariant". This means the checkpointer 38 | ensures that the output checkpoint has the same format as the original checkpoint i.e. 39 | the output checkpoint has the same keys split across the same number of files as the original 40 | checkpoint. Being "state-dict invariant" allows users to seamlessly use torchtune checkpoints 41 | with their favorite post-training tools from the open-source ecosystem without writing 42 | torchtune-specific convertors. To be "state-dict invariant", the ``load_checkpoint`` and 43 | ``save_checkpoint`` methods make use of the weight convertors available in 44 | ``torchtune/models/