├── data_utils ├── __init__.py └── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── datasets.cpython-38.pyc │ └── processor.cpython-38.pyc ├── aloha_scripts ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── constants.cpython-38.pyc └── constants.py ├── llava-pythia ├── __init__.py ├── llava_pythia │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ ├── language_model │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ └── __init__.cpython-38.pyc │ │ │ └── pythia │ │ │ │ ├── __pycache__ │ │ │ │ ├── llava_pythia.cpython-38.pyc │ │ │ │ └── configuration_llava_pythia.cpython-38.pyc │ │ │ │ └── configuration_llava_pythia.py │ │ ├── __pycache__ │ │ │ ├── builder.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── llava_arch.cpython-38.pyc │ │ ├── multimodal_projector │ │ │ ├── __pycache__ │ │ │ │ ├── builder.cpython-310.pyc │ │ │ │ └── builder.cpython-38.pyc │ │ │ └── builder.py │ │ ├── multimodal_encoder │ │ │ ├── __pycache__ │ │ │ │ ├── clip_encoder.cpython-310.pyc │ │ │ │ ├── clip_encoder.cpython-38.pyc │ │ │ │ ├── siglip_encoder.cpython-310.pyc │ │ │ │ └── siglip_encoder.cpython-38.pyc │ │ │ ├── clip_encoder.py │ │ │ └── siglip_encoder.py │ │ └── builder.py │ ├── train │ │ └── __pycache__ │ │ │ └── llava_pythia_trainer.cpython-38.pyc │ ├── constants.py │ ├── utils.py │ ├── mm_utils.py │ └── conversation.py ├── scripts │ ├── llava_pythia │ │ ├── all.sh │ │ ├── .ipynb_checkpoints │ │ │ ├── all-checkpoint.sh │ │ │ ├── all_train_robot-checkpoint.sh │ │ │ ├── get_base_model-checkpoint.sh │ │ │ ├── pretrain-checkpoint.sh │ │ │ ├── finetune-checkpoint.sh │ │ │ ├── train_robot-checkpoint.sh │ │ │ └── lora_train_robot-checkpoint.sh │ │ ├── all_train_robot.sh │ │ ├── eval │ │ │ ├── mme.sh │ │ │ ├── .ipynb_checkpoints │ │ │ │ ├── mme-checkpoint.sh │ │ │ │ ├── mmvet-checkpoint.sh │ │ │ │ ├── pope-checkpoint.sh │ │ │ │ ├── textvqa-checkpoint.sh │ │ │ │ ├── mmbench-checkpoint.sh │ │ │ │ ├── sqa-checkpoint.sh │ │ │ │ ├── vqav2-checkpoint.sh │ │ │ │ └── gqa-checkpoint.sh │ │ │ ├── vizwiz.sh │ │ │ ├── mmvet.sh │ │ │ ├── pope.sh │ │ │ ├── textvqa.sh │ │ │ ├── mmbench.sh │ │ │ ├── sqa.sh │ │ │ ├── vqav2.sh │ │ │ └── gqa.sh │ │ ├── get_base_model.sh │ │ ├── pretrain.sh │ │ ├── finetune.sh │ │ ├── train_robot.sh │ │ └── lora_train_robot.sh │ ├── convert_mmvet_for_eval.py │ ├── convert_gqa_for_eval.py │ ├── zero2.json │ ├── .ipynb_checkpoints │ │ ├── zero2-checkpoint.json │ │ ├── zero3-checkpoint.json │ │ ├── training_states_2_tensorboard-checkpoint.py │ │ ├── zero3_offload-checkpoint.json │ │ ├── display_eval_results_all-checkpoint.py │ │ └── convert_sqa_to_llava-checkpoint.py │ ├── merge_lora_weights.py │ ├── zero3.json │ ├── training_states_2_tensorboard.py │ ├── convert_mmbench_for_submission.py │ ├── zero3_offload.json │ ├── convert_vizwiz_for_submission.py │ ├── display_eval_results_all.py │ ├── convert_seed_for_submission.py │ ├── convert_sqa_to_llava.py │ ├── tranfer2llava.py │ └── convert_vqav2_for_submission.py ├── preprocessor_config.json └── pyproject.toml ├── policy_heads ├── __init__.py ├── util │ ├── __init__.py │ ├── __pycache__ │ │ ├── misc.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ ├── box_ops.py │ └── plot_utils.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── backbone.cpython-38.pyc │ │ ├── detr_vae.cpython-38.pyc │ │ ├── transformer.cpython-38.pyc │ │ ├── position_encoding.cpython-38.pyc │ │ └── droid_unet_diffusion.cpython-38.pyc │ ├── __init__.py │ ├── position_encoding.py │ ├── backbone.py │ └── droid_unet_diffusion.py ├── setup.py ├── README.md └── LICENSE ├── __pycache__ └── torch_utils.cpython-38.pyc ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── Open_TinyVLA.iml ├── setup.py ├── scripts ├── process_ckpts.sh └── train.sh ├── LICENSE ├── .gitignore ├── requirements.txt └── README.md /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aloha_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava-pythia/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /policy_heads/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/language_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /policy_heads/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./scripts/llava_pythia/pretrain.sh && ./scripts/llava_pythia/finetune.sh -------------------------------------------------------------------------------- /__pycache__/torch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/__pycache__/torch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /aloha_scripts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/aloha_scripts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/processor.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/util/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/util/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /aloha_scripts/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/aloha_scripts/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/all-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./scripts/llava_pythia/pretrain.sh && ./scripts/llava_pythia/finetune.sh -------------------------------------------------------------------------------- /policy_heads/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/detr_vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/detr_vae.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/position_encoding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/position_encoding.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/__pycache__/llava_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/llava_arch.cpython-38.pyc -------------------------------------------------------------------------------- /policy_heads/models/__pycache__/droid_unet_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/droid_unet_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/train/__pycache__/llava_pythia_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/train/__pycache__/llava_pythia_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/language_model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/llava_pythia.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/llava_pythia.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/configuration_llava_pythia.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/configuration_llava_pythia.cpython-38.pyc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='act', 6 | version='0.0.0', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) 11 | -------------------------------------------------------------------------------- /policy_heads/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='policy_heads', 6 | version='0.0.0', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/all_train_robot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #"1B" "410M" "14M" "1_4B" "70M" "2_8B" 3 | # 循环调用另一个脚本 4 | #"14M" "70M" "1B" "160M" 5 | for i in "14M" "70M" "160M" "1B"; do 6 | echo "Loop iteration $i" 7 | # 调用另一个脚本并传递参数 8 | ./scripts/llava_pythia/lora_train_robot.sh "$i" 9 | done 10 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/all_train_robot-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #"1B" "410M" "14M" "1_4B" "70M" "2_8B" 3 | # 循环调用另一个脚本 4 | #"14M" "70M" "1B" "160M" 5 | for i in "14M" "70M" "160M" "1B"; do 6 | echo "Loop iteration $i" 7 | # 调用另一个脚本并传递参数 8 | ./scripts/llava_pythia/lora_train_robot.sh "$i" 9 | done 10 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava-pythia/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 336, 3 | "do_center_crop": true, 4 | "do_normalize": true, 5 | "do_resize": true, 6 | "feature_extractor_type": "CLIPFeatureExtractor", 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_std": [ 13 | 0.26862954, 14 | 0.26130258, 15 | 0.27577711 16 | ], 17 | "resample": 3, 18 | "size": 336 19 | } 20 | -------------------------------------------------------------------------------- /policy_heads/README.md: -------------------------------------------------------------------------------- 1 | This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0. 2 | 3 | @article{Carion2020EndtoEndOD, 4 | title={End-to-End Object Detection with Transformers}, 5 | author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko}, 6 | journal={ArXiv}, 7 | year={2020}, 8 | volume={abs/2005.12872} 9 | } -------------------------------------------------------------------------------- /policy_heads/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .detr_vae import build as build_vae 3 | from .detr_vae import build_cnnmlp as build_cnnmlp 4 | from .detr_vae import build_vae_head 5 | from .droid_unet_diffusion import ConditionalUnet1D 6 | def build_ACT_model(args): 7 | return build_vae(args) 8 | 9 | def build_ACT_head(args): 10 | return build_vae_head(args) 11 | def build_CNNMLP_model(args): 12 | return build_cnnmlp(args) 13 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_mmvet_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | cur_result = {} 11 | 12 | for line in open(args.src): 13 | data = json.loads(line) 14 | qid = data['question_id'] 15 | cur_result[f'v1_{qid}'] = data['text'] 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(cur_result, f, indent=2) 19 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_gqa_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | all_answers = [] 11 | for line_idx, line in enumerate(open(args.src)): 12 | res = json.loads(line) 13 | question_id = res['question_id'] 14 | text = res['text'].rstrip('.').lower() 15 | all_answers.append({"questionId": question_id, "prediction": text}) 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(all_answers, f) 19 | -------------------------------------------------------------------------------- /.idea/Open_TinyVLA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /llava-pythia/scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/zero2-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/mme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 7 | --answers-file ./playground/data/eval/MME/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | cd ./playground/data/eval/MME 12 | 13 | python convert_answer_to_mme.py --experiment llavaPhi-v0-3b 14 | 15 | cd eval_tool 16 | 17 | python calculation.py --results_dir answers/llavaPhi-v0-3b 18 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mme-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 7 | --answers-file ./playground/data/eval/MME/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | cd ./playground/data/eval/MME 12 | 13 | python convert_answer_to_mme.py --experiment llavaPhi-v0-3b 14 | 15 | cd eval_tool 16 | 17 | python calculation.py --results_dir answers/llavaPhi-v0-3b 18 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/vizwiz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path checkpoints/llavaPhi-v0-3b-finetune \ 5 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 6 | --image-folder ./playground/data/eval/vizwiz/test \ 7 | --answers-file ./playground/data/eval/vizwiz/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode phi-2_v0 10 | 11 | python scripts/convert_vizwiz_for_submission.py \ 12 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 13 | --result-file ./playground/data/eval/vizwiz/answers/llavaPhi-v0-3b.jsonl \ 14 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/llavaPhi-v0-3b.json 15 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 6 | --image-folder ./playground/data/eval/mm-vet/images \ 7 | --answers-file ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | mkdir -p ./playground/data/eval/mm-vet/results 12 | 13 | python scripts/convert_mmvet_for_eval.py \ 14 | --src ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \ 15 | --dst ./playground/data/eval/mm-vet/results/llavaPhi-v0-3b.json 16 | 17 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/pope.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 6 | --image-folder /data/team/zhumj/data/coco/val2014 \ 7 | --answers-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | python llava_pythia/eval/eval_pope.py \ 12 | --annotation-dir ./playground/data/eval/pope/coco \ 13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 14 | --result-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/textvqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 6 | --image-folder /data/team/zhumj/data/finetune/data/textvqa/train_images \ 7 | --answers-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | python -m llava_pythia.eval.eval_textvqa \ 12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ 13 | --result-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl 14 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mmvet-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 6 | --image-folder ./playground/data/eval/mm-vet/images \ 7 | --answers-file ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | mkdir -p ./playground/data/eval/mm-vet/results 12 | 13 | python scripts/convert_mmvet_for_eval.py \ 14 | --src ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \ 15 | --dst ./playground/data/eval/mm-vet/results/llavaPhi-v0-3b.json 16 | 17 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/pope-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 6 | --image-folder /data/team/zhumj/data/coco/val2014 \ 7 | --answers-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | python llava_pythia/eval/eval_pope.py \ 12 | --annotation-dir ./playground/data/eval/pope/coco \ 13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 14 | --result-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/textvqa-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_loader \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 6 | --image-folder /data/team/zhumj/data/finetune/data/textvqa/train_images \ 7 | --answers-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode pythia 10 | 11 | python -m llava_pythia.eval.eval_textvqa \ 12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ 13 | --result-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl 14 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/mmbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT="mmbench_dev_20230712" 4 | LLM_MODEL_SIZE=2_8B 5 | python -m llava_pythia.eval.model_vqa_mmbench \ 6 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llavaPhi-v0-3b.jsonl \ 9 | --single-pred-prompt \ 10 | --temperature 0 \ 11 | --conv-mode pythia 12 | 13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 14 | 15 | python scripts/convert_mmbench_for_submission.py \ 16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \ 18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \ 19 | --experiment llavaPhi-v0-3b -------------------------------------------------------------------------------- /llava-pythia/scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava_pythia.model.builder import load_pretrained_model 3 | from llava_pythia.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mmbench-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT="mmbench_dev_20230712" 4 | LLM_MODEL_SIZE=2_8B 5 | python -m llava_pythia.eval.model_vqa_mmbench \ 6 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llavaPhi-v0-3b.jsonl \ 9 | --single-pred-prompt \ 10 | --temperature 0 \ 11 | --conv-mode pythia 12 | 13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 14 | 15 | python scripts/convert_mmbench_for_submission.py \ 16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \ 18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \ 19 | --experiment llavaPhi-v0-3b -------------------------------------------------------------------------------- /llava-pythia/scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_science \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 6 | --image-folder ./playground/data/eval/scienceqa/images/test \ 7 | --answers-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \ 8 | --single-pred-prompt \ 9 | --temperature 0 \ 10 | --conv-mode pythia 11 | 12 | python llava_pythia/eval/eval_science_qa.py \ 13 | --base-dir ./playground/data/eval/scienceqa \ 14 | --result-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \ 15 | --output-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_output.jsonl \ 16 | --output-result ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_result.json 17 | 18 | -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/zero3-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/sqa-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=2_8B 3 | python -m llava_pythia.eval.model_vqa_science \ 4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 6 | --image-folder ./playground/data/eval/scienceqa/images/test \ 7 | --answers-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \ 8 | --single-pred-prompt \ 9 | --temperature 0 \ 10 | --conv-mode pythia 11 | 12 | python llava_pythia/eval/eval_science_qa.py \ 13 | --base-dir ./playground/data/eval/scienceqa \ 14 | --result-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \ 15 | --output-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_output.jsonl \ 16 | --output-result ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_result.json 17 | 18 | -------------------------------------------------------------------------------- /llava-pythia/scripts/training_states_2_tensorboard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import os 4 | import json 5 | 6 | def main(): 7 | # create SummaryWriter 8 | pythia = "410M" 9 | log_p = f'/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/log' 10 | 11 | trainint_state_p = f"/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/trainer_state.json" 12 | 13 | os.makedirs(log_p, exist_ok=True) 14 | 15 | writer = SummaryWriter(log_dir=log_p) 16 | 17 | with open(trainint_state_p, "r") as f: 18 | data = json.load(f) 19 | 20 | # save loss in SummaryWriter 21 | for each in data['log_history']: 22 | if not 'loss' in each.keys(): 23 | continue 24 | step, loss = each['step'], each['loss'] 25 | writer.add_scalar('train/loss', loss, step) 26 | 27 | writer.close() 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_mmbench_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str, required=True) 9 | parser.add_argument("--result-dir", type=str, required=True) 10 | parser.add_argument("--upload-dir", type=str, required=True) 11 | parser.add_argument("--experiment", type=str, required=True) 12 | 13 | return parser.parse_args() 14 | 15 | if __name__ == "__main__": 16 | args = get_args() 17 | 18 | df = pd.read_table(args.annotation_file) 19 | 20 | cur_df = df.copy() 21 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) 22 | cur_df.insert(6, 'prediction', None) 23 | for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")): 24 | pred = json.loads(pred) 25 | cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text'] 26 | 27 | cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl') 28 | -------------------------------------------------------------------------------- /scripts/process_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This scripts is used to process the trained weights and generates a smaller and compact weights 3 | LLM_MODEL_SIZE=410M 4 | 5 | 6 | # path to trained TinyVLA weights 7 | source_dir="/path/to/trained/VLA/weights" 8 | # new path to save weights 9 | target_dir="/path/to/save/processed/VLA/weights" 10 | 11 | mkdir -p $target_dir 12 | 13 | exclude_pattern="global_step*" 14 | 15 | echo "copying checkpoints from $source_dir to $target_dir" 16 | rsync -av --exclude="$exclude_pattern" --exclude="$exclude_pattern/**" "$source_dir/" "$target_dir/" 17 | 18 | echo 'tranfer checkpoints to non_lora_trainables.bin' 19 | for dir in "$source_dir"/*/ ; do 20 | 21 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 22 | if ! find "$dir" -mindepth 1 -type f -name "non_lora_trainables.bin" | grep -q .; then 23 | cd "$dir" || exit 24 | python ./zero_to_fp32.py ./ ${target_dir}/$(basename "$dir")/non_lora_trainables.bin 25 | # cp $OUTPUT/non_lora_trainables.bin $dir 26 | fi 27 | fi 28 | done 29 | 30 | cd "/data/junjiewen/droid_results/checkpoint_all" || exit 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tony Z. Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/training_states_2_tensorboard-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import os 4 | import json 5 | 6 | def main(): 7 | # 创建SummaryWriter对象 8 | pythia = "410M" 9 | log_p = f'/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/log' 10 | 11 | trainint_state_p = f"/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/trainer_state.json" 12 | 13 | os.makedirs(log_p, exist_ok=True) 14 | 15 | writer = SummaryWriter(log_dir=log_p) 16 | 17 | # 假设loss_data是你保存的loss数据,格式为字典 18 | with open(trainint_state_p, "r") as f: 19 | data = json.load(f) 20 | 21 | # 将loss数据写入SummaryWriter 22 | for each in data['log_history']: 23 | if not 'loss' in each.keys(): 24 | continue 25 | step, loss = each['step'], each['loss'] 26 | writer.add_scalar('train/loss', loss, step) 27 | 28 | # 关闭SummaryWriter对象 29 | writer.close() 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /llava-pythia/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava_pythia" 7 | version = "1.0.0" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers==0.15.0", 18 | "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.9.5", 21 | "peft==0.4.0", 22 | "transformers==4.37.1", 23 | "accelerate==0.21.0", 24 | "bitsandbytes==0.41.0", 25 | "scikit-learn==1.2.2", 26 | "sentencepiece==0.1.99", 27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 28 | "gradio_client==0.2.9" 29 | ] 30 | 31 | [project.urls] 32 | "Bug Tracker" = "https://github.com/zhuyiche/llava-phi/issues" 33 | 34 | [tool.setuptools.packages.find] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | 37 | [tool.wheel] 38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 39 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/vqav2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | LLM_MODEL_SIZE=2_8B 8 | CKPT="llavaPhi-v0-3b-finetune" 9 | SPLIT="llava_vqav2_mscoco_test-dev2015" 10 | 11 | for IDX in $(seq 0 $((CHUNKS-1))); do 12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \ 13 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 15 | --image-folder /data/team/zhumj/data/coco/test2015 \ 16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 17 | --num-chunks $CHUNKS \ 18 | --chunk-idx $IDX \ 19 | --temperature 0 \ 20 | --conv-mode pythia & 21 | done 22 | 23 | wait 24 | 25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl 26 | 27 | # Clear out the output file if it exists. 28 | > "$output_file" 29 | 30 | # Loop through the indices and concatenate each file. 31 | for IDX in $(seq 0 $((CHUNKS-1))); do 32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 33 | done 34 | 35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT 36 | 37 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/vqav2-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | LLM_MODEL_SIZE=2_8B 8 | CKPT="llavaPhi-v0-3b-finetune" 9 | SPLIT="llava_vqav2_mscoco_test-dev2015" 10 | 11 | for IDX in $(seq 0 $((CHUNKS-1))); do 12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \ 13 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 15 | --image-folder /data/team/zhumj/data/coco/test2015 \ 16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 17 | --num-chunks $CHUNKS \ 18 | --chunk-idx $IDX \ 19 | --temperature 0 \ 20 | --conv-mode pythia & 21 | done 22 | 23 | wait 24 | 25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl 26 | 27 | # Clear out the output file if it exists. 28 | > "$output_file" 29 | 30 | # Loop through the indices and concatenate each file. 31 | for IDX in $(seq 0 $((CHUNKS-1))); do 32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 33 | done 34 | 35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT 36 | 37 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | LLM_MODEL_SIZE=1B 8 | CKPT="llavaPhi-v0-3b" 9 | SPLIT="llavaPhi_gqa_testdev_balanced" 10 | GQADIR="./playground/data/eval/gqa/data" 11 | 12 | for IDX in $(seq 0 $((CHUNKS-1))); do 13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \ 14 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 16 | --image-folder /data/team/zhumj/data/finetune/data/gqa/images \ 17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 18 | --num-chunks $CHUNKS \ 19 | --chunk-idx $IDX \ 20 | --temperature 0 \ 21 | --conv-mode pythia & 22 | done 23 | 24 | wait 25 | 26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl 27 | 28 | # Clear out the output file if it exists. 29 | > "$output_file" 30 | 31 | # Loop through the indices and concatenate each file. 32 | for IDX in $(seq 0 $((CHUNKS-1))); do 33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 34 | done 35 | 36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json 37 | 38 | cd $GQADIR 39 | python eval/eval.py --tier testdev_balanced 40 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/gqa-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | LLM_MODEL_SIZE=1B 8 | CKPT="llavaPhi-v0-3b" 9 | SPLIT="llavaPhi_gqa_testdev_balanced" 10 | GQADIR="./playground/data/eval/gqa/data" 11 | 12 | for IDX in $(seq 0 $((CHUNKS-1))); do 13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \ 14 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 16 | --image-folder /data/team/zhumj/data/finetune/data/gqa/images \ 17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 18 | --num-chunks $CHUNKS \ 19 | --chunk-idx $IDX \ 20 | --temperature 0 \ 21 | --conv-mode pythia & 22 | done 23 | 24 | wait 25 | 26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl 27 | 28 | # Clear out the output file if it exists. 29 | > "$output_file" 30 | 31 | # Loop through the indices and concatenate each file. 32 | for IDX in $(seq 0 $((CHUNKS-1))); do 33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 34 | done 35 | 36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json 37 | 38 | cd $GQADIR 39 | python eval/eval.py --tier testdev_balanced 40 | -------------------------------------------------------------------------------- /llava-pythia/scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/zero3_offload-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/get_base_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | python llava_pythia/train/convert_model2base_llava_pythia.py \ 5 | --model_name_or_path /data/team/zhumj/model_Param/EleutherAI/pythia-$LLM_MODEL_SIZE \ 6 | --version plain \ 7 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \ 8 | --image_folder /data/team/zhumj/data/llava-pretrain/images \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --tune_mm_mlp_adapter True \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 16 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 2 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0.1 \ 27 | --warmup_ratio 0. \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | 37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | deepspeed --master_port 29600 llava_pythia/train/train.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \ 7 | --version plain \ 8 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \ 9 | --image_folder /data/team/zhumj/data/llava-pretrain/images \ 10 | --tune_mm_mlp_adapter True \ 11 | --freeze_vision_tower True \ 12 | --freeze_backbone True \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 32 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | 37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/get_base_model-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | python llava_pythia/train/convert_model2base_llava_pythia.py \ 5 | --model_name_or_path /data/team/zhumj/model_Param/EleutherAI/pythia-$LLM_MODEL_SIZE \ 6 | --version plain \ 7 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \ 8 | --image_folder /data/team/zhumj/data/llava-pretrain/images \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --tune_mm_mlp_adapter True \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 16 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 2 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0.1 \ 27 | --warmup_ratio 0. \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | 37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/pretrain-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | deepspeed --master_port 29600 llava_pythia/train/train.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \ 7 | --version plain \ 8 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \ 9 | --image_folder /data/team/zhumj/data/llava-pretrain/images \ 10 | --tune_mm_mlp_adapter True \ 11 | --freeze_vision_tower True \ 12 | --freeze_backbone True \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 32 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | 37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_vizwiz_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from llava_pythia.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--annotation-file', type=str, required=True) 11 | parser.add_argument('--result-file', type=str, required=True) 12 | parser.add_argument('--result-upload-file', type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | args = parse_args() 19 | 20 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) 21 | 22 | results = [] 23 | error_line = 0 24 | for line_idx, line in enumerate(open(args.result_file)): 25 | try: 26 | results.append(json.loads(line)) 27 | except: 28 | error_line += 1 29 | results = {x['question_id']: x['text'] for x in results} 30 | test_split = [json.loads(line) for line in open(args.annotation_file)] 31 | split_ids = set([x['question_id'] for x in test_split]) 32 | 33 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 34 | 35 | all_answers = [] 36 | 37 | answer_processor = EvalAIAnswerProcessor() 38 | 39 | for x in test_split: 40 | assert x['question_id'] in results 41 | all_answers.append({ 42 | 'image': x['image'], 43 | 'answer': answer_processor(results[x['question_id']]) 44 | }) 45 | 46 | with open(args.result_upload_file, 'w') as f: 47 | json.dump(all_answers, f) 48 | -------------------------------------------------------------------------------- /llava-pythia/scripts/display_eval_results_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ckpt_p = "/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{msz}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2/checkpoint-{ckpt}" 4 | 5 | model_size = ['14M', '70M', '160M', '410M', '1B'] 6 | 7 | TASK = ('kitchen_sdoor_open-v3', 'kitchen_micro_open-v3', 'kitchen_light_on-v3', 'kitchen_ldoor_open-v3','kitchen_knob1_on-v3') 8 | 9 | task_id = 0 10 | for task_id in range(5): 11 | # print(f"{TASK[task_id]}") 12 | with open('eval_results_all.txt', 'w') as f: 13 | f.write(f"eval results on Franka Kitchen {TASK[task_id]}:\n") 14 | 15 | for msz in model_size: 16 | f.write(f"###################Model size:{msz}###################\n") 17 | for i in range(2,6): 18 | ckpt = str(i*1000) 19 | p = ckpt_p.replace('{msz}', msz).replace('{ckpt}', ckpt) 20 | try: 21 | with open(os.path.join(p, f"{TASK[task_id]}.txt"), 'r') as f1: 22 | content = f1.read() 23 | content = content.split('\n') 24 | # for e in content: 25 | # if e == "": 26 | # continue 27 | # if '50_' not in e: 28 | # continue 29 | 30 | # f.write(f"ckpt_{i*1000}:{e.strip()}\n") 31 | f.write(f"ckpt_{ckpt}:{content[-1].strip()}\n") 32 | except Exception as e: 33 | print(e) 34 | pass 35 | with open('eval_results_all.txt', 'r') as f: 36 | data = f.read() 37 | print(data) 38 | 39 | -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/display_eval_results_all-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ckpt_p = "/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{msz}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2/checkpoint-{ckpt}" 4 | 5 | model_size = ['14M', '70M', '160M', '410M', '1B'] 6 | 7 | TASK = ('kitchen_sdoor_open-v3', 'kitchen_micro_open-v3', 'kitchen_light_on-v3', 'kitchen_ldoor_open-v3','kitchen_knob1_on-v3') 8 | 9 | task_id = 0 10 | for task_id in range(5): 11 | # print(f"{TASK[task_id]}") 12 | with open('eval_results_all.txt', 'w') as f: 13 | f.write(f"eval results on Franka Kitchen {TASK[task_id]}:\n") 14 | 15 | for msz in model_size: 16 | f.write(f"###################Model size:{msz}###################\n") 17 | for i in range(2,6): 18 | ckpt = str(i*1000) 19 | p = ckpt_p.replace('{msz}', msz).replace('{ckpt}', ckpt) 20 | try: 21 | with open(os.path.join(p, f"{TASK[task_id]}.txt"), 'r') as f1: 22 | content = f1.read() 23 | content = content.split('\n') 24 | # for e in content: 25 | # if e == "": 26 | # continue 27 | # if '50_' not in e: 28 | # continue 29 | 30 | # f.write(f"ckpt_{i*1000}:{e.strip()}\n") 31 | f.write(f"ckpt_{ckpt}:{content[-1].strip()}\n") 32 | except Exception as e: 33 | print(e) 34 | pass 35 | with open('eval_results_all.txt', 'r') as f: 36 | data = f.read() 37 | print(data) 38 | 39 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \ 7 | --version v0 \ 8 | --data_path /data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json \ 9 | --image_folder /data/team/zhumj/data/finetune/data \ 10 | --tune_mm_mlp_adapter True \ 11 | --freeze_vision_tower False \ 12 | --freeze_backbone Talse \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length False \ 17 | --bf16 True \ 18 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 4 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 4 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | 39 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 40 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 41 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/finetune-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=14M 3 | 4 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \ 7 | --version v0 \ 8 | --data_path /data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json \ 9 | --image_folder /data/team/zhumj/data/finetune/data \ 10 | --tune_mm_mlp_adapter True \ 11 | --freeze_vision_tower False \ 12 | --freeze_backbone Talse \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length False \ 17 | --bf16 True \ 18 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 4 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 4 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | 39 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 40 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 41 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ACTION_HEAD=droid_diffusion # specify action policy head type 4 | # define OUTPUT path 5 | 6 | OUTPUT=/path/to/save_dir 7 | 8 | if [ -d "$OUTPUT" ]; then 9 | echo 'output exists' 10 | else 11 | echo '!!output not exists!!' 12 | mkdir -p $OUTPUT 13 | fi 14 | # backup the train scripts 15 | cp ./scripts/train.sh $OUTPUT 16 | 17 | # detailed usage of each parameter can be found in train_tinyvla.py 18 | 19 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 ./train_tinyvla.py \ 20 | --deepspeed scripts/zero2.json \ 21 | --lora_enable True \ 22 | --lora_module 'vit llm' \ 23 | --load_pretrain False \ 24 | --pretrain_image_size 320 \ 25 | --lora_r 64 \ 26 | --lora_alpha 256 \ 27 | --non_lora_lr 2e-5 \ 28 | --task_name "example_task_config" \ 29 | --model_name_or_path /path/to/pretrained_vlm \ 30 | --version v0 \ 31 | --tune_mm_mlp_adapter True \ 32 | --freeze_vision_tower True \ 33 | --freeze_backbone True \ 34 | --mm_use_im_start_end False \ 35 | --mm_use_im_patch_token False \ 36 | --image_aspect_ratio pad \ 37 | --group_by_modality_length False \ 38 | --bf16 True \ 39 | --output_dir $OUTPUT \ 40 | --max_steps 10000 \ 41 | --per_device_train_batch_size 32 \ 42 | --gradient_accumulation_steps 1 \ 43 | --save_strategy "steps" \ 44 | --save_steps 1000 \ 45 | --save_total_limit 50 \ 46 | --learning_rate 2e-4 \ 47 | --weight_decay 0. \ 48 | --warmup_ratio 0.005 \ 49 | --lr_scheduler_type "cosine" \ 50 | --logging_steps 10 \ 51 | --tf32 True \ 52 | --model_max_length 2048 \ 53 | --gradient_checkpointing True \ 54 | --dataloader_num_workers 8 \ 55 | --lazy_preprocess True \ 56 | --action_head_type $ACTION_HEAD \ 57 | --use_state True \ 58 | --concat "token_cat" \ 59 | --window_size 6 \ 60 | --report_to tensorboard \ 61 | --logging_dir $OUTPUT/log 62 | 63 | for dir in "$OUTPUT"/*/ ; do 64 | # 检查文件夹名称是否包含'checkpoint' 65 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 66 | cp llava-pythia/preprocessor_config.json $dir 67 | fi 68 | done 69 | 70 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/train_robot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # LLM_MODEL_SIZE=$1 3 | LLM_MODEL_SIZE=2_8B 4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune 5 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-1view_adapter3 6 | 7 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \ 8 | # echo "waiting for 20minutes..." 9 | # sleep 20m 10 | 11 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 12 | --deepspeed ./scripts/zero2.json \ 13 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 14 | --version v0 \ 15 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \ 16 | --image_folder /data/team/zhumj/data/finetune/data \ 17 | --tune_mm_mlp_adapter True \ 18 | --freeze_vision_tower True \ 19 | --freeze_backbone True \ 20 | --mm_use_im_start_end False \ 21 | --mm_use_im_patch_token False \ 22 | --image_aspect_ratio pad \ 23 | --group_by_modality_length False \ 24 | --bf16 True \ 25 | --output_dir $OUTPUT \ 26 | --num_train_epochs 15 \ 27 | --per_device_train_batch_size 4 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "steps" \ 31 | --save_strategy "steps" \ 32 | --save_steps 1000 \ 33 | --save_total_limit 15 \ 34 | --learning_rate 3e-5 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.005 \ 37 | --lr_scheduler_type "cosine" \ 38 | --logging_steps 10 \ 39 | --tf32 True \ 40 | --model_max_length 2048 \ 41 | --gradient_checkpointing True \ 42 | --dataloader_num_workers 4 \ 43 | --lazy_preprocess True \ 44 | --action_head "fc" \ 45 | --use_state True \ 46 | --lora_enable False \ 47 | --window_size 6 \ 48 | --logging_dir $OUTPUT/log 49 | --report_to wandb 50 | 51 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 52 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 53 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT 54 | for dir in "$OUTPUT"/*/ ; do 55 | # 检查文件夹名称是否包含'checkpoint' 56 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 57 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir 58 | fi 59 | done 60 | 61 | cp ./scripts/llava_pythia/train_robot.sh $OUTPUT 62 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config): 34 | """ 35 | Constructs a vision projector based on the specified configuration. 36 | 37 | Args: 38 | - config: An object containing configuration attributes. It should have 39 | 'mm_projector_type' to specify the type of projector and 'mm_hidden_size' 40 | and 'hidden_size' for the dimensions of the layers. 41 | 42 | Returns: 43 | - A PyTorch module that acts as the vision projector. The type of module 44 | returned depends on the 'mm_projector_type' attribute in the config: 45 | - 'linear': Returns a linear layer mapping from mm_hidden_size to hidden_size. 46 | - 'mlp{n}x_gelu': Returns a sequential model with n layers, each consisting 47 | of a GELU activation followed by a linear layer. 48 | - 'identity': Returns an IdentityMap, which simply returns the input as is. 49 | 50 | Raises: 51 | - ValueError: If the 'mm_projector_type' is not recognized. 52 | """ 53 | projector_type = getattr(config, 'mm_projector_type', 'linear') 54 | 55 | if projector_type == 'linear': 56 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 57 | 58 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 59 | if mlp_gelu_match: 60 | mlp_depth = int(mlp_gelu_match.group(1)) 61 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 62 | for _ in range(1, mlp_depth): 63 | modules.append(nn.GELU()) 64 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 65 | return nn.Sequential(*modules) 66 | 67 | if projector_type == 'identity': 68 | return IdentityMap() 69 | 70 | raise ValueError(f'Unknown projector type: {projector_type}') 71 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/train_robot-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # LLM_MODEL_SIZE=$1 3 | LLM_MODEL_SIZE=2_8B 4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune 5 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-1view_adapter3 6 | 7 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \ 8 | # echo "waiting for 20minutes..." 9 | # sleep 20m 10 | 11 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 12 | --deepspeed ./scripts/zero2.json \ 13 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 14 | --version v0 \ 15 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \ 16 | --image_folder /data/team/zhumj/data/finetune/data \ 17 | --tune_mm_mlp_adapter True \ 18 | --freeze_vision_tower True \ 19 | --freeze_backbone True \ 20 | --mm_use_im_start_end False \ 21 | --mm_use_im_patch_token False \ 22 | --image_aspect_ratio pad \ 23 | --group_by_modality_length False \ 24 | --bf16 True \ 25 | --output_dir $OUTPUT \ 26 | --num_train_epochs 15 \ 27 | --per_device_train_batch_size 4 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "steps" \ 31 | --save_strategy "steps" \ 32 | --save_steps 1000 \ 33 | --save_total_limit 15 \ 34 | --learning_rate 3e-5 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.005 \ 37 | --lr_scheduler_type "cosine" \ 38 | --logging_steps 10 \ 39 | --tf32 True \ 40 | --model_max_length 2048 \ 41 | --gradient_checkpointing True \ 42 | --dataloader_num_workers 4 \ 43 | --lazy_preprocess True \ 44 | --action_head "fc" \ 45 | --use_state True \ 46 | --lora_enable False \ 47 | --window_size 6 \ 48 | --logging_dir $OUTPUT/log 49 | --report_to wandb 50 | 51 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 52 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 53 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT 54 | for dir in "$OUTPUT"/*/ ; do 55 | # 检查文件夹名称是否包含'checkpoint' 56 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 57 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir 58 | fi 59 | done 60 | 61 | cp ./scripts/llava_pythia/train_robot.sh $OUTPUT 62 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/lora_train_robot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=$1 3 | # LLM_MODEL_SIZE=2_8B 4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune 5 | # lora only vit and tune adapter 6 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2 7 | 8 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \ 9 | # echo "waiting for 6 hours..." 10 | # sleep 6h 11 | 12 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 13 | --lora_enable True --lora_r 64 --lora_alpha 256 --non_lora_lr 3e-5 \ 14 | --deepspeed ./scripts/zero2.json \ 15 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 16 | --version v0 \ 17 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \ 18 | --image_folder /data/team/zhumj/data/finetune/data \ 19 | --tune_mm_mlp_adapter True \ 20 | --freeze_vision_tower True \ 21 | --freeze_backbone True \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --image_aspect_ratio pad \ 25 | --group_by_modality_length False \ 26 | --bf16 True \ 27 | --output_dir $OUTPUT \ 28 | --num_train_epochs 15 \ 29 | --per_device_train_batch_size 4 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 4 \ 32 | --evaluation_strategy "steps" \ 33 | --save_strategy "steps" \ 34 | --save_steps 1000 \ 35 | --save_total_limit 15 \ 36 | --learning_rate 2e-4 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.005 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 10 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --action_head "fc" \ 47 | --use_state True \ 48 | --concat "channel_cat" \ 49 | --window_size 6 \ 50 | --report_to wandb \ 51 | --logging_dir $OUTPUT/log 52 | 53 | 54 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 55 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 56 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT 57 | for dir in "$OUTPUT"/*/ ; do 58 | # 检查文件夹名称是否包含'checkpoint' 59 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 60 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir 61 | # cp $OUTPUT/non_lora_trainables.bin $dir 62 | fi 63 | done 64 | 65 | cp ./scripts/llava_pythia/lora_train_robot.sh $OUTPUT 66 | -------------------------------------------------------------------------------- /llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/lora_train_robot-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LLM_MODEL_SIZE=$1 3 | # LLM_MODEL_SIZE=2_8B 4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune 5 | # lora only vit and tune adapter 6 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2 7 | 8 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \ 9 | # echo "waiting for 6 hours..." 10 | # sleep 6h 11 | 12 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \ 13 | --lora_enable True --lora_r 64 --lora_alpha 256 --non_lora_lr 3e-5 \ 14 | --deepspeed ./scripts/zero2.json \ 15 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \ 16 | --version v0 \ 17 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \ 18 | --image_folder /data/team/zhumj/data/finetune/data \ 19 | --tune_mm_mlp_adapter True \ 20 | --freeze_vision_tower True \ 21 | --freeze_backbone True \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --image_aspect_ratio pad \ 25 | --group_by_modality_length False \ 26 | --bf16 True \ 27 | --output_dir $OUTPUT \ 28 | --num_train_epochs 15 \ 29 | --per_device_train_batch_size 4 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 4 \ 32 | --evaluation_strategy "steps" \ 33 | --save_strategy "steps" \ 34 | --save_steps 1000 \ 35 | --save_total_limit 15 \ 36 | --learning_rate 2e-4 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.005 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 10 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --action_head "fc" \ 47 | --use_state True \ 48 | --concat "channel_cat" \ 49 | --window_size 6 \ 50 | --logging_dir $OUTPUT/log 51 | --report_to wandb 52 | 53 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json 54 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json 55 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT 56 | for dir in "$OUTPUT"/*/ ; do 57 | # 检查文件夹名称是否包含'checkpoint' 58 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then 59 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir 60 | # cp $OUTPUT/non_lora_trainables.bin $dir 61 | fi 62 | done 63 | 64 | cp ./scripts/llava_pythia/lora_train_robot.sh $OUTPUT 65 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_seed_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str) 9 | parser.add_argument("--result-file", type=str) 10 | parser.add_argument("--result-upload-file", type=str) 11 | return parser.parse_args() 12 | 13 | 14 | def eval_single(result_file, eval_only_type=None): 15 | results = {} 16 | for line in open(result_file): 17 | row = json.loads(line) 18 | results[row['question_id']] = row 19 | 20 | type_counts = {} 21 | correct_counts = {} 22 | for question_data in data['questions']: 23 | if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue 24 | data_type = question_data['question_type_id'] 25 | type_counts[data_type] = type_counts.get(data_type, 0) + 1 26 | try: 27 | question_id = int(question_data['question_id']) 28 | except: 29 | question_id = question_data['question_id'] 30 | if question_id not in results: 31 | correct_counts[data_type] = correct_counts.get(data_type, 0) 32 | continue 33 | row = results[question_id] 34 | if row['text'] == question_data['answer']: 35 | correct_counts[data_type] = correct_counts.get(data_type, 0) + 1 36 | 37 | total_count = 0 38 | total_correct = 0 39 | for data_type in sorted(type_counts.keys()): 40 | accuracy = correct_counts[data_type] / type_counts[data_type] * 100 41 | if eval_only_type is None: 42 | print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%") 43 | 44 | total_count += type_counts[data_type] 45 | total_correct += correct_counts[data_type] 46 | 47 | total_accuracy = total_correct / total_count * 100 48 | if eval_only_type is None: 49 | print(f"Total accuracy: {total_accuracy:.2f}%") 50 | else: 51 | print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%") 52 | 53 | return results 54 | 55 | if __name__ == "__main__": 56 | args = get_args() 57 | data = json.load(open(args.annotation_file)) 58 | ques_type_id_to_name = {id:n for n,id in data['question_type'].items()} 59 | 60 | results = eval_single(args.result_file) 61 | eval_single(args.result_file, eval_only_type='image') 62 | eval_single(args.result_file, eval_only_type='video') 63 | 64 | with open(args.result_upload_file, 'w') as fp: 65 | for question in data['questions']: 66 | qid = question['question_id'] 67 | if qid in results: 68 | result = results[qid] 69 | else: 70 | result = results[int(qid)] 71 | fp.write(json.dumps({ 72 | 'question_id': qid, 73 | 'prediction': result['text'] 74 | }) + '\n') 75 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /llava-pythia/scripts/.ipynb_checkpoints/convert_sqa_to_llava-checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import CLIPPreTrainedModel, CLIPVisionConfig 7 | from transformers.models.clip.modeling_clip import CLIPVisionTransformer 8 | from llava_pythia.model.language_model.pythia.configuration_llava_pythia import LlavaPythiaVisionConfig 9 | 10 | 11 | class CLIPVisionTower(CLIPPreTrainedModel): 12 | config_class = LlavaPythiaVisionConfig 13 | 14 | def __init__(self, config): 15 | super().__init__(config) 16 | 17 | self.vision_model = CLIPVisionTransformer(config) 18 | # Initialize weights and apply final processing 19 | self.post_init() 20 | 21 | def get_input_embeddings(self) -> nn.Module: 22 | return self.vision_model.embeddings.patch_embedding 23 | 24 | def feature_select(self, image_forward_outs): 25 | image_features = image_forward_outs.hidden_states[self.config.mm_vision_select_layer] 26 | if self.config.mm_vision_select_feature == 'patch': 27 | image_features = image_features[:, 1:] 28 | elif self.config.mm_vision_select_feature == 'cls_patch': 29 | image_features = image_features 30 | else: 31 | raise ValueError(f'Unexpected select feature: {self.config.mm_vision_select_feature}') 32 | return image_features 33 | 34 | def forward(self, images): 35 | if type(images) is list: 36 | image_features = [] 37 | for image in images: 38 | image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 39 | output_hidden_states=True) 40 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 41 | image_features.append(image_feature) 42 | else: 43 | image_forward_outs = self.vision_model(images.to(device=self.device, dtype=self.dtype), 44 | output_hidden_states=True) 45 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 46 | 47 | return image_features 48 | 49 | @property 50 | def dummy_feature(self): 51 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 52 | 53 | @property 54 | def dtype(self): 55 | return list(self.vision_model.parameters())[0].dtype 56 | 57 | @property 58 | def device(self): 59 | return list(self.vision_model.parameters())[0].device 60 | 61 | @property 62 | def hidden_size(self): 63 | return self.config.hidden_size 64 | 65 | @property 66 | def num_patches(self): 67 | return (self.config.image_size // self.config.patch_size) ** 2 68 | 69 | 70 | if __name__ == '__main__': 71 | clip_config = CLIPVisionConfig.from_pretrained( 72 | "/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336" 73 | ) 74 | print("################ clip_config ##############") 75 | print(clip_config) 76 | pythia_vis_config = LlavaPythiaVisionConfig(**clip_config.to_dict()) 77 | print("################ pythia_vis_config ##############") 78 | print(pythia_vis_config) 79 | 80 | model = CLIPVisionTower(clip_config) 81 | # print(list(model.vision_model.parameters())[0].dtype) 82 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/multimodal_encoder/siglip_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers.models.siglip import SiglipPreTrainedModel, SiglipVisionConfig 7 | from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer 8 | from llava_pythia.model.language_model.pythia.configuration_llava_pythia import LlavaPythiaVisionConfig 9 | 10 | 11 | class SiglipVisionTower(SiglipPreTrainedModel): 12 | config_class = LlavaPythiaVisionConfig 13 | 14 | def __init__(self, config): 15 | super().__init__(config) 16 | 17 | self.vision_model = SiglipVisionTransformer(config) 18 | # Initialize weights and apply final processing 19 | self.post_init() 20 | 21 | def get_input_embeddings(self) -> nn.Module: 22 | return self.vision_model.embeddings.patch_embedding 23 | 24 | def feature_select(self, image_forward_outs): 25 | image_features = image_forward_outs.hidden_states[self.config.mm_vision_select_layer] 26 | if self.config.mm_vision_select_feature == 'patch': 27 | image_features = image_features 28 | elif self.config.mm_vision_select_feature == 'cls_patch': 29 | image_features = image_features 30 | else: 31 | raise ValueError(f'Unexpected select feature: {self.config.mm_vision_select_feature}') 32 | return image_features 33 | 34 | def forward(self, images): 35 | if type(images) is list: 36 | image_features = [] 37 | for image in images: 38 | image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 39 | output_hidden_states=True) 40 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 41 | image_features.append(image_feature) 42 | else: 43 | image_forward_outs = self.vision_model(images.to(device=self.device, dtype=self.dtype), 44 | output_hidden_states=True) 45 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 46 | 47 | return image_features 48 | 49 | @property 50 | def dummy_feature(self): 51 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 52 | 53 | @property 54 | def dtype(self): 55 | return list(self.vision_model.parameters())[0].dtype 56 | 57 | @property 58 | def device(self): 59 | return list(self.vision_model.parameters())[0].device 60 | 61 | @property 62 | def hidden_size(self): 63 | return self.config.hidden_size 64 | 65 | @property 66 | def num_patches(self): 67 | return (self.config.image_size // self.config.patch_size) ** 2 68 | 69 | 70 | if __name__ == '__main__': 71 | clip_config = SiglipVisionConfig.from_pretrained( 72 | "/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336" 73 | ) 74 | print("################ clip_config ##############") 75 | print(clip_config) 76 | pythia_vis_config = LlavaPythiaVisionConfig(**clip_config.to_dict()) 77 | print("################ pythia_vis_config ##############") 78 | print(pythia_vis_config) 79 | 80 | model = SiglipVisionTower(clip_config) 81 | # print(list(model.vision_model.parameters())[0].dtype) 82 | -------------------------------------------------------------------------------- /aloha_scripts/constants.py: -------------------------------------------------------------------------------- 1 | 2 | ##################### Setting of training data ##################################### 3 | 4 | # DATA_DIR = '/path/to/your/data_dir' 5 | 6 | TASK_CONFIGS = { 7 | 'example_task_config': { # for local debug 8 | 'dataset_dir': [ 9 | "/media/rl/HDD/data/data/act/new_view/8_29_tennis", # task 1 10 | ], 11 | 'episode_len': 1000, # 1000, 12 | 'camera_names': ['left', 'right', 'wrist'] # corresponding to image keys saved in h5py files 13 | }, 14 | } 15 | #################################################################################### 16 | 17 | #!!!!!!!!!!!!!!!!!!!!!!Followings are copied from aloha which are not used!!!!!!!!!!!!!!!!!!!!!! 18 | ### ALOHA fixed constants 19 | DT = 0.02 20 | 21 | FPS = 50 22 | 23 | JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] 24 | START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] 25 | 26 | # Left finger position limits (qpos[7]), right_finger = -1 * left_finger 27 | MASTER_GRIPPER_POSITION_OPEN = 0.02417 28 | MASTER_GRIPPER_POSITION_CLOSE = 0.01244 29 | PUPPET_GRIPPER_POSITION_OPEN = 0.05800 30 | PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 31 | 32 | # Gripper joint limits (qpos[6]) 33 | MASTER_GRIPPER_JOINT_OPEN = 0.3083 34 | MASTER_GRIPPER_JOINT_CLOSE = -0.6842 35 | PUPPET_GRIPPER_JOINT_OPEN = 1.4910 36 | PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 37 | 38 | ############################ Helper functions ############################ 39 | 40 | MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) 41 | PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) 42 | MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE 43 | PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE 44 | MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) 45 | 46 | MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) 47 | PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) 48 | MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE 49 | PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE 50 | MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) 51 | 52 | MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) 53 | PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) 54 | 55 | MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE 56 | MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) 57 | PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE 58 | PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) 59 | 60 | MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | # *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .idea 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | *.egg-info/ 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | __pycache__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | robomimic-r2d2/ 152 | robomimic-r2d2.zip 153 | OUTPUT/ 154 | data/ 155 | checkpoints/ 156 | loglog/ 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | -------------------------------------------------------------------------------- /policy_heads/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | """ 11 | Convert bounding boxes from center-size format (cx, cy, w, h) to corner format (x0, y0, x1, y1). 12 | 13 | Args: 14 | x: Tensor of shape (..., 4) representing bounding boxes in center-size format. 15 | 16 | Returns: 17 | Tensor of shape (..., 4) representing bounding boxes in corner format. 18 | """ 19 | x_c, y_c, w, h = x.unbind(-1) 20 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 21 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | """ 27 | Convert bounding boxes from corner format (x0, y0, x1, y1) to center-size format (cx, cy, w, h). 28 | 29 | Args: 30 | x: Tensor of shape (..., 4) representing bounding boxes in corner format. 31 | 32 | Returns: 33 | Tensor of shape (..., 4) representing bounding boxes in center-size format. 34 | """ 35 | x0, y0, x1, y1 = x.unbind(-1) 36 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 37 | (x1 - x0), (y1 - y0)] 38 | return torch.stack(b, dim=-1) 39 | 40 | 41 | # modified from torchvision to also return the union 42 | def box_iou(boxes1, boxes2): 43 | """ 44 | Compute the Intersection over Union (IoU) between two sets of boxes. 45 | 46 | Args: 47 | boxes1: Tensor of shape (N, 4) representing the first set of boxes. 48 | boxes2: Tensor of shape (M, 4) representing the second set of boxes. 49 | 50 | Returns: 51 | iou: Tensor of shape (N, M) representing pairwise IoU values. 52 | union: Tensor of shape (N, M) representing pairwise union areas. 53 | """ 54 | area1 = box_area(boxes1) 55 | area2 = box_area(boxes2) 56 | 57 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 58 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 59 | 60 | wh = (rb - lt).clamp(min=0) # [N,M,2] 61 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 62 | 63 | union = area1[:, None] + area2 - inter 64 | 65 | iou = inter / union 66 | return iou, union 67 | 68 | 69 | def generalized_box_iou(boxes1, boxes2): 70 | """ 71 | Compute the Generalized Intersection over Union (GIoU) between two sets of boxes. 72 | 73 | Args: 74 | boxes1: Tensor of shape (N, 4) representing the first set of boxes in [x0, y0, x1, y1] format. 75 | boxes2: Tensor of shape (M, 4) representing the second set of boxes in [x0, y0, x1, y1] format. 76 | 77 | Returns: 78 | Tensor of shape (N, M) representing pairwise GIoU values. 79 | """ 80 | # degenerate boxes give inf / nan results, so do an early check 81 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 82 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 83 | iou, union = box_iou(boxes1, boxes2) 84 | 85 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 86 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 87 | 88 | wh = (rb - lt).clamp(min=0) # [N,M,2] 89 | area = wh[:, :, 0] * wh[:, :, 1] 90 | 91 | return iou - (area - union) / area 92 | 93 | 94 | def masks_to_boxes(masks): 95 | """ 96 | Compute the bounding boxes around the provided masks. 97 | 98 | Args: 99 | masks: Tensor of shape (N, H, W) where N is the number of masks, (H, W) are the spatial dimensions. 100 | 101 | Returns: 102 | Tensor of shape (N, 4) with the boxes in xyxy format. 103 | """ 104 | if masks.numel() == 0: 105 | return torch.zeros((0, 4), device=masks.device) 106 | 107 | h, w = masks.shape[-2:] 108 | 109 | y = torch.arange(0, h, dtype=torch.float) 110 | x = torch.arange(0, w, dtype=torch.float) 111 | y, x = torch.meshgrid(y, x) 112 | 113 | x_mask = (masks * x.unsqueeze(0)) 114 | x_max = x_mask.flatten(1).max(-1)[0] 115 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 116 | 117 | y_mask = (masks * y.unsqueeze(0)) 118 | y_max = y_mask.flatten(1).max(-1)[0] 119 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 120 | 121 | return torch.stack([x_min, y_min, x_max, y_max], 1) 122 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava_pythia.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.21.0 3 | aiofiles==23.2.1 4 | aiohappyeyeballs==2.4.4 5 | aiohttp==3.10.11 6 | aiosignal==1.3.1 7 | altair==5.3.0 8 | anyio==4.5.2 9 | appdirs==1.4.4 10 | argcomplete==3.3.0 11 | asttokens==2.4.1 12 | async-timeout==5.0.1 13 | attrs==23.2.0 14 | backcall==0.2.0 15 | beautifulsoup4==4.12.3 16 | bitsandbytes==0.41.0 17 | bleach==6.1.0 18 | cachetools==5.3.3 19 | catkin-pkg==1.0.0 20 | certifi==2024.2.2 21 | charset-normalizer==3.3.2 22 | click==8.1.8 23 | cloudpickle==3.0.0 24 | cmake==3.29.2 25 | colorama==0.3.0 26 | contourpy==1.1.1 27 | cycler==0.12.1 28 | decorator==5.1.1 29 | deepspeed==0.9.5 30 | defusedxml==0.7.1 31 | diffusers==0.11.1 32 | distro==1.9.0 33 | dm-control==1.0.14 34 | dm-env==1.6 35 | dm-tree==0.1.8 36 | docker-pycreds==0.4.0 37 | docopt==0.6.2 38 | docutils==0.20.1 39 | egl-probe==1.0.2 40 | einops==0.6.1 41 | einops-exts==0.0.4 42 | evdev==1.7.0 43 | exceptiongroup==1.2.2 44 | executing==2.0.1 45 | fastapi==0.110.2 46 | fastjsonschema==2.20.0 47 | ffmpy==0.3.2 48 | filelock==3.16.1 49 | fonttools==4.51.0 50 | frozenlist==1.5.0 51 | fsspec==2025.2.0 52 | gitdb==4.0.12 53 | GitPython==3.1.44 54 | glfw==2.7.0 55 | google-auth==2.29.0 56 | google-auth-oauthlib==1.0.0 57 | gradio==3.35.2 58 | gradio_client==0.2.9 59 | grpcio==1.62.2 60 | gym==0.26.2 61 | gym-notices==0.0.8 62 | h11==0.14.0 63 | h5py==3.11.0 64 | hjson==3.1.0 65 | httpcore==0.17.3 66 | httpx==0.24.0 67 | huggingface-hub==0.23.5 68 | idna==3.7 69 | imageio==2.34.1 70 | imageio-ffmpeg==0.4.9 71 | importlib_metadata==8.5.0 72 | importlib_resources==6.4.5 73 | ipython==8.12.3 74 | jedi==0.19.1 75 | Jinja2==3.1.4 76 | joblib==1.4.0 77 | jsonschema==4.21.1 78 | jsonschema-specifications==2023.12.1 79 | jupyter_client==8.6.3 80 | jupyter_core==5.7.2 81 | jupyterlab_pygments==0.3.0 82 | kiwisolver==1.4.5 83 | labmaze==1.0.6 84 | linkify-it-py==2.0.3 85 | lit==18.1.3 86 | llvmlite==0.41.1 87 | lxml==5.2.1 88 | Markdown==3.6 89 | markdown-it-py==2.2.0 90 | markdown2==2.4.13 91 | MarkupSafe==2.1.5 92 | matplotlib==3.7.5 93 | matplotlib-inline==0.1.7 94 | mdit-py-plugins==0.3.3 95 | mdurl==0.1.2 96 | mistune==3.0.2 97 | mpmath==1.3.0 98 | mujoco==2.3.7 99 | multidict==6.1.0 100 | nbclient==0.10.0 101 | nbconvert==7.16.4 102 | nbformat==5.10.4 103 | networkx==3.1 104 | ninja==1.11.1.1 105 | numba==0.58.1 106 | numpy==1.24.4 107 | nvidia-cublas-cu11==11.10.3.66 108 | nvidia-cuda-cupti-cu11==11.7.101 109 | nvidia-cuda-nvrtc-cu11==11.7.99 110 | nvidia-cuda-runtime-cu11==11.7.99 111 | nvidia-cudnn-cu11==8.5.0.96 112 | nvidia-cufft-cu11==10.9.0.58 113 | nvidia-curand-cu11==10.2.10.91 114 | nvidia-cusolver-cu11==11.4.0.1 115 | nvidia-cusparse-cu11==11.7.4.91 116 | nvidia-nccl-cu11==2.14.3 117 | nvidia-nvtx-cu11==11.7.91 118 | oauthlib==3.2.2 119 | opencv-python==4.6.0.66 120 | orjson==3.10.1 121 | packaging==24.0 122 | pandas==2.0.3 123 | pandocfilters==1.5.1 124 | parso==0.8.4 125 | peft==0.4.0 126 | pexpect==4.9.0 127 | pickleshare==0.7.5 128 | pillow==10.3.0 129 | pipreqs==0.5.0 130 | pkgutil_resolve_name==1.3.10 131 | platformdirs==4.3.6 132 | prompt_toolkit==3.0.47 133 | propcache==0.2.0 134 | protobuf==3.19.6 135 | psutil==6.1.1 136 | ptyprocess==0.7.0 137 | pure-eval==0.2.2 138 | py-cpuinfo==9.0.0 139 | pyasn1==0.6.0 140 | pyasn1_modules==0.4.0 141 | pydantic==1.10.15 142 | pydub==0.25.1 143 | Pygments==2.17.2 144 | pynput==1.7.6 145 | PyOpenGL==3.1.7 146 | pyparsing==3.1.4 147 | pyquaternion==0.9.9 148 | python-dateutil==2.9.0.post0 149 | python-multipart==0.0.9 150 | python-xlib==0.33 151 | pytz==2024.1 152 | PyYAML==6.0.1 153 | pyzmq==26.2.0 154 | referencing==0.34.0 155 | regex==2024.4.16 156 | requests==2.31.0 157 | requests-oauthlib==2.0.0 158 | rospkg==1.5.1 159 | rpds-py==0.18.0 160 | rsa==4.9 161 | safetensors==0.4.3 162 | scikit-learn==1.2.2 163 | scipy==1.10.1 164 | semantic-version==2.10.0 165 | sentencepiece==0.1.99 166 | sentry-sdk==1.45.0 167 | setproctitle==1.3.3 168 | shortuuid==1.0.13 169 | six==1.16.0 170 | smmap==5.0.2 171 | sniffio==1.3.1 172 | snowballstemmer==2.2.0 173 | soupsieve==2.6 174 | stack-data==0.6.3 175 | starlette==0.37.2 176 | svgwrite==1.4.3 177 | sympy==1.12 178 | tensorboard==2.14.0 179 | tensorboard-data-server==0.7.2 180 | tensorboardX==2.6 181 | termcolor==2.4.0 182 | threadpoolctl==3.4.0 183 | tianshou==0.4.10 184 | timm==0.6.13 185 | tinycss2==1.3.0 186 | tokenizers==0.15.0 187 | toolz==0.12.1 188 | torch==2.0.1 189 | torchvision==0.15.2 190 | tornado==6.4.1 191 | tqdm==4.67.1 192 | traitlets==5.14.3 193 | transformers==4.37.1 194 | triton==2.0.0 195 | typing_extensions==4.11.0 196 | tzdata==2025.1 197 | uc-micro-py==1.0.3 198 | urllib3==2.2.3 199 | uvicorn==0.29.0 200 | wandb==0.16.6 201 | wavedrom==2.0.3.post3 202 | wcwidth==0.2.13 203 | webencodings==0.5.1 204 | websockets==13.1 205 | Werkzeug==3.0.2 206 | yarg==0.1.9 207 | yarl==1.15.2 208 | zipp==3.20.2 209 | -------------------------------------------------------------------------------- /policy_heads/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from policy_heads.util.misc import NestedTensor 10 | 11 | import IPython 12 | e = IPython.embed 13 | 14 | class PositionEmbeddingSine(nn.Module): 15 | """ 16 | This is a more standard version of the position embedding, very similar to the one 17 | used by the Attention is all you need paper, generalized to work on images. 18 | """ 19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 20 | super().__init__() 21 | self.num_pos_feats = num_pos_feats 22 | self.temperature = temperature 23 | self.normalize = normalize 24 | if scale is not None and normalize is False: 25 | raise ValueError("normalize should be True if scale is passed") 26 | if scale is None: 27 | scale = 2 * math.pi 28 | self.scale = scale 29 | 30 | def forward(self, tensor): 31 | x = tensor 32 | # mask = tensor_list.mask 33 | # assert mask is not None 34 | # not_mask = ~mask 35 | 36 | not_mask = torch.ones_like(x[0, [0]]) 37 | y_embed = not_mask.cumsum(1, dtype=tensor.dtype) 38 | x_embed = not_mask.cumsum(2, dtype=tensor.dtype) 39 | if self.normalize: 40 | eps = 1e-6 41 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 42 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 43 | 44 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 45 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 46 | 47 | pos_x = x_embed[:, :, :, None] / dim_t 48 | pos_y = y_embed[:, :, :, None] / dim_t 49 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | 55 | class PositionEmbeddingLearned(nn.Module): 56 | """ 57 | Computes learned absolute positional embeddings for the input tensor. 58 | 59 | This method generates position embeddings based on the spatial dimensions 60 | (height and width) of the input tensor. The embeddings are learned through 61 | trainable `nn.Embedding` layers for both row (height) and column (width) indices. 62 | """ 63 | def __init__(self, num_pos_feats=256): 64 | super().__init__() 65 | self.row_embed = nn.Embedding(50, num_pos_feats) 66 | self.col_embed = nn.Embedding(50, num_pos_feats) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | nn.init.uniform_(self.row_embed.weight) 71 | nn.init.uniform_(self.col_embed.weight) 72 | 73 | def forward(self, tensor_list: NestedTensor): 74 | x = tensor_list.tensors 75 | h, w = x.shape[-2:] 76 | i = torch.arange(w, device=x.device) 77 | j = torch.arange(h, device=x.device) 78 | x_emb = self.col_embed(i) 79 | y_emb = self.row_embed(j) 80 | pos = torch.cat([ 81 | x_emb.unsqueeze(0).repeat(h, 1, 1), 82 | y_emb.unsqueeze(1).repeat(1, w, 1), 83 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 84 | return pos 85 | 86 | def position_encoding_1d(x): 87 | """ 88 | Generates 1-dimensional positional encoding. 89 | 90 | Args: 91 | seq_len (int): Length of the input sequence. 92 | d_model (int): Dimension of the model. 93 | 94 | Returns: 95 | np.array: 2D array of shape (seq_len, d_model) containing positional encodings. 96 | """ 97 | seq_len, d_model = x.shape[1:] 98 | pos_enc = torch.zeros((seq_len, d_model), dtype=x.dtype) 99 | position = torch.arange(0, seq_len, dtype=x.dtype).cuda().unsqueeze(1) 100 | 101 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=x.dtype) * (-math.log(10000.0) / d_model)).cuda()) 102 | 103 | pos_enc[:, 0::2] = torch.sin(position * div_term) 104 | pos_enc[:, 1::2] = torch.cos(position * div_term) 105 | 106 | return pos_enc.to(x.device) 107 | 108 | def build_position_encoding(args): 109 | N_steps = args.hidden_dim // 2 110 | if args.position_embedding in ('v2', 'sine'): 111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif args.position_embedding in ('v3', 'learned'): 113 | position_embedding = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {args.position_embedding}") 116 | 117 | return position_embedding -------------------------------------------------------------------------------- /policy_heads/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path, PurePath 11 | 12 | 13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 14 | ''' 15 | Function to plot specific fields from training log(s). Plots both training and test results. 16 | 17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 18 | - fields = which results to plot from each log file - plots both training and test for each field. 19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 20 | - log_name = optional, name of log file if different than default 'log.txt'. 21 | 22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 23 | - solid lines are training results, dashed lines are test results. 24 | 25 | ''' 26 | func_name = "plot_utils.py::plot_logs" 27 | 28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 29 | # convert single Path to list to avoid 'not iterable' error 30 | 31 | if not isinstance(logs, list): 32 | if isinstance(logs, PurePath): 33 | logs = [logs] 34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 35 | else: 36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 37 | Expect list[Path] or single Path obj, received {type(logs)}") 38 | 39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir 40 | for i, dir in enumerate(logs): 41 | if not isinstance(dir, PurePath): 42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 43 | if not dir.exists(): 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | # verify log_name exists 46 | fn = Path(dir / log_name) 47 | if not fn.exists(): 48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") 49 | print(f"--> full path of missing log file: {fn}") 50 | return 51 | 52 | # load log file(s) and plot 53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 54 | 55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 56 | 57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 58 | for j, field in enumerate(fields): 59 | if field == 'mAP': 60 | coco_eval = pd.DataFrame( 61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] 62 | ).ewm(com=ewm_col).mean() 63 | axs[j].plot(coco_eval, c=color) 64 | else: 65 | df.interpolate().ewm(com=ewm_col).mean().plot( 66 | y=[f'train_{field}', f'test_{field}'], 67 | ax=axs[j], 68 | color=[color] * 2, 69 | style=['-', '--'] 70 | ) 71 | for ax, field in zip(axs, fields): 72 | ax.legend([Path(p).name for p in logs]) 73 | ax.set_title(field) 74 | 75 | 76 | def plot_precision_recall(files, naming_scheme='iter'): 77 | if naming_scheme == 'exp_id': 78 | # name becomes exp_id 79 | names = [f.parts[-3] for f in files] 80 | elif naming_scheme == 'iter': 81 | names = [f.stem for f in files] 82 | else: 83 | raise ValueError(f'not supported {naming_scheme}') 84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 86 | data = torch.load(f) 87 | # precision is n_iou, n_points, n_cat, n_area, max_det 88 | precision = data['precision'] 89 | recall = data['params'].recThrs 90 | scores = data['scores'] 91 | # take precision for all classes, all areas and 100 detections 92 | precision = precision[0, :, :, 0, -1].mean(1) 93 | scores = scores[0, :, :, 0, -1].mean(1) 94 | prec = precision.mean() 95 | rec = data['recall'][0, :, 0, -1].mean() 96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 97 | f'score={scores.mean():0.3f}, ' + 98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 99 | ) 100 | axs[0].plot(recall, precision, c=color) 101 | axs[1].plot(recall, scores, c=color) 102 | 103 | axs[0].set_title('Precision / Recall') 104 | axs[0].legend(names) 105 | axs[1].set_title('Scores / Recall') 106 | axs[1].legend(names) 107 | return fig, axs 108 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava_pythia.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | """ 45 | Tokenizes a prompt string that may contain image placeholders and returns the tokenized input IDs. 46 | 47 | Args: 48 | prompt (str): The input string containing text and '' placeholders. 49 | tokenizer: The tokenizer object used to tokenize the text. 50 | image_token_index (int, optional): The token index used to represent the '' placeholder. Defaults to IMAGE_TOKEN_INDEX. 51 | return_tensors (str, optional): If specified, returns the tokenized input as a tensor of the specified type ('pt' for PyTorch). Defaults to None. 52 | 53 | Returns: 54 | list or torch.Tensor: The tokenized input IDs. If `return_tensors` is specified as 'pt', returns a PyTorch tensor; otherwise, returns a list of input IDs. 55 | """ 56 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 57 | # print("####"*100) 58 | # attention = [tokenizer(chunk).attention_mask for chunk in prompt.split('')] 59 | def insert_separator(X, sep): 60 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 61 | 62 | input_ids = [] 63 | offset = 0 64 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 65 | offset = 1 66 | input_ids.append(prompt_chunks[0][0]) 67 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 68 | input_ids.extend(x[offset:]) 69 | 70 | if return_tensors is not None: 71 | if return_tensors == 'pt': 72 | return torch.tensor(input_ids, dtype=torch.long) 73 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 74 | return input_ids 75 | 76 | 77 | def get_model_name_from_path(model_path): 78 | model_path = model_path.strip("/") 79 | model_paths = model_path.split("/") 80 | if model_paths[-1].startswith('checkpoint-'): 81 | return model_paths[-2] + "_" + model_paths[-1] 82 | else: 83 | return model_paths[-1] 84 | 85 | 86 | class KeywordsStoppingCriteria(StoppingCriteria): 87 | def __init__(self, keywords, tokenizer, input_ids): 88 | self.keywords = keywords 89 | self.keyword_ids = [] 90 | for keyword in keywords: 91 | cur_keyword_ids = tokenizer(keyword).input_ids 92 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 93 | cur_keyword_ids = cur_keyword_ids[1:] 94 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 95 | self.tokenizer = tokenizer 96 | self.start_len = input_ids.shape[1] 97 | 98 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 99 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 100 | offset = min(output_ids.shape[1] - self.start_len, 3) 101 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 102 | for keyword_id in self.keyword_ids: 103 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 104 | return True 105 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 106 | for keyword in self.keywords: 107 | if keyword in outputs: 108 | return True 109 | return False 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | TinyVLA: Towards Fast, Data-Efficient Vision-Language-Action Models 3 | for Robotic Manipulation

4 | 5 | 6 | * **TinyVLA: Towards Fast, Data-Efficient Vision-Language-Action Modelsfor Robotic Manipulation**
7 | [![arXiv](https://img.shields.io/badge/Arxiv-2402.03766-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2409.12514) 8 | 9 | 10 | 11 | ## 📰 News 12 | * **`Feb. 17th, 2025`**: 🔥🔥🔥Our code is released! 13 | * **`Feb. 9th, 2025`**: 🔥🔥🔥**TinyVLA** is accepted by IEEE Robotics and Automation Letters (RA-L) 2025! 14 | * **`Nov. 19th, 2024`**: **TinyVLA** is out! **Paper** can be found [here](https://arxiv.org/abs/2409.12514). The **project web** can be found [here](https://tiny-vla.github.io/). 15 | 16 | ## Contents 17 | - [📰 News](#-news) 18 | - [Contents](#contents) 19 | - [Install](#install) 20 | - [Data Preparation](#data-preparation) 21 | - [Download Pretrained VLM](#download-pretrained-vlm) 22 | - [Train](#train) 23 | - [Evaluation](#evaluation) 24 | - [Acknowledgement](#acknowledgement) 25 | - [Citation](#citation) 26 | 27 | ## Install 28 | 29 | 1. Clone this repository and navigate to diffusion-vla folder 30 | ```bash 31 | git clone https://github.com/liyaxuanliyaxuan/TinyVLA 32 | ``` 33 | 34 | 2. Install Package 35 | ```Shell 36 | conda create -n tinyvla python=3.10 -y 37 | conda activate tinyvla 38 | pip install --upgrade pip # 39 | pip install -r requirements.txt 40 | cd policy_heads 41 | pip install -e . 42 | # install llava-pythia 43 | cd ../llava-pythia 44 | pip install -e . 45 | ``` 46 | 47 | ## Data Preparation 48 | 1. Our data format is the same as [act](https://github.com/MarkFzp/act-plus-plus), so you need to transfer your data into h5py format. You can refer to the [rlds_to_h5py.py](https://github.com/lesjie-wen/tinyvla/blob/main/data_utils/rlds_to_h5py.py) which is used to transfer the data from rlds format to h5py format. 49 | ```angular2html 50 | # h5 data structure 51 | root 52 | |-action (100,10) 53 | |-language_raw (1,) 54 | |-observations 55 | |-images # multi-view 56 | |-left (100,480,640,3) 57 | |-right (100,480,640,3) 58 | |-wrist (100,480,640,3) 59 | |-joint_positions (100,7) 60 | |-qpos (100,7) 61 | |-qvel (100,7) 62 | ``` 63 | 2. You have to add one entry in [constants.py](https://github.com/lesjie-wen/tinyvla/blob/main/aloha_scripts/constants.py) to specify the path of your data as follows. 64 | ```python 65 | 'your_task_name':{ 66 | 'dataset_dir': DATA_DIR + '/your_task_path', # define the path of the dataset 67 | 'episode_len': 1000, #max length of the episode, 68 | 'camera_names': ['front', 'wrist'] # define the camera names which are used as the key when reading data 69 | } 70 | ``` 71 | ## Download Pretrained VLM 72 | We construct the VLM backbone by integrating a series of tiny LLM([Pythia](https://github.com/EleutherAI/pythia)) into [Llava](https://github.com/haotian-liu/LLaVA) framework. We follow the standard training pipe line and data provided by [Llava](https://github.com/haotian-liu/LLaVA). All the weights of VLM used in our paper are listed as following: 73 | 74 | | Model | Usage | Link | 75 | |---------------------|---------------|----------------------------------------------------------------| 76 | | Llava-Pythia(~400M) | For TinyVLA-S | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-400M) | 77 | | Llava-Pythia(~700M) | For TinyVLA-B | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-700M) | 78 | | Llava-Pythia(~1.3B) | For TinyVLA-H | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-1.3B) | 79 | 80 | 81 | ## Train 82 | The training script is "scripts/train.sh". And you need to change following parameters: 83 | 1. **OUTPUT** :refers to the save directory for training, which must include the keyword "llava_pythia" (and optionally "lora"). If LoRA training is used, the name must include "lora" (e.g., "llava_pythia_lora"). 84 | 2. **task_name** :refers to the tasks used for training, which should be corresponded to "your_task_name" in aloha_scripts/constant.py 85 | 3. **model_name_or_path** :path to the pretrained VLM weights 86 | 4. Other hyperparameters like "batch_size", "save_steps" could be customized according to your computation resources. 87 | 88 | Start training by following commands: 89 | ```shell 90 | ./scripts/train.sh 91 | ``` 92 | 93 | ## Evaluation 94 | Before evaluation, we provide a post process script to generate a usable and smaller weights. 95 | The process script is "scripts/process_ckpts.sh". And you need to change following parameters: 96 | 1. **source_dir** :path to trained VLA dir equals to **OUTPUT** in train.sh 97 | 2. **target_dir** :path to save processed VLA weights 98 | 99 | You can refer to our evaluation script [eval_real_franka.py](https://github.com/lesjie-wen/tinyvla/blob/main/eval_real_franka.py). 100 | ## Acknowledgement 101 | We build our project based on: 102 | - [LLaVA](https://github.com/haotian-liu/LLaVA): an amazing open-sourced project for vision language assistant 103 | - [act-plus-plus](https://github.com/haotian-liu/LLaVA): an amazing open-sourced project for robotics visuomotor learning 104 | - [Miphi](https://github.com/zhuyiche/llava-phi): an amazing open-sourced project for tiny vision language model 105 | 106 | ## Citation 107 | 108 | If you find Tiny-VLA useful for your research and applications, please cite using this BibTeX: 109 | ```bibtex 110 | @misc{ 111 | @inproceedings{wen2024tinyvla, 112 | title={Tinyvla: Towards fast, data-efficient vision-language-action models for robotic manipulation}, 113 | author={Wen, Junjie and Zhu, Yichen and Li, Jinming and Zhu, Minjie and Wu, Kun and Xu, Zhiyuan and Liu, Ning and Cheng, Ran and Shen, Chaomin and Peng, Yaxin and others}, 114 | booktitle={IEEE Robotics and Automation Letters (RA-L)}, 115 | year={2025} 116 | } 117 | ``` 118 | 119 | 120 | -------------------------------------------------------------------------------- /policy_heads/models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from policy_heads.util.misc import is_main_process, NestedTensor 15 | from .position_encoding import build_position_encoding 16 | 17 | import IPython 18 | e = IPython.embed 19 | 20 | class FrozenBatchNorm2d(torch.nn.Module): 21 | """ 22 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 23 | 24 | This implementation is a copy-paste from torchvision.misc.ops with added eps before rsqrt, 25 | without which any other models than torchvision.models.resnet[18,34,50,101] produce NaNs. 26 | """ 27 | def __init__(self, n): 28 | super(FrozenBatchNorm2d, self).__init__() 29 | self.register_buffer("weight", torch.ones(n)) 30 | self.register_buffer("bias", torch.zeros(n)) 31 | self.register_buffer("running_mean", torch.zeros(n)) 32 | self.register_buffer("running_var", torch.ones(n)) 33 | 34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 35 | missing_keys, unexpected_keys, error_msgs): 36 | # remove num_batches_tracked from state_dict if present 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | """ 47 | Forward pass for the frozen batch normalization. 48 | 49 | Args: 50 | x: Input tensor. 51 | 52 | Returns: 53 | Normalized tensor. 54 | """ 55 | # move reshapes to the beginning to make it fuser-friendly 56 | w = self.weight.reshape(1, -1, 1, 1) 57 | b = self.bias.reshape(1, -1, 1, 1) 58 | rv = self.running_var.reshape(1, -1, 1, 1) 59 | rm = self.running_mean.reshape(1, -1, 1, 1) 60 | eps = 1e-5 61 | scale = w * (rv + eps).rsqrt() 62 | bias = b - rm * scale 63 | return x * scale + bias 64 | 65 | 66 | class BackboneBase(nn.Module): 67 | """ 68 | Base class for backbone networks. 69 | 70 | Args: 71 | backbone: The backbone model. 72 | train_backbone: Whether to train the backbone. 73 | num_channels: Number of output channels. 74 | return_interm_layers: Whether to return intermediate layers. 75 | """ 76 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 77 | super().__init__() 78 | # determine which layers to return 79 | if return_interm_layers: 80 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 81 | else: 82 | return_layers = {'layer4': "0"} 83 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 84 | self.num_channels = num_channels 85 | 86 | def forward(self, tensor): 87 | """ 88 | Forward pass for the backbone. 89 | 90 | Args: 91 | tensor: Input tensor. 92 | 93 | Returns: 94 | Dictionary of feature maps from the specified layers. 95 | """ 96 | xs = self.body(tensor) 97 | return xs 98 | 99 | 100 | class Backbone(BackboneBase): 101 | """ 102 | ResNet backbone with frozen BatchNorm. 103 | 104 | Args: 105 | name: Name of the ResNet model. 106 | train_backbone: Whether to train the backbone. 107 | return_interm_layers: Whether to return intermediate layers. 108 | dilation: Whether to use dilation in the last block. 109 | """ 110 | def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): 111 | backbone = getattr(torchvision.models, name)( 112 | replace_stride_with_dilation=[False, False, dilation], 113 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # use pretrained model 114 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 115 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 116 | 117 | 118 | class Joiner(nn.Sequential): 119 | """ 120 | Combines a backbone and a position encoding module. 121 | 122 | Args: 123 | backbone: The backbone model. 124 | position_embedding: The position encoding module. 125 | """ 126 | def __init__(self, backbone, position_embedding): 127 | super().__init__(backbone, position_embedding) 128 | 129 | def forward(self, tensor_list: NestedTensor): 130 | """ 131 | Forward pass for the joiner. 132 | 133 | Args: 134 | tensor_list: NestedTensor containing input data and mask. 135 | 136 | Returns: 137 | Tuple of feature maps and position encodings. 138 | """ 139 | xs = self[0](tensor_list) 140 | out: List[NestedTensor] = [] 141 | pos = [] 142 | for name, x in xs.items(): 143 | out.append(x) 144 | # position encoding 145 | pos.append(self[1](x).to(x.dtype)) 146 | 147 | return out, pos 148 | 149 | 150 | def build_backbone(args): 151 | """ 152 | Builds the backbone model with position encoding. 153 | 154 | Args: 155 | args: Arguments containing configuration for the backbone. 156 | 157 | Returns: 158 | A model combining the backbone and position encoding. 159 | """ 160 | position_embedding = build_position_encoding(args) 161 | train_backbone = args.lr_backbone > 0 162 | return_interm_layers = args.masks 163 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 164 | model = Joiner(backbone, position_embedding) 165 | model.num_channels = backbone.num_channels 166 | return model 167 | -------------------------------------------------------------------------------- /llava-pythia/scripts/tranfer2llava.py: -------------------------------------------------------------------------------- 1 | # Convert the R3M-formatted Franka kitchen data to the LLaVA format 2 | output_path = "/media/rl/HDD/data/data/franka_kitchen/frankakitchen_llava" 3 | import pickle 4 | import json 5 | import os 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import numpy as np 9 | TASK = { 10 | 'kitchen_sdoor_open-v3':"Open the sliding door", 11 | 'kitchen_micro_open-v3':"Open the microwave oven", 12 | 'kitchen_light_on-v3':"Toggle the light switch", 13 | 'kitchen_ldoor_open-v3':"Open the cabinet door", 14 | 'kitchen_knob1_on-v3':"Rotate the round stovetop knob" 15 | } 16 | 17 | def franka_kitchen2llava_format(data_path:str, ratio:float, view): 18 | os.makedirs(os.path.join(output_path, view), exist_ok=True) 19 | os.makedirs(os.path.join(output_path, view, 'images'), exist_ok=True) 20 | 21 | pickle_path = os.listdir(data_path) 22 | pickle_path = [p for p in pickle_path if p.endswith('.pickle')] 23 | all_task_demo_paths = [] 24 | 25 | t_bar = tqdm(total=1000 * ratio) 26 | 27 | try: 28 | with open(os.path.join(data_path, f'mean_var_{view}.json'), 'r') as f: 29 | mean_std = json.load(f) 30 | except: 31 | mean_std = {'action':{},'state':{}} 32 | 33 | 34 | data_processed = [] 35 | eval_data = [] 36 | train_data = [] 37 | for p in pickle_path: 38 | cur_data = [] 39 | 40 | if p.split('/')[-1] not in mean_std['action'].keys(): 41 | mean_std['action'][p.split('/')[-1]] = {} 42 | mean_std['state'][p.split('/')[-1]] = {} 43 | 44 | action_all = [] 45 | state_all = [] 46 | # print(p.split('.')[0]) 47 | max_action = [0] * 9 48 | min_action = [1000000] * 9 49 | max_state = [0] * 9 50 | min_state = [1000000] * 9 51 | 52 | demo_paths_loc = os.path.join(data_path, p) 53 | demo_paths = pickle.load(open(demo_paths_loc, 'rb')) 54 | all_task_demo_paths += demo_paths[:int(ratio*200)] 55 | # print(all_task_demo_paths[0].keys()) 56 | # print(all_task_demo_paths[0]['actions'].shape) 57 | for idx, each in enumerate(demo_paths[:int(ratio*200)]): # trajectory id 58 | traj_len = each['images'].shape[0] 59 | 60 | # normalize action and state 61 | m_a,v_a = np.array([mean_std['action'][p.split('/')[-1]]['mean']]), np.array([mean_std['action'][p.split('/')[-1]]['var']]) 62 | m_s,v_s = np.array([mean_std['state'][p.split('/')[-1]]['mean']]), np.array([mean_std['state'][p.split('/')[-1]]['var']]) 63 | each['actions'] = (each['actions'] - m_a) /np.sqrt(v_a) 64 | # print(each['observations'][:,:9].shape) 65 | each['observations'][:,:9] = (each['observations'][:,:9] - m_s) / np.sqrt(v_s) 66 | # ############################# 67 | for i in range(traj_len): # frame id 68 | t = { 69 | "id": "", 70 | "image": "", 71 | 'state': [], 72 | 'action': [], 73 | "conversations": [{"from": "human", "value": "\n"}, {"from": "gpt", "value": " "}] 74 | } 75 | img_p = os.path.join(output_path, view, 'images', f'{p.split(".")[0]}_{idx}_{i}.png') 76 | # print(each['images'].shape) 77 | # break 78 | if not os.path.exists(img_p): 79 | Image.fromarray(each['images'][i]).save(img_p) 80 | t['image'] = img_p 81 | t['id'] = img_p.split('/')[-1] 82 | t["conversations"][0]["value"] += TASK[p.split('.')[0]] 83 | t['state'] = each['observations'][i].tolist() 84 | t['action'] = each['actions'][i].tolist() 85 | # print(t['action']) 86 | ########################################查看一下数据范围 87 | # for j,a in enumerate(zip(t['action'])): 88 | # if max_action[j] < a: 89 | # max_action[j] = a 90 | # if min_action[j] > a: 91 | # min_action[j] = a 92 | # for j,a in enumerate(t['state'][:9]): 93 | # if max_state[j] < a: 94 | # max_state[j] = a 95 | # if min_state[j] > a: 96 | # min_state[j] = a 97 | action_all.append(t['action']) 98 | state_all.append(t['state'][:9]) 99 | 100 | data_processed.append(t) 101 | cur_data.append(t) 102 | 103 | t_bar.update(1) 104 | # print(t) 105 | # break 106 | # break 107 | # print(p.split('/')[-1]) 108 | # print(max_action,min_action) 109 | # print(max_state,min_state) 110 | 111 | mean_action = np.mean(np.array(action_all), axis=0) 112 | var_action = np.var(np.array(action_all), axis=0) 113 | 114 | mean_state = np.mean(np.array(state_all), axis=0) 115 | var_state = np.var(np.array(state_all), axis=0) 116 | 117 | mean_std['action'][p.split('/')[-1]]['mean'] = mean_action.tolist() 118 | mean_std['action'][p.split('/')[-1]]['var'] = var_action.tolist() 119 | 120 | mean_std['state'][p.split('/')[-1]]['mean'] = mean_state.tolist() 121 | mean_std['state'][p.split('/')[-1]]['var'] = var_state.tolist() 122 | eval_data += cur_data[int(200*50*0.9):] 123 | train_data += cur_data[:int(200*50*0.9)] 124 | # print("action") 125 | # print(mean_action, var_action) 126 | 127 | # print("state") 128 | # print(mean_state, var_state) 129 | # print(mean_std) 130 | 131 | # with open(f'mean_var_{view}.json', 'w') as f: 132 | # json.dump(mean_std,f) 133 | 134 | # with open('action_all.txt', 'w') as f: 135 | # for a in action_all: 136 | # f.write(str(a) + '\n') 137 | # with open('state_all.txt', 'w') as f: 138 | # for s in state_all: 139 | # f.write(str(s) + '\n') 140 | 141 | # with open(os.path.join(output_path, view, f"std_{view}_50k.json"), "w") as f: 142 | # json.dump(data_processed, f, indent=4) 143 | print(len(train_data), len(eval_data)) 144 | with open(os.path.join(output_path, view, f"std_eval_{view}_50k.json"), "w") as f: 145 | json.dump(eval_data, f, indent=4) 146 | with open(os.path.join(output_path, view, f"std_train_{view}_50k.json"), "w") as f: 147 | json.dump(train_data, f, indent=4) 148 | 149 | for view in ['default', 'left_cap2', 'right_cap2']: 150 | data_path = f"/media/rl/HDD/data/data/franka_kitchen/FrankaKitchen/{view}" 151 | franka_kitchen2llava_format(data_path, 1, view) -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import shutil 4 | 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor, SiglipImageProcessor, \ 6 | GPTNeoXModel, GPTNeoXPreTrainedModel 7 | import torch 8 | from llava_pythia.model import * 9 | from llava_pythia.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | 11 | 12 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda"): 13 | """ 14 | Loads a pretrained model with optional quantization and device mapping. 15 | 16 | Args: 17 | - model_path (str): Path to the model directory or file. 18 | - model_base (str): Base model path, used when loading LoRA models. 19 | - model_name (str): Name of the model to load. 20 | - load_8bit (bool): Whether to load the model in 8-bit precision. 21 | - load_4bit (bool): Whether to load the model in 4-bit precision. 22 | - device_map (str): Device map for model loading, default is "cuda". 23 | - device (str): Device to load the model onto, default is "cuda". 24 | 25 | Returns: 26 | - tokenizer: The tokenizer associated with the model. 27 | - model: The loaded model. 28 | - image_processor: The image processor if applicable. 29 | - context_len (int): The context length of the model. 30 | """ 31 | kwargs = {"device_map": device_map} 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if 'pythia' in model_name.lower(): 46 | # Load LLaVA-Phi model 47 | if 'lora' in model_name.lower() and model_base is None: 48 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.') 49 | if 'lora' in model_name.lower() and model_base is not None: 50 | 51 | path = model_path.split('/')[0:-1] 52 | root_path = '/'.join(path) 53 | lora_cfg_pretrained = AutoConfig.from_pretrained(root_path) 54 | config = lora_cfg_pretrained 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) # default use_fast=False 56 | print('Loading LLaVA-Pythia from base model...') 57 | model = LlavaPythiaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 58 | 59 | # token_num, tokem_dim = model.embed_out.out_features, model.embed_out.in_features 60 | # if model.embed_out.weight.shape[0] != token_num: 61 | # model.embed_out.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 62 | # model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 63 | 64 | print('Loading additional LLaVA-Pythia weights...') 65 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 66 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 67 | else: 68 | # this is probably from HF Hub 69 | from huggingface_hub import hf_hub_download 70 | def load_from_hf(repo_id, filename, subfolder=None): 71 | cache_file = hf_hub_download( 72 | repo_id=repo_id, 73 | filename=filename, 74 | subfolder=subfolder) 75 | return torch.load(cache_file, map_location='cpu') 76 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 77 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 78 | if any(k.startswith('model.gpt_neox.') for k in non_lora_trainables): 79 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 80 | 81 | # 删除lora相关的参数 82 | keys_to_del = [] 83 | for k,v in non_lora_trainables.items(): 84 | if 'lora' in k: 85 | keys_to_del.append(k) 86 | for key in keys_to_del: 87 | del non_lora_trainables[key] 88 | 89 | model.load_state_dict(non_lora_trainables, strict=False) 90 | 91 | from peft import PeftModel 92 | print('Loading LoRA weights...') 93 | model = PeftModel.from_pretrained(model, model_path) 94 | print('Merging LoRA weights...') 95 | model = model.merge_and_unload() 96 | print('Model is loaded...') 97 | elif model_base is not None: 98 | # this may be mm projector only 99 | print('Loading LLaVA-Pythia from base model...') 100 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 101 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 102 | model = LlavaPythiaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 103 | 104 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 105 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 106 | model.load_state_dict(mm_projector_weights, strict=False) 107 | else: 108 | print("load llaVA-Pythia MLLM!!!") 109 | config = LlavaPythiaConfig.from_pretrained(model_path, trust_remote_code=True) 110 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 111 | model = LlavaPythiaForCausalLM.from_pretrained( 112 | model_path, 113 | config=config, 114 | use_safetensors=True, 115 | **kwargs).to("cuda") 116 | else: 117 | # Load language model 118 | if model_base is not None: 119 | # PEFT model 120 | from peft import PeftModel 121 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 122 | model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") 123 | print(f"Loading LoRA weights from {model_path}") 124 | model = PeftModel.from_pretrained(model, model_path) 125 | print(f"Merging weights") 126 | model = model.merge_and_unload() 127 | print('Convert to FP16...') 128 | model.to(torch.float16) 129 | else: 130 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 131 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 132 | if "clip" in config.vision_config["vision_tower"]["vision_model_name_or_path"]: 133 | image_processor = CLIPImageProcessor.from_pretrained(model_path) 134 | elif "siglip" in config.vision_config["vision_tower"]["vision_model_name_or_path"]: 135 | image_processor = SiglipImageProcessor.from_pretrained(model_path) 136 | else: 137 | return NotImplementedError 138 | # image_processor = CLIPImageProcessor.from_pretrained(model_path) 139 | 140 | if 'pythia' in model_name.lower(): 141 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 142 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 143 | 144 | # TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200 145 | if mm_use_im_patch_token: 146 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 147 | if mm_use_im_start_end: 148 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 149 | # model.resize_token_embeddings(len(tokenizer)) 150 | else: 151 | raise ValueError(f"Unsupported model name: {model_name}") 152 | 153 | if hasattr(model.config, "max_sequence_length"): 154 | context_len = model.config.max_sequence_length 155 | else: 156 | context_len = 2048 157 | model.to(device="cuda") 158 | print(kwargs) 159 | return tokenizer, model, image_processor, context_len 160 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/model/language_model/pythia/configuration_llava_pythia.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | from transformers import PretrainedConfig, GPTNeoXConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class LlavaPythiaVisionConfig(PretrainedConfig): 10 | r""" 11 | This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a 12 | CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a 13 | configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP 14 | [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 17 | documentation from [`PretrainedConfig`] for more information. 18 | 19 | Args: 20 | hidden_size (`int`, *optional*, defaults to 768): 21 | Dimensionality of the encoder layers and the pooler layer. 22 | intermediate_size (`int`, *optional*, defaults to 3072): 23 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 24 | projection_dim (`int`, *optional*, defaults to 512): 25 | Dimentionality of text and vision projection layers. 26 | num_hidden_layers (`int`, *optional*, defaults to 12): 27 | Number of hidden layers in the Transformer encoder. 28 | num_attention_heads (`int`, *optional*, defaults to 12): 29 | Number of attention heads for each attention layer in the Transformer encoder. 30 | num_channels (`int`, *optional*, defaults to 3): 31 | The number of input channels. 32 | image_size (`int`, *optional*, defaults to 224): 33 | The size (resolution) of each image. 34 | patch_size (`int`, *optional*, defaults to 32): 35 | The size (resolution) of each patch. 36 | hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): 37 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 38 | `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. 39 | layer_norm_eps (`float`, *optional*, defaults to 1e-05): 40 | The epsilon used by the layer normalization layers. 41 | attention_dropout (`float`, *optional*, defaults to 0.0): 42 | The dropout ratio for the attention probabilities. 43 | initializer_range (`float`, *optional*, defaults to 0.02): 44 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 45 | initializer_factor (`float`, *optional*, defaults to 1.0): 46 | A factor for initializing all weight matrices (should be kept to 1, used internally for initialization 47 | testing). 48 | mm_vision_select_feature (`str`, *optional*, defaults to `"patch"`): 49 | The feature to select from the vision encoder output. Can be one of `"patch"` or `"cls_patch"`. 50 | mm_vision_select_layer (`int`, *optional*, defaults to `-2`): 51 | The layer to select from the vision encoder output. 52 | 53 | Example: 54 | 55 | ```python 56 | >>> from transformers import CLIPVisionConfig, CLIPVisionModel 57 | 58 | >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration 59 | >>> configuration = CLIPVisionConfig() 60 | 61 | >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration 62 | >>> model = CLIPVisionModel(configuration) 63 | 64 | >>> # Accessing the model configuration 65 | >>> configuration = model.config 66 | ```""" 67 | 68 | model_type = "llava_pythia_clip_vision_model" 69 | 70 | def __init__( 71 | self, 72 | hidden_size=768, 73 | intermediate_size=3072, 74 | projection_dim=512, 75 | num_hidden_layers=12, 76 | num_attention_heads=12, 77 | num_channels=3, 78 | image_size=224, 79 | patch_size=32, 80 | hidden_act="quick_gelu", 81 | layer_norm_eps=1e-5, 82 | attention_dropout=0.0, 83 | initializer_range=0.02, 84 | initializer_factor=1.0, 85 | mm_vision_select_feature="patch", 86 | mm_vision_select_layer=-2, 87 | vision_model_name_or_path="clip", 88 | concat="None", 89 | **kwargs, 90 | ): 91 | super().__init__(**kwargs) 92 | 93 | self.hidden_size = hidden_size 94 | self.intermediate_size = intermediate_size 95 | self.projection_dim = projection_dim 96 | self.num_hidden_layers = num_hidden_layers 97 | self.num_attention_heads = num_attention_heads 98 | self.num_channels = num_channels 99 | self.patch_size = patch_size 100 | self.image_size = image_size 101 | self.initializer_range = initializer_range 102 | self.initializer_factor = initializer_factor 103 | self.attention_dropout = attention_dropout 104 | self.layer_norm_eps = layer_norm_eps 105 | self.hidden_act = hidden_act 106 | self.mm_vision_select_feature = mm_vision_select_feature 107 | self.mm_vision_select_layer = mm_vision_select_layer 108 | self.vision_model_name_or_path = vision_model_name_or_path 109 | self.concat = concat 110 | 111 | @classmethod 112 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 113 | cls._set_token_in_kwargs(kwargs) 114 | 115 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 116 | 117 | # get the vision config dict if we are loading from CLIPConfig 118 | if config_dict.get("model_type") == "llava_pythia": 119 | config_dict = config_dict["vision_config"]["vision_tower"] 120 | 121 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 122 | logger.warning( 123 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 124 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 125 | ) 126 | 127 | return cls.from_dict(config_dict, **kwargs) 128 | 129 | 130 | class ProjectorConfig(PretrainedConfig): 131 | model_type = "llava_pythia_projector" 132 | 133 | def __init__( 134 | self, 135 | mm_projector_type="linear", 136 | mm_hidden_size=768, 137 | hidden_size=2560, 138 | **kwargs 139 | ): 140 | self.mm_projector_type = mm_projector_type 141 | self.mm_hidden_size = mm_hidden_size 142 | self.hidden_size = hidden_size 143 | super().__init__(**kwargs) 144 | 145 | @classmethod 146 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 147 | cls._set_token_in_kwargs(kwargs) 148 | 149 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 150 | 151 | # get the vision config dict if we are loading from CLIPConfig 152 | if config_dict.get("model_type") == "llava_pythia": 153 | config_dict = config_dict["vision_config"]["mm_projector"] 154 | 155 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 156 | logger.warning( 157 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 158 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 159 | ) 160 | 161 | return cls.from_dict(config_dict, **kwargs) 162 | 163 | 164 | 165 | from typing import List 166 | 167 | # for initialize act head 168 | 169 | DEFAULT_VISUAL_CONFIG = { 170 | "vision_tower": LlavaPythiaVisionConfig().to_dict(), 171 | "mm_projector": ProjectorConfig().to_dict(), 172 | } 173 | 174 | # print(DEFAULT_ACT_CONFIG['act']) 175 | 176 | class LlavaPythiaConfig(GPTNeoXConfig): 177 | model_type = "llava_pythia" 178 | 179 | def __init__(self, vision_config=None, **kwargs): 180 | if vision_config is None: 181 | self.vision_config = DEFAULT_VISUAL_CONFIG 182 | else: 183 | self.vision_config = vision_config 184 | 185 | self.concat = "None" 186 | super().__init__(**kwargs) 187 | 188 | 189 | if __name__ == "__main__": 190 | print(LlavaPythiaVisionConfig()) 191 | -------------------------------------------------------------------------------- /llava-pythia/llava_pythia/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | 5 | 6 | class SeparatorStyle(Enum): 7 | """Different separator style.""" 8 | SINGLE = auto() 9 | TWO = auto() 10 | MPT = auto() 11 | PLAIN = auto() 12 | LLAMA_2 = auto() 13 | 14 | 15 | @dataclasses.dataclass 16 | class Conversation: 17 | """A class that keeps all conversation history.""" 18 | system: str 19 | roles: List[str] 20 | messages: List[List[str]] 21 | offset: int 22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 23 | sep: str = "###" 24 | sep2: str = None 25 | version: str = "Unknown" 26 | 27 | skip_next: bool = False 28 | 29 | def get_prompt(self) -> str: 30 | """Generate a prompt string based on the conversation history and separator style.""" 31 | messages = self.messages 32 | # check if the first message is a tuple and process accordingly 33 | if len(messages) > 0 and type(messages[0][1]) is tuple: 34 | messages = self.messages.copy() 35 | init_role, init_msg = messages[0].copy() 36 | init_msg = init_msg[0].replace("", "").strip() 37 | if 'mmtag' in self.version: 38 | messages[0] = (init_role, init_msg) 39 | messages.insert(0, (self.roles[0], "")) 40 | messages.insert(1, (self.roles[1], "Received.")) 41 | else: 42 | messages[0] = (init_role, "\n" + init_msg) 43 | 44 | # construct the prompt based on the separator style 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.PLAIN: 65 | seps = [self.sep, self.sep2] 66 | ret = self.system 67 | for i, (role, message) in enumerate(messages): 68 | if message: 69 | if type(message) is tuple: 70 | message, _, _ = message 71 | ret += message + seps[i % 2] 72 | else: 73 | ret += "" 74 | else: 75 | raise ValueError(f"Invalid style: {self.sep_style}") 76 | 77 | return ret 78 | 79 | def append_message(self, role: str, message: str) -> None: 80 | """Append a new message to the conversation.""" 81 | self.messages.append([role, message]) 82 | 83 | def get_images(self, return_pil: bool = False) -> List: 84 | """Extract images from the conversation messages.""" 85 | images = [] 86 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 87 | if i % 2 == 0: 88 | if type(msg) is tuple: 89 | import base64 90 | from io import BytesIO 91 | from PIL import Image 92 | msg, image, image_process_mode = msg 93 | # process image based on the mode 94 | if image_process_mode == "Pad": 95 | def expand2square(pil_img, background_color=(122, 116, 104)): 96 | width, height = pil_img.size 97 | if width == height: 98 | return pil_img 99 | elif width > height: 100 | result = Image.new(pil_img.mode, (width, width), background_color) 101 | result.paste(pil_img, (0, (width - height) // 2)) 102 | return result 103 | else: 104 | result = Image.new(pil_img.mode, (height, height), background_color) 105 | result.paste(pil_img, ((height - width) // 2, 0)) 106 | return result 107 | image = expand2square(image) 108 | elif image_process_mode in ["Default", "Crop"]: 109 | pass 110 | elif image_process_mode == "Resize": 111 | image = image.resize((336, 336)) 112 | else: 113 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 114 | # resize image to maintain aspect ratio 115 | max_hw, min_hw = max(image.size), min(image.size) 116 | aspect_ratio = max_hw / min_hw 117 | max_len, min_len = 800, 400 118 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 119 | longest_edge = int(shortest_edge * aspect_ratio) 120 | W, H = image.size 121 | if longest_edge != max(image.size): 122 | if H > W: 123 | H, W = longest_edge, shortest_edge 124 | else: 125 | H, W = shortest_edge, longest_edge 126 | image = image.resize((W, H)) 127 | if return_pil: 128 | images.append(image) 129 | else: 130 | buffered = BytesIO() 131 | image.save(buffered, format="PNG") 132 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 133 | images.append(img_b64_str) 134 | return images 135 | 136 | def to_gradio_chatbot(self): 137 | ret = [] 138 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 139 | if i % 2 == 0: 140 | if type(msg) is tuple: 141 | import base64 142 | from io import BytesIO 143 | msg, image, image_process_mode = msg 144 | max_hw, min_hw = max(image.size), min(image.size) 145 | aspect_ratio = max_hw / min_hw 146 | max_len, min_len = 800, 400 147 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 148 | longest_edge = int(shortest_edge * aspect_ratio) 149 | W, H = image.size 150 | if H > W: 151 | H, W = longest_edge, shortest_edge 152 | else: 153 | H, W = shortest_edge, longest_edge 154 | image = image.resize((W, H)) 155 | buffered = BytesIO() 156 | image.save(buffered, format="JPEG") 157 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 158 | img_str = f'user upload image' 159 | msg = img_str + msg.replace('', '').strip() 160 | ret.append([msg, None]) 161 | else: 162 | ret.append([msg, None]) 163 | else: 164 | ret[-1][-1] = msg 165 | return ret 166 | 167 | def copy(self): 168 | return Conversation( 169 | system=self.system, 170 | roles=self.roles, 171 | messages=[[x, y] for x, y in self.messages], 172 | offset=self.offset, 173 | sep_style=self.sep_style, 174 | sep=self.sep, 175 | sep2=self.sep2, 176 | version=self.version) 177 | 178 | def dict(self): 179 | if len(self.get_images()) > 0: 180 | return { 181 | "system": self.system, 182 | "roles": self.roles, 183 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 184 | "offset": self.offset, 185 | "sep": self.sep, 186 | "sep2": self.sep2, 187 | } 188 | return { 189 | "system": self.system, 190 | "roles": self.roles, 191 | "messages": self.messages, 192 | "offset": self.offset, 193 | "sep": self.sep, 194 | "sep2": self.sep2, 195 | } 196 | 197 | 198 | conv_pythia = Conversation( 199 | system="A chat between a curious user and an artificial intelligence assistant. " 200 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 201 | roles=("USER", "ASSISTANT"), 202 | version="v0", 203 | messages=(), 204 | offset=0, 205 | sep_style=SeparatorStyle.TWO, 206 | sep=" ", 207 | sep2="<|endoftext|>", 208 | ) 209 | 210 | conv_llava_plain = Conversation( 211 | system="", 212 | roles=("", ""), 213 | messages=(), 214 | offset=0, 215 | sep_style=SeparatorStyle.PLAIN, 216 | sep="\n", 217 | ) 218 | 219 | default_conversation = conv_pythia 220 | conv_templates = { 221 | "default": conv_pythia, 222 | "v0": conv_pythia, 223 | "pythia": conv_pythia, 224 | 225 | "plain": conv_llava_plain, 226 | } 227 | 228 | 229 | if __name__ == "__main__": 230 | print(default_conversation.get_prompt()) 231 | -------------------------------------------------------------------------------- /policy_heads/models/droid_unet_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi 3 | """ 4 | from typing import Callable, Union 5 | import math 6 | from collections import OrderedDict, deque 7 | from packaging.version import parse as parse_version 8 | import random 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | # requires diffusers==0.11.1 13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 14 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 15 | from diffusers.training_utils import EMAModel 16 | 17 | 18 | # =================== UNet for Diffusion ============== 19 | 20 | class SinusoidalPosEmb(nn.Module): 21 | """ 22 | Sinusoidal positional embedding for diffusion models. 23 | 24 | Args: 25 | dim: The dimension of the embedding. 26 | dtype: The data type for the embedding. 27 | """ 28 | def __init__(self, dim, dtype): 29 | super().__init__() 30 | self.dim = dim 31 | self.dtype = dtype 32 | 33 | def forward(self, x): 34 | """ 35 | Forward pass to compute the sinusoidal positional embedding. 36 | 37 | Args: 38 | x: Input tensor. 39 | 40 | Returns: 41 | The sinusoidal positional embedding. 42 | """ 43 | device = x.device 44 | half_dim = self.dim // 2 45 | emb = math.log(10000) / (half_dim - 1) 46 | emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb) 47 | emb = x[:, None] * emb[None, :] 48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 49 | return emb 50 | 51 | 52 | class Downsample1d(nn.Module): 53 | """ 54 | 1D downsampling layer using convolution. 55 | 56 | Args: 57 | dim: The number of input and output channels. 58 | """ 59 | def __init__(self, dim): 60 | super().__init__() 61 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 62 | 63 | def forward(self, x): 64 | return self.conv(x) 65 | 66 | 67 | class Upsample1d(nn.Module): 68 | """ 69 | 1D upsampling layer using transposed convolution. 70 | 71 | Args: 72 | dim: The number of input and output channels. 73 | """ 74 | def __init__(self, dim): 75 | super().__init__() 76 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 77 | 78 | def forward(self, x): 79 | return self.conv(x) 80 | 81 | 82 | class Conv1dBlock(nn.Module): 83 | """ 84 | A block consisting of Conv1d, GroupNorm, and Mish activation. 85 | 86 | Args: 87 | inp_channels: Number of input channels. 88 | out_channels: Number of output channels. 89 | kernel_size: Size of the convolutional kernel. 90 | n_groups: Number of groups for GroupNorm. 91 | """ 92 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 93 | super().__init__() 94 | 95 | self.block = nn.Sequential( 96 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 97 | nn.GroupNorm(n_groups, out_channels), 98 | nn.Mish(), 99 | ) 100 | 101 | def forward(self, x): 102 | return self.block(x) 103 | 104 | 105 | class ConditionalResidualBlock1D(nn.Module): 106 | """ 107 | Conditional residual block with FiLM modulation. 108 | 109 | Args: 110 | in_channels: Number of input channels. 111 | out_channels: Number of output channels. 112 | cond_dim: Dimension of the conditioning input. 113 | kernel_size: Size of the convolutional kernel. 114 | n_groups: Number of groups for GroupNorm. 115 | """ 116 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): 117 | super().__init__() 118 | 119 | self.blocks = nn.ModuleList([ 120 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 121 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 122 | ]) 123 | 124 | # FiLM modulation https://arxiv.org/abs/1709.07871 125 | # predicts per-channel scale and bias 126 | cond_channels = out_channels * 2 127 | self.out_channels = out_channels 128 | self.cond_encoder = nn.Sequential( 129 | nn.Mish(), 130 | nn.Linear(cond_dim, cond_channels), 131 | nn.Unflatten(-1, (-1, 1)) 132 | ) 133 | 134 | # ensure dimensions are compatible 135 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ 136 | if in_channels != out_channels else nn.Identity() 137 | 138 | def forward(self, x, cond): 139 | """ 140 | Forward pass for the conditional residual block. 141 | 142 | Args: 143 | x: Input tensor of shape [batch_size, in_channels, horizon]. 144 | cond: Conditioning tensor of shape [batch_size, cond_dim]. 145 | 146 | Returns: 147 | Output tensor of shape [batch_size, out_channels, horizon]. 148 | """ 149 | out = self.blocks[0](x) 150 | embed = self.cond_encoder(cond) 151 | 152 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) 153 | scale = embed[:, 0, ...] 154 | bias = embed[:, 1, ...] 155 | out = scale * out + bias 156 | 157 | out = self.blocks[1](out) 158 | out = out + self.residual_conv(x) 159 | return out 160 | 161 | 162 | class ConditionalUnet1D(nn.Module): 163 | """ 164 | Conditional 1D UNet for diffusion models. 165 | 166 | Args: 167 | input_dim: Dimension of the input actions. 168 | global_cond_dim: Dimension of global conditioning applied with FiLM. 169 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k. 170 | down_dims: Channel size for each UNet level. 171 | kernel_size: Convolutional kernel size. 172 | n_groups: Number of groups for GroupNorm. 173 | state_dim: Dimension of the state input. 174 | """ 175 | def __init__(self, input_dim, global_cond_dim, diffusion_step_embed_dim=256, 176 | down_dims=[256, 512, 1024], kernel_size=5, n_groups=8, state_dim=7): 177 | super().__init__() 178 | all_dims = [input_dim] + list(down_dims) 179 | start_dim = down_dims[0] 180 | 181 | self.global_1d_pool = nn.AdaptiveAvgPool1d(1) 182 | self.norm_after_pool = nn.LayerNorm(global_cond_dim) 183 | self.combine = nn.Linear(global_cond_dim + state_dim, global_cond_dim) 184 | 185 | dsed = diffusion_step_embed_dim 186 | diffusion_step_encoder = nn.Sequential( 187 | SinusoidalPosEmb(dsed, torch.bfloat16), 188 | nn.Linear(dsed, dsed * 4), 189 | nn.Mish(), 190 | nn.Linear(dsed * 4, dsed), 191 | ) 192 | cond_dim = dsed + global_cond_dim 193 | 194 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 195 | mid_dim = all_dims[-1] 196 | self.mid_modules = nn.ModuleList([ 197 | ConditionalResidualBlock1D( 198 | mid_dim, mid_dim, cond_dim=cond_dim, 199 | kernel_size=kernel_size, n_groups=n_groups 200 | ), 201 | ConditionalResidualBlock1D( 202 | mid_dim, mid_dim, cond_dim=cond_dim, 203 | kernel_size=kernel_size, n_groups=n_groups 204 | ), 205 | ]) 206 | 207 | down_modules = nn.ModuleList([]) 208 | for ind, (dim_in, dim_out) in enumerate(in_out): 209 | is_last = ind >= (len(in_out) - 1) 210 | down_modules.append(nn.ModuleList([ 211 | ConditionalResidualBlock1D( 212 | dim_in, dim_out, cond_dim=cond_dim, 213 | kernel_size=kernel_size, n_groups=n_groups), 214 | ConditionalResidualBlock1D( 215 | dim_out, dim_out, cond_dim=cond_dim, 216 | kernel_size=kernel_size, n_groups=n_groups), 217 | Downsample1d(dim_out) if not is_last else nn.Identity() 218 | ])) 219 | 220 | up_modules = nn.ModuleList([]) 221 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 222 | is_last = ind >= (len(in_out) - 1) 223 | up_modules.append(nn.ModuleList([ 224 | ConditionalResidualBlock1D( 225 | dim_out * 2, dim_in, cond_dim=cond_dim, 226 | kernel_size=kernel_size, n_groups=n_groups), 227 | ConditionalResidualBlock1D( 228 | dim_in, dim_in, cond_dim=cond_dim, 229 | kernel_size=kernel_size, n_groups=n_groups), 230 | Upsample1d(dim_in) if not is_last else nn.Identity() 231 | ])) 232 | 233 | final_conv = nn.Sequential( 234 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 235 | nn.Conv1d(start_dim, input_dim, 1), 236 | ) 237 | 238 | self.diffusion_step_encoder = diffusion_step_encoder 239 | self.up_modules = up_modules 240 | self.down_modules = down_modules 241 | self.final_conv = final_conv 242 | 243 | print("number of parameters: {:e}".format( 244 | sum(p.numel() for p in self.parameters())) 245 | ) 246 | 247 | def forward(self, 248 | sample: torch.Tensor, 249 | timestep: Union[torch.Tensor, float, int], 250 | global_cond=None, 251 | states=None): 252 | """ 253 | Forward pass for the Conditional UNet. 254 | 255 | Args: 256 | sample: Input tensor of shape (B, T, input_dim). 257 | timestep: Diffusion step, can be a tensor or an integer. 258 | global_cond: Global conditioning tensor of shape (B, global_cond_dim). 259 | states: Optional state tensor. 260 | 261 | Returns: 262 | Output tensor of shape (B, T, input_dim). 263 | """ 264 | # move axis for processing 265 | sample = sample.moveaxis(-1, -2) 266 | # process global conditioning 267 | global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) 268 | global_cond = self.norm_after_pool(global_cond) # layernorm 269 | global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond 270 | global_cond = self.combine(global_cond) 271 | 272 | timesteps = timestep 273 | if not torch.is_tensor(timesteps): 274 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 275 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 276 | timesteps = timesteps[None].to(sample.device) 277 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 278 | timesteps = timesteps.expand(sample.shape[0]) 279 | 280 | global_feature = self.diffusion_step_encoder(timesteps) 281 | 282 | if global_cond is not None: 283 | global_feature = torch.cat([ 284 | global_feature, global_cond 285 | ], axis=-1) 286 | 287 | x = sample 288 | h = [] 289 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 290 | x = resnet(x, global_feature) 291 | x = resnet2(x, global_feature) 292 | h.append(x) 293 | x = downsample(x) 294 | 295 | for mid_module in self.mid_modules: 296 | x = mid_module(x, global_feature) 297 | 298 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 299 | x = torch.cat((x, h.pop()), dim=1) 300 | x = resnet(x, global_feature) 301 | x = resnet2(x, global_feature) 302 | x = upsample(x) 303 | 304 | x = self.final_conv(x) 305 | 306 | # (B,C,T) 307 | x = x.moveaxis(-1, -2) 308 | # (B,T,C) 309 | return x 310 | -------------------------------------------------------------------------------- /policy_heads/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /llava-pythia/scripts/convert_vqav2_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | # from llava_pythia.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | import re 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | class EvalAIAnswerProcessor: 13 | """ 14 | Processes an answer similar to Eval AI 15 | copied from 16 | https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 17 | """ 18 | 19 | CONTRACTIONS = { 20 | "aint": "ain't", 21 | "arent": "aren't", 22 | "cant": "can't", 23 | "couldve": "could've", 24 | "couldnt": "couldn't", 25 | "couldn'tve": "couldn't've", 26 | "couldnt've": "couldn't've", 27 | "didnt": "didn't", 28 | "doesnt": "doesn't", 29 | "dont": "don't", 30 | "hadnt": "hadn't", 31 | "hadnt've": "hadn't've", 32 | "hadn'tve": "hadn't've", 33 | "hasnt": "hasn't", 34 | "havent": "haven't", 35 | "hed": "he'd", 36 | "hed've": "he'd've", 37 | "he'dve": "he'd've", 38 | "hes": "he's", 39 | "howd": "how'd", 40 | "howll": "how'll", 41 | "hows": "how's", 42 | "Id've": "I'd've", 43 | "I'dve": "I'd've", 44 | "Im": "I'm", 45 | "Ive": "I've", 46 | "isnt": "isn't", 47 | "itd": "it'd", 48 | "itd've": "it'd've", 49 | "it'dve": "it'd've", 50 | "itll": "it'll", 51 | "let's": "let's", 52 | "maam": "ma'am", 53 | "mightnt": "mightn't", 54 | "mightnt've": "mightn't've", 55 | "mightn'tve": "mightn't've", 56 | "mightve": "might've", 57 | "mustnt": "mustn't", 58 | "mustve": "must've", 59 | "neednt": "needn't", 60 | "notve": "not've", 61 | "oclock": "o'clock", 62 | "oughtnt": "oughtn't", 63 | "ow's'at": "'ow's'at", 64 | "'ows'at": "'ow's'at", 65 | "'ow'sat": "'ow's'at", 66 | "shant": "shan't", 67 | "shed've": "she'd've", 68 | "she'dve": "she'd've", 69 | "she's": "she's", 70 | "shouldve": "should've", 71 | "shouldnt": "shouldn't", 72 | "shouldnt've": "shouldn't've", 73 | "shouldn'tve": "shouldn't've", 74 | "somebody'd": "somebodyd", 75 | "somebodyd've": "somebody'd've", 76 | "somebody'dve": "somebody'd've", 77 | "somebodyll": "somebody'll", 78 | "somebodys": "somebody's", 79 | "someoned": "someone'd", 80 | "someoned've": "someone'd've", 81 | "someone'dve": "someone'd've", 82 | "someonell": "someone'll", 83 | "someones": "someone's", 84 | "somethingd": "something'd", 85 | "somethingd've": "something'd've", 86 | "something'dve": "something'd've", 87 | "somethingll": "something'll", 88 | "thats": "that's", 89 | "thered": "there'd", 90 | "thered've": "there'd've", 91 | "there'dve": "there'd've", 92 | "therere": "there're", 93 | "theres": "there's", 94 | "theyd": "they'd", 95 | "theyd've": "they'd've", 96 | "they'dve": "they'd've", 97 | "theyll": "they'll", 98 | "theyre": "they're", 99 | "theyve": "they've", 100 | "twas": "'twas", 101 | "wasnt": "wasn't", 102 | "wed've": "we'd've", 103 | "we'dve": "we'd've", 104 | "weve": "we've", 105 | "werent": "weren't", 106 | "whatll": "what'll", 107 | "whatre": "what're", 108 | "whats": "what's", 109 | "whatve": "what've", 110 | "whens": "when's", 111 | "whered": "where'd", 112 | "wheres": "where's", 113 | "whereve": "where've", 114 | "whod": "who'd", 115 | "whod've": "who'd've", 116 | "who'dve": "who'd've", 117 | "wholl": "who'll", 118 | "whos": "who's", 119 | "whove": "who've", 120 | "whyll": "why'll", 121 | "whyre": "why're", 122 | "whys": "why's", 123 | "wont": "won't", 124 | "wouldve": "would've", 125 | "wouldnt": "wouldn't", 126 | "wouldnt've": "wouldn't've", 127 | "wouldn'tve": "wouldn't've", 128 | "yall": "y'all", 129 | "yall'll": "y'all'll", 130 | "y'allll": "y'all'll", 131 | "yall'd've": "y'all'd've", 132 | "y'alld've": "y'all'd've", 133 | "y'all'dve": "y'all'd've", 134 | "youd": "you'd", 135 | "youd've": "you'd've", 136 | "you'dve": "you'd've", 137 | "youll": "you'll", 138 | "youre": "you're", 139 | "youve": "you've", 140 | } 141 | 142 | NUMBER_MAP = { 143 | "none": "0", 144 | "zero": "0", 145 | "one": "1", 146 | "two": "2", 147 | "three": "3", 148 | "four": "4", 149 | "five": "5", 150 | "six": "6", 151 | "seven": "7", 152 | "eight": "8", 153 | "nine": "9", 154 | "ten": "10", 155 | } 156 | ARTICLES = ["a", "an", "the"] 157 | PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") 158 | COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") 159 | PUNCTUATIONS = [ 160 | ";", 161 | r"/", 162 | "[", 163 | "]", 164 | '"', 165 | "{", 166 | "}", 167 | "(", 168 | ")", 169 | "=", 170 | "+", 171 | "\\", 172 | "_", 173 | "-", 174 | ">", 175 | "<", 176 | "@", 177 | "`", 178 | ",", 179 | "?", 180 | "!", 181 | ] 182 | 183 | def __init__(self, *args, **kwargs): 184 | pass 185 | 186 | def word_tokenize(self, word): 187 | word = word.lower() 188 | word = word.replace(",", "").replace("?", "").replace("'s", " 's") 189 | return word.strip() 190 | 191 | def process_punctuation(self, in_text): 192 | out_text = in_text 193 | for p in self.PUNCTUATIONS: 194 | if (p + " " in in_text or " " + p in in_text) or ( 195 | re.search(self.COMMA_STRIP, in_text) is not None 196 | ): 197 | out_text = out_text.replace(p, "") 198 | else: 199 | out_text = out_text.replace(p, " ") 200 | out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) 201 | return out_text 202 | 203 | def process_digit_article(self, in_text): 204 | out_text = [] 205 | temp_text = in_text.lower().split() 206 | for word in temp_text: 207 | word = self.NUMBER_MAP.setdefault(word, word) 208 | if word not in self.ARTICLES: 209 | out_text.append(word) 210 | else: 211 | pass 212 | for word_id, word in enumerate(out_text): 213 | if word in self.CONTRACTIONS: 214 | out_text[word_id] = self.CONTRACTIONS[word] 215 | out_text = " ".join(out_text) 216 | return out_text 217 | 218 | def __call__(self, item): 219 | item = self.word_tokenize(item) 220 | item = item.replace("\n", " ").replace("\t", " ").strip() 221 | item = self.process_punctuation(item) 222 | item = self.process_digit_article(item) 223 | return item 224 | 225 | 226 | class TextVQAAccuracyEvaluator: 227 | def __init__(self): 228 | self.answer_processor = EvalAIAnswerProcessor() 229 | 230 | def _compute_answer_scores(self, raw_answers): 231 | """ 232 | compute the accuracy (soft score) of human answers 233 | """ 234 | answers = [self.answer_processor(a) for a in raw_answers] 235 | assert len(answers) == 10 236 | gt_answers = list(enumerate(answers)) 237 | unique_answers = set(answers) 238 | unique_answer_scores = {} 239 | 240 | for unique_answer in unique_answers: 241 | accs = [] 242 | for gt_answer in gt_answers: 243 | other_answers = [item for item in gt_answers if item != gt_answer] 244 | matching_answers = [ 245 | item for item in other_answers if item[1] == unique_answer 246 | ] 247 | acc = min(1, float(len(matching_answers)) / 3) 248 | accs.append(acc) 249 | unique_answer_scores[unique_answer] = sum(accs) / len(accs) 250 | 251 | return unique_answer_scores 252 | 253 | def eval_pred_list(self, pred_list): 254 | pred_scores = [] 255 | for entry in tqdm(pred_list): 256 | pred_answer = self.answer_processor(entry["pred_answer"]) 257 | unique_answer_scores = self._compute_answer_scores(entry["gt_answers"]) 258 | score = unique_answer_scores.get(pred_answer, 0.0) 259 | pred_scores.append(score) 260 | 261 | accuracy = sum(pred_scores) / len(pred_scores) 262 | return accuracy 263 | 264 | 265 | class STVQAAccuracyEvaluator: 266 | def __init__(self): 267 | self.answer_processor = EvalAIAnswerProcessor() 268 | 269 | def eval_pred_list(self, pred_list): 270 | pred_scores = [] 271 | for entry in pred_list: 272 | pred_answer = self.answer_processor(entry["pred_answer"]) 273 | gts = [self.answer_processor(a) for a in entry["gt_answers"]] 274 | score = 1.0 if pred_answer in gts else 0.0 275 | pred_scores.append(score) 276 | 277 | accuracy = sum(pred_scores) / len(pred_scores) 278 | return accuracy 279 | 280 | 281 | class STVQAANLSEvaluator: 282 | def __init__(self): 283 | import editdistance # install with `pip install editdistance` 284 | 285 | self.get_edit_distance = editdistance.eval 286 | 287 | def get_anls(self, s1, s2): 288 | s1 = s1.lower().strip() 289 | s2 = s2.lower().strip() 290 | iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) 291 | anls = iou if iou >= 0.5 else 0.0 292 | return anls 293 | 294 | def eval_pred_list(self, pred_list): 295 | pred_scores = [] 296 | for entry in pred_list: 297 | anls = max( 298 | self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"] 299 | ) 300 | pred_scores.append(anls) 301 | 302 | accuracy = sum(pred_scores) / len(pred_scores) 303 | return accuracy 304 | 305 | 306 | class TextCapsBleu4Evaluator: 307 | def __init__(self): 308 | # The following script requires Java 1.8.0 and pycocotools installed. 309 | # The pycocoevalcap can be installed with pip as 310 | # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 311 | # Original pycocoevalcap code is at https://github.com/tylin/coco-caption 312 | # but has no python3 support yet. 313 | try: 314 | from pycocoevalcap.bleu.bleu import Bleu 315 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 316 | except ModuleNotFoundError: 317 | print( 318 | "Please install pycocoevalcap module using " 319 | "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa 320 | ) 321 | raise 322 | 323 | self.tokenizer = PTBTokenizer() 324 | self.scorer = Bleu(4) 325 | 326 | def eval_pred_list(self, pred_list): 327 | # Create reference and hypotheses captions. 328 | gts = {} 329 | res = {} 330 | for idx, entry in enumerate(pred_list): 331 | gts[idx] = [{"caption": a} for a in entry["gt_answers"]] 332 | res[idx] = [{"caption": entry["pred_answer"]}] 333 | 334 | gts = self.tokenizer.tokenize(gts) 335 | res = self.tokenizer.tokenize(res) 336 | score, _ = self.scorer.compute_score(gts, res) 337 | 338 | bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) 339 | return bleu4 340 | 341 | 342 | def parse_args(): 343 | parser = argparse.ArgumentParser() 344 | parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2") 345 | parser.add_argument('--ckpt', type=str, required=True) 346 | parser.add_argument('--split', type=str, required=True) 347 | return parser.parse_args() 348 | 349 | 350 | if __name__ == '__main__': 351 | 352 | args = parse_args() 353 | 354 | src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl') 355 | test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl') 356 | dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json') 357 | os.makedirs(os.path.dirname(dst), exist_ok=True) 358 | 359 | results = [] 360 | error_line = 0 361 | for line_idx, line in enumerate(open(src)): 362 | try: 363 | results.append(json.loads(line)) 364 | except: 365 | error_line += 1 366 | 367 | results = {x['question_id']: x['text'] for x in results} 368 | test_split = [json.loads(line) for line in open(test_split)] 369 | split_ids = set([x['question_id'] for x in test_split]) 370 | 371 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 372 | 373 | all_answers = [] 374 | 375 | answer_processor = EvalAIAnswerProcessor() 376 | 377 | for x in test_split: 378 | if x['question_id'] not in results: 379 | all_answers.append({ 380 | 'question_id': x['question_id'], 381 | 'answer': '' 382 | }) 383 | else: 384 | all_answers.append({ 385 | 'question_id': x['question_id'], 386 | 'answer': answer_processor(results[x['question_id']]) 387 | }) 388 | 389 | with open(dst, 'w') as f: 390 | json.dump(all_answers, open(dst, 'w')) 391 | --------------------------------------------------------------------------------