├── data_utils
├── __init__.py
└── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── datasets.cpython-38.pyc
│ └── processor.cpython-38.pyc
├── aloha_scripts
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── constants.cpython-38.pyc
└── constants.py
├── llava-pythia
├── __init__.py
├── llava_pythia
│ ├── __init__.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── language_model
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ └── __init__.cpython-38.pyc
│ │ │ └── pythia
│ │ │ │ ├── __pycache__
│ │ │ │ ├── llava_pythia.cpython-38.pyc
│ │ │ │ └── configuration_llava_pythia.cpython-38.pyc
│ │ │ │ └── configuration_llava_pythia.py
│ │ ├── __pycache__
│ │ │ ├── builder.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── llava_arch.cpython-38.pyc
│ │ ├── multimodal_projector
│ │ │ ├── __pycache__
│ │ │ │ ├── builder.cpython-310.pyc
│ │ │ │ └── builder.cpython-38.pyc
│ │ │ └── builder.py
│ │ ├── multimodal_encoder
│ │ │ ├── __pycache__
│ │ │ │ ├── clip_encoder.cpython-310.pyc
│ │ │ │ ├── clip_encoder.cpython-38.pyc
│ │ │ │ ├── siglip_encoder.cpython-310.pyc
│ │ │ │ └── siglip_encoder.cpython-38.pyc
│ │ │ ├── clip_encoder.py
│ │ │ └── siglip_encoder.py
│ │ └── builder.py
│ ├── train
│ │ └── __pycache__
│ │ │ └── llava_pythia_trainer.cpython-38.pyc
│ ├── constants.py
│ ├── utils.py
│ ├── mm_utils.py
│ └── conversation.py
├── scripts
│ ├── llava_pythia
│ │ ├── all.sh
│ │ ├── .ipynb_checkpoints
│ │ │ ├── all-checkpoint.sh
│ │ │ ├── all_train_robot-checkpoint.sh
│ │ │ ├── get_base_model-checkpoint.sh
│ │ │ ├── pretrain-checkpoint.sh
│ │ │ ├── finetune-checkpoint.sh
│ │ │ ├── train_robot-checkpoint.sh
│ │ │ └── lora_train_robot-checkpoint.sh
│ │ ├── all_train_robot.sh
│ │ ├── eval
│ │ │ ├── mme.sh
│ │ │ ├── .ipynb_checkpoints
│ │ │ │ ├── mme-checkpoint.sh
│ │ │ │ ├── mmvet-checkpoint.sh
│ │ │ │ ├── pope-checkpoint.sh
│ │ │ │ ├── textvqa-checkpoint.sh
│ │ │ │ ├── mmbench-checkpoint.sh
│ │ │ │ ├── sqa-checkpoint.sh
│ │ │ │ ├── vqav2-checkpoint.sh
│ │ │ │ └── gqa-checkpoint.sh
│ │ │ ├── vizwiz.sh
│ │ │ ├── mmvet.sh
│ │ │ ├── pope.sh
│ │ │ ├── textvqa.sh
│ │ │ ├── mmbench.sh
│ │ │ ├── sqa.sh
│ │ │ ├── vqav2.sh
│ │ │ └── gqa.sh
│ │ ├── get_base_model.sh
│ │ ├── pretrain.sh
│ │ ├── finetune.sh
│ │ ├── train_robot.sh
│ │ └── lora_train_robot.sh
│ ├── convert_mmvet_for_eval.py
│ ├── convert_gqa_for_eval.py
│ ├── zero2.json
│ ├── .ipynb_checkpoints
│ │ ├── zero2-checkpoint.json
│ │ ├── zero3-checkpoint.json
│ │ ├── training_states_2_tensorboard-checkpoint.py
│ │ ├── zero3_offload-checkpoint.json
│ │ ├── display_eval_results_all-checkpoint.py
│ │ └── convert_sqa_to_llava-checkpoint.py
│ ├── merge_lora_weights.py
│ ├── zero3.json
│ ├── training_states_2_tensorboard.py
│ ├── convert_mmbench_for_submission.py
│ ├── zero3_offload.json
│ ├── convert_vizwiz_for_submission.py
│ ├── display_eval_results_all.py
│ ├── convert_seed_for_submission.py
│ ├── convert_sqa_to_llava.py
│ ├── tranfer2llava.py
│ └── convert_vqav2_for_submission.py
├── preprocessor_config.json
└── pyproject.toml
├── policy_heads
├── __init__.py
├── util
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── misc.cpython-38.pyc
│ │ └── __init__.cpython-38.pyc
│ ├── box_ops.py
│ └── plot_utils.py
├── models
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── backbone.cpython-38.pyc
│ │ ├── detr_vae.cpython-38.pyc
│ │ ├── transformer.cpython-38.pyc
│ │ ├── position_encoding.cpython-38.pyc
│ │ └── droid_unet_diffusion.cpython-38.pyc
│ ├── __init__.py
│ ├── position_encoding.py
│ ├── backbone.py
│ └── droid_unet_diffusion.py
├── setup.py
├── README.md
└── LICENSE
├── __pycache__
└── torch_utils.cpython-38.pyc
├── .idea
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── Open_TinyVLA.iml
├── setup.py
├── scripts
├── process_ckpts.sh
└── train.sh
├── LICENSE
├── .gitignore
├── requirements.txt
└── README.md
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aloha_scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava-pythia/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/policy_heads/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/language_model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/policy_heads/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ./scripts/llava_pythia/pretrain.sh && ./scripts/llava_pythia/finetune.sh
--------------------------------------------------------------------------------
/__pycache__/torch_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/__pycache__/torch_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/aloha_scripts/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/aloha_scripts/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/processor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/data_utils/__pycache__/processor.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/util/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/util/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/aloha_scripts/__pycache__/constants.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/aloha_scripts/__pycache__/constants.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/all-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ./scripts/llava_pythia/pretrain.sh && ./scripts/llava_pythia/finetune.sh
--------------------------------------------------------------------------------
/policy_heads/util/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/util/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/backbone.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/backbone.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/detr_vae.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/detr_vae.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/position_encoding.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/position_encoding.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/__pycache__/builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/builder.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/__pycache__/llava_arch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/__pycache__/llava_arch.cpython-38.pyc
--------------------------------------------------------------------------------
/policy_heads/models/__pycache__/droid_unet_diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/policy_heads/models/__pycache__/droid_unet_diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/train/__pycache__/llava_pythia_trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/train/__pycache__/llava_pythia_trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/language_model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_projector/__pycache__/builder.cpython-38.pyc
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/clip_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/llava_pythia.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/llava_pythia.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/configuration_llava_pythia.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liyaxuanliyaxuan/TinyVLA/HEAD/llava-pythia/llava_pythia/model/language_model/pythia/__pycache__/configuration_llava_pythia.cpython-38.pyc
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='act',
6 | version='0.0.0',
7 | packages=find_packages(),
8 | license='MIT License',
9 | long_description=open('README.md').read(),
10 | )
11 |
--------------------------------------------------------------------------------
/policy_heads/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='policy_heads',
6 | version='0.0.0',
7 | packages=find_packages(),
8 | license='MIT License',
9 | long_description=open('README.md').read(),
10 | )
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/all_train_robot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #"1B" "410M" "14M" "1_4B" "70M" "2_8B"
3 | # 循环调用另一个脚本
4 | #"14M" "70M" "1B" "160M"
5 | for i in "14M" "70M" "160M" "1B"; do
6 | echo "Loop iteration $i"
7 | # 调用另一个脚本并传递参数
8 | ./scripts/llava_pythia/lora_train_robot.sh "$i"
9 | done
10 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/all_train_robot-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #"1B" "410M" "14M" "1_4B" "70M" "2_8B"
3 | # 循环调用另一个脚本
4 | #"14M" "70M" "1B" "160M"
5 | for i in "14M" "70M" "160M" "1B"; do
6 | echo "Loop iteration $i"
7 | # 调用另一个脚本并传递参数
8 | ./scripts/llava_pythia/lora_train_robot.sh "$i"
9 | done
10 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/llava-pythia/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "crop_size": 336,
3 | "do_center_crop": true,
4 | "do_normalize": true,
5 | "do_resize": true,
6 | "feature_extractor_type": "CLIPFeatureExtractor",
7 | "image_mean": [
8 | 0.48145466,
9 | 0.4578275,
10 | 0.40821073
11 | ],
12 | "image_std": [
13 | 0.26862954,
14 | 0.26130258,
15 | 0.27577711
16 | ],
17 | "resample": 3,
18 | "size": 336
19 | }
20 |
--------------------------------------------------------------------------------
/policy_heads/README.md:
--------------------------------------------------------------------------------
1 | This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
2 |
3 | @article{Carion2020EndtoEndOD,
4 | title={End-to-End Object Detection with Transformers},
5 | author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
6 | journal={ArXiv},
7 | year={2020},
8 | volume={abs/2005.12872}
9 | }
--------------------------------------------------------------------------------
/policy_heads/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from .detr_vae import build as build_vae
3 | from .detr_vae import build_cnnmlp as build_cnnmlp
4 | from .detr_vae import build_vae_head
5 | from .droid_unet_diffusion import ConditionalUnet1D
6 | def build_ACT_model(args):
7 | return build_vae(args)
8 |
9 | def build_ACT_head(args):
10 | return build_vae_head(args)
11 | def build_CNNMLP_model(args):
12 | return build_cnnmlp(args)
13 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_mmvet_for_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--src", type=str)
7 | parser.add_argument("--dst", type=str)
8 | args = parser.parse_args()
9 |
10 | cur_result = {}
11 |
12 | for line in open(args.src):
13 | data = json.loads(line)
14 | qid = data['question_id']
15 | cur_result[f'v1_{qid}'] = data['text']
16 |
17 | with open(args.dst, 'w') as f:
18 | json.dump(cur_result, f, indent=2)
19 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_gqa_for_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--src", type=str)
7 | parser.add_argument("--dst", type=str)
8 | args = parser.parse_args()
9 |
10 | all_answers = []
11 | for line_idx, line in enumerate(open(args.src)):
12 | res = json.loads(line)
13 | question_id = res['question_id']
14 | text = res['text'].rstrip('.').lower()
15 | all_answers.append({"questionId": question_id, "prediction": text})
16 |
17 | with open(args.dst, 'w') as f:
18 | json.dump(all_answers, f)
19 |
--------------------------------------------------------------------------------
/.idea/Open_TinyVLA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/zero2-checkpoint.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/mme.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \
6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \
7 | --answers-file ./playground/data/eval/MME/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | cd ./playground/data/eval/MME
12 |
13 | python convert_answer_to_mme.py --experiment llavaPhi-v0-3b
14 |
15 | cd eval_tool
16 |
17 | python calculation.py --results_dir answers/llavaPhi-v0-3b
18 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mme-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \
6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \
7 | --answers-file ./playground/data/eval/MME/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | cd ./playground/data/eval/MME
12 |
13 | python convert_answer_to_mme.py --experiment llavaPhi-v0-3b
14 |
15 | cd eval_tool
16 |
17 | python calculation.py --results_dir answers/llavaPhi-v0-3b
18 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/vizwiz.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path checkpoints/llavaPhi-v0-3b-finetune \
5 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \
6 | --image-folder ./playground/data/eval/vizwiz/test \
7 | --answers-file ./playground/data/eval/vizwiz/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode phi-2_v0
10 |
11 | python scripts/convert_vizwiz_for_submission.py \
12 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \
13 | --result-file ./playground/data/eval/vizwiz/answers/llavaPhi-v0-3b.jsonl \
14 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/llavaPhi-v0-3b.json
15 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/mmvet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \
6 | --image-folder ./playground/data/eval/mm-vet/images \
7 | --answers-file ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | mkdir -p ./playground/data/eval/mm-vet/results
12 |
13 | python scripts/convert_mmvet_for_eval.py \
14 | --src ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \
15 | --dst ./playground/data/eval/mm-vet/results/llavaPhi-v0-3b.json
16 |
17 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/pope.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
6 | --image-folder /data/team/zhumj/data/coco/val2014 \
7 | --answers-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | python llava_pythia/eval/eval_pope.py \
12 | --annotation-dir ./playground/data/eval/pope/coco \
13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
14 | --result-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/textvqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \
6 | --image-folder /data/team/zhumj/data/finetune/data/textvqa/train_images \
7 | --answers-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | python -m llava_pythia.eval.eval_textvqa \
12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \
13 | --result-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl
14 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mmvet-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \
6 | --image-folder ./playground/data/eval/mm-vet/images \
7 | --answers-file ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | mkdir -p ./playground/data/eval/mm-vet/results
12 |
13 | python scripts/convert_mmvet_for_eval.py \
14 | --src ./playground/data/eval/mm-vet/answers/llavaPhi-v0-3b.jsonl \
15 | --dst ./playground/data/eval/mm-vet/results/llavaPhi-v0-3b.json
16 |
17 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/pope-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
6 | --image-folder /data/team/zhumj/data/coco/val2014 \
7 | --answers-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | python llava_pythia/eval/eval_pope.py \
12 | --annotation-dir ./playground/data/eval/pope/coco \
13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
14 | --result-file ./playground/data/eval/pope/answers/llavaPhi-v0-3b.jsonl
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/textvqa-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_loader \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \
6 | --image-folder /data/team/zhumj/data/finetune/data/textvqa/train_images \
7 | --answers-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl \
8 | --temperature 0 \
9 | --conv-mode pythia
10 |
11 | python -m llava_pythia.eval.eval_textvqa \
12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \
13 | --result-file ./playground/data/eval/textvqa/answers/llavaPhi-v0-3b.jsonl
14 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/mmbench.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SPLIT="mmbench_dev_20230712"
4 | LLM_MODEL_SIZE=2_8B
5 | python -m llava_pythia.eval.model_vqa_mmbench \
6 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \
8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llavaPhi-v0-3b.jsonl \
9 | --single-pred-prompt \
10 | --temperature 0 \
11 | --conv-mode pythia
12 |
13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT
14 |
15 | python scripts/convert_mmbench_for_submission.py \
16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \
17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \
18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \
19 | --experiment llavaPhi-v0-3b
--------------------------------------------------------------------------------
/llava-pythia/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from llava_pythia.model.builder import load_pretrained_model
3 | from llava_pythia.mm_utils import get_model_name_from_path
4 |
5 |
6 | def merge_lora(args):
7 | model_name = get_model_name_from_path(args.model_path)
8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9 |
10 | model.save_pretrained(args.save_model_path)
11 | tokenizer.save_pretrained(args.save_model_path)
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model-path", type=str, required=True)
17 | parser.add_argument("--model-base", type=str, required=True)
18 | parser.add_argument("--save-model-path", type=str, required=True)
19 |
20 | args = parser.parse_args()
21 |
22 | merge_lora(args)
23 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/mmbench-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SPLIT="mmbench_dev_20230712"
4 | LLM_MODEL_SIZE=2_8B
5 | python -m llava_pythia.eval.model_vqa_mmbench \
6 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \
8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llavaPhi-v0-3b.jsonl \
9 | --single-pred-prompt \
10 | --temperature 0 \
11 | --conv-mode pythia
12 |
13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT
14 |
15 | python scripts/convert_mmbench_for_submission.py \
16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \
17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \
18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \
19 | --experiment llavaPhi-v0-3b
--------------------------------------------------------------------------------
/llava-pythia/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/sqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_science \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \
6 | --image-folder ./playground/data/eval/scienceqa/images/test \
7 | --answers-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \
8 | --single-pred-prompt \
9 | --temperature 0 \
10 | --conv-mode pythia
11 |
12 | python llava_pythia/eval/eval_science_qa.py \
13 | --base-dir ./playground/data/eval/scienceqa \
14 | --result-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \
15 | --output-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_output.jsonl \
16 | --output-result ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_result.json
17 |
18 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/zero3-checkpoint.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/sqa-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=2_8B
3 | python -m llava_pythia.eval.model_vqa_science \
4 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \
6 | --image-folder ./playground/data/eval/scienceqa/images/test \
7 | --answers-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \
8 | --single-pred-prompt \
9 | --temperature 0 \
10 | --conv-mode pythia
11 |
12 | python llava_pythia/eval/eval_science_qa.py \
13 | --base-dir ./playground/data/eval/scienceqa \
14 | --result-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b.jsonl \
15 | --output-file ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_output.jsonl \
16 | --output-result ./playground/data/eval/scienceqa/answers/llavaPhi-v0-3b_result.json
17 |
18 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/training_states_2_tensorboard.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.tensorboard import SummaryWriter
3 | import os
4 | import json
5 |
6 | def main():
7 | # create SummaryWriter
8 | pythia = "410M"
9 | log_p = f'/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/log'
10 |
11 | trainint_state_p = f"/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/trainer_state.json"
12 |
13 | os.makedirs(log_p, exist_ok=True)
14 |
15 | writer = SummaryWriter(log_dir=log_p)
16 |
17 | with open(trainint_state_p, "r") as f:
18 | data = json.load(f)
19 |
20 | # save loss in SummaryWriter
21 | for each in data['log_history']:
22 | if not 'loss' in each.keys():
23 | continue
24 | step, loss = each['step'], each['loss']
25 | writer.add_scalar('train/loss', loss, step)
26 |
27 | writer.close()
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_mmbench_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import pandas as pd
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--annotation-file", type=str, required=True)
9 | parser.add_argument("--result-dir", type=str, required=True)
10 | parser.add_argument("--upload-dir", type=str, required=True)
11 | parser.add_argument("--experiment", type=str, required=True)
12 |
13 | return parser.parse_args()
14 |
15 | if __name__ == "__main__":
16 | args = get_args()
17 |
18 | df = pd.read_table(args.annotation_file)
19 |
20 | cur_df = df.copy()
21 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category'])
22 | cur_df.insert(6, 'prediction', None)
23 | for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")):
24 | pred = json.loads(pred)
25 | cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text']
26 |
27 | cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl')
28 |
--------------------------------------------------------------------------------
/scripts/process_ckpts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This scripts is used to process the trained weights and generates a smaller and compact weights
3 | LLM_MODEL_SIZE=410M
4 |
5 |
6 | # path to trained TinyVLA weights
7 | source_dir="/path/to/trained/VLA/weights"
8 | # new path to save weights
9 | target_dir="/path/to/save/processed/VLA/weights"
10 |
11 | mkdir -p $target_dir
12 |
13 | exclude_pattern="global_step*"
14 |
15 | echo "copying checkpoints from $source_dir to $target_dir"
16 | rsync -av --exclude="$exclude_pattern" --exclude="$exclude_pattern/**" "$source_dir/" "$target_dir/"
17 |
18 | echo 'tranfer checkpoints to non_lora_trainables.bin'
19 | for dir in "$source_dir"/*/ ; do
20 |
21 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
22 | if ! find "$dir" -mindepth 1 -type f -name "non_lora_trainables.bin" | grep -q .; then
23 | cd "$dir" || exit
24 | python ./zero_to_fp32.py ./ ${target_dir}/$(basename "$dir")/non_lora_trainables.bin
25 | # cp $OUTPUT/non_lora_trainables.bin $dir
26 | fi
27 | fi
28 | done
29 |
30 | cd "/data/junjiewen/droid_results/checkpoint_all" || exit
31 |
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tony Z. Zhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/training_states_2_tensorboard-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.tensorboard import SummaryWriter
3 | import os
4 | import json
5 |
6 | def main():
7 | # 创建SummaryWriter对象
8 | pythia = "410M"
9 | log_p = f'/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/log'
10 |
11 | trainint_state_p = f"/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{pythia}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-w_state_huber/trainer_state.json"
12 |
13 | os.makedirs(log_p, exist_ok=True)
14 |
15 | writer = SummaryWriter(log_dir=log_p)
16 |
17 | # 假设loss_data是你保存的loss数据,格式为字典
18 | with open(trainint_state_p, "r") as f:
19 | data = json.load(f)
20 |
21 | # 将loss数据写入SummaryWriter
22 | for each in data['log_history']:
23 | if not 'loss' in each.keys():
24 | continue
25 | step, loss = each['step'], each['loss']
26 | writer.add_scalar('train/loss', loss, step)
27 |
28 | # 关闭SummaryWriter对象
29 | writer.close()
30 |
31 | if __name__ == "__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/llava-pythia/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava_pythia"
7 | version = "1.0.0"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
17 | "requests", "sentencepiece", "tokenizers==0.15.0",
18 | "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
19 | "shortuuid", "httpx==0.24.0",
20 | "deepspeed==0.9.5",
21 | "peft==0.4.0",
22 | "transformers==4.37.1",
23 | "accelerate==0.21.0",
24 | "bitsandbytes==0.41.0",
25 | "scikit-learn==1.2.2",
26 | "sentencepiece==0.1.99",
27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28 | "gradio_client==0.2.9"
29 | ]
30 |
31 | [project.urls]
32 | "Bug Tracker" = "https://github.com/zhuyiche/llava-phi/issues"
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/vqav2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 | LLM_MODEL_SIZE=2_8B
8 | CKPT="llavaPhi-v0-3b-finetune"
9 | SPLIT="llava_vqav2_mscoco_test-dev2015"
10 |
11 | for IDX in $(seq 0 $((CHUNKS-1))); do
12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \
13 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \
15 | --image-folder /data/team/zhumj/data/coco/test2015 \
16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
17 | --num-chunks $CHUNKS \
18 | --chunk-idx $IDX \
19 | --temperature 0 \
20 | --conv-mode pythia &
21 | done
22 |
23 | wait
24 |
25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl
26 |
27 | # Clear out the output file if it exists.
28 | > "$output_file"
29 |
30 | # Loop through the indices and concatenate each file.
31 | for IDX in $(seq 0 $((CHUNKS-1))); do
32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
33 | done
34 |
35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT
36 |
37 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/vqav2-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 | LLM_MODEL_SIZE=2_8B
8 | CKPT="llavaPhi-v0-3b-finetune"
9 | SPLIT="llava_vqav2_mscoco_test-dev2015"
10 |
11 | for IDX in $(seq 0 $((CHUNKS-1))); do
12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \
13 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \
15 | --image-folder /data/team/zhumj/data/coco/test2015 \
16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
17 | --num-chunks $CHUNKS \
18 | --chunk-idx $IDX \
19 | --temperature 0 \
20 | --conv-mode pythia &
21 | done
22 |
23 | wait
24 |
25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl
26 |
27 | # Clear out the output file if it exists.
28 | > "$output_file"
29 |
30 | # Loop through the indices and concatenate each file.
31 | for IDX in $(seq 0 $((CHUNKS-1))); do
32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
33 | done
34 |
35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT
36 |
37 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 | LLM_MODEL_SIZE=1B
8 | CKPT="llavaPhi-v0-3b"
9 | SPLIT="llavaPhi_gqa_testdev_balanced"
10 | GQADIR="./playground/data/eval/gqa/data"
11 |
12 | for IDX in $(seq 0 $((CHUNKS-1))); do
13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \
14 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \
16 | --image-folder /data/team/zhumj/data/finetune/data/gqa/images \
17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
18 | --num-chunks $CHUNKS \
19 | --chunk-idx $IDX \
20 | --temperature 0 \
21 | --conv-mode pythia &
22 | done
23 |
24 | wait
25 |
26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl
27 |
28 | # Clear out the output file if it exists.
29 | > "$output_file"
30 |
31 | # Loop through the indices and concatenate each file.
32 | for IDX in $(seq 0 $((CHUNKS-1))); do
33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
34 | done
35 |
36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json
37 |
38 | cd $GQADIR
39 | python eval/eval.py --tier testdev_balanced
40 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/eval/.ipynb_checkpoints/gqa-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 | LLM_MODEL_SIZE=1B
8 | CKPT="llavaPhi-v0-3b"
9 | SPLIT="llavaPhi_gqa_testdev_balanced"
10 | GQADIR="./playground/data/eval/gqa/data"
11 |
12 | for IDX in $(seq 0 $((CHUNKS-1))); do
13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava_pythia.eval.model_vqa_loader \
14 | --model-path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \
16 | --image-folder /data/team/zhumj/data/finetune/data/gqa/images \
17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
18 | --num-chunks $CHUNKS \
19 | --chunk-idx $IDX \
20 | --temperature 0 \
21 | --conv-mode pythia &
22 | done
23 |
24 | wait
25 |
26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl
27 |
28 | # Clear out the output file if it exists.
29 | > "$output_file"
30 |
31 | # Loop through the indices and concatenate each file.
32 | for IDX in $(seq 0 $((CHUNKS-1))); do
33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
34 | done
35 |
36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json
37 |
38 | cd $GQADIR
39 | python eval/eval.py --tier testdev_balanced
40 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/zero3_offload-checkpoint.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/get_base_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | python llava_pythia/train/convert_model2base_llava_pythia.py \
5 | --model_name_or_path /data/team/zhumj/model_Param/EleutherAI/pythia-$LLM_MODEL_SIZE \
6 | --version plain \
7 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \
8 | --image_folder /data/team/zhumj/data/llava-pretrain/images \
9 | --vision_tower openai/clip-vit-large-patch14-336 \
10 | --mm_projector_type mlp2x_gelu \
11 | --tune_mm_mlp_adapter True \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 16 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 2 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 24000 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-3 \
26 | --weight_decay 0.1 \
27 | --warmup_ratio 0. \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | deepspeed --master_port 29600 llava_pythia/train/train.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \
7 | --version plain \
8 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \
9 | --image_folder /data/team/zhumj/data/llava-pretrain/images \
10 | --tune_mm_mlp_adapter True \
11 | --freeze_vision_tower True \
12 | --freeze_backbone True \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 32 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 24000 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-3 \
26 | --weight_decay 0. \
27 | --warmup_ratio 0.03 \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/get_base_model-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | python llava_pythia/train/convert_model2base_llava_pythia.py \
5 | --model_name_or_path /data/team/zhumj/model_Param/EleutherAI/pythia-$LLM_MODEL_SIZE \
6 | --version plain \
7 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \
8 | --image_folder /data/team/zhumj/data/llava-pretrain/images \
9 | --vision_tower openai/clip-vit-large-patch14-336 \
10 | --mm_projector_type mlp2x_gelu \
11 | --tune_mm_mlp_adapter True \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 16 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 2 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 24000 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-3 \
26 | --weight_decay 0.1 \
27 | --warmup_ratio 0. \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/pretrain-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | deepspeed --master_port 29600 llava_pythia/train/train.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/base_checkpoints_llava_vanilla_pythia \
7 | --version plain \
8 | --data_path /data/team/zhumj/data/llava-pretrain/blip_laion_cc_sbu_558k.json \
9 | --image_folder /data/team/zhumj/data/llava-pretrain/images \
10 | --tune_mm_mlp_adapter True \
11 | --freeze_vision_tower True \
12 | --freeze_backbone True \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 32 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 24000 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-3 \
26 | --weight_decay 0. \
27 | --warmup_ratio 0.03 \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
37 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_vizwiz_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 |
5 | from llava_pythia.eval.m4c_evaluator import EvalAIAnswerProcessor
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--annotation-file', type=str, required=True)
11 | parser.add_argument('--result-file', type=str, required=True)
12 | parser.add_argument('--result-upload-file', type=str, required=True)
13 | return parser.parse_args()
14 |
15 |
16 | if __name__ == '__main__':
17 |
18 | args = parse_args()
19 |
20 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
21 |
22 | results = []
23 | error_line = 0
24 | for line_idx, line in enumerate(open(args.result_file)):
25 | try:
26 | results.append(json.loads(line))
27 | except:
28 | error_line += 1
29 | results = {x['question_id']: x['text'] for x in results}
30 | test_split = [json.loads(line) for line in open(args.annotation_file)]
31 | split_ids = set([x['question_id'] for x in test_split])
32 |
33 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
34 |
35 | all_answers = []
36 |
37 | answer_processor = EvalAIAnswerProcessor()
38 |
39 | for x in test_split:
40 | assert x['question_id'] in results
41 | all_answers.append({
42 | 'image': x['image'],
43 | 'answer': answer_processor(results[x['question_id']])
44 | })
45 |
46 | with open(args.result_upload_file, 'w') as f:
47 | json.dump(all_answers, f)
48 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/display_eval_results_all.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ckpt_p = "/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{msz}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2/checkpoint-{ckpt}"
4 |
5 | model_size = ['14M', '70M', '160M', '410M', '1B']
6 |
7 | TASK = ('kitchen_sdoor_open-v3', 'kitchen_micro_open-v3', 'kitchen_light_on-v3', 'kitchen_ldoor_open-v3','kitchen_knob1_on-v3')
8 |
9 | task_id = 0
10 | for task_id in range(5):
11 | # print(f"{TASK[task_id]}")
12 | with open('eval_results_all.txt', 'w') as f:
13 | f.write(f"eval results on Franka Kitchen {TASK[task_id]}:\n")
14 |
15 | for msz in model_size:
16 | f.write(f"###################Model size:{msz}###################\n")
17 | for i in range(2,6):
18 | ckpt = str(i*1000)
19 | p = ckpt_p.replace('{msz}', msz).replace('{ckpt}', ckpt)
20 | try:
21 | with open(os.path.join(p, f"{TASK[task_id]}.txt"), 'r') as f1:
22 | content = f1.read()
23 | content = content.split('\n')
24 | # for e in content:
25 | # if e == "":
26 | # continue
27 | # if '50_' not in e:
28 | # continue
29 |
30 | # f.write(f"ckpt_{i*1000}:{e.strip()}\n")
31 | f.write(f"ckpt_{ckpt}:{content[-1].strip()}\n")
32 | except Exception as e:
33 | print(e)
34 | pass
35 | with open('eval_results_all.txt', 'r') as f:
36 | data = f.read()
37 | print(data)
38 |
39 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/display_eval_results_all-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ckpt_p = "/data/private/wenjj/llava-pythia/checkpoint_all/pythia_{msz}/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2/checkpoint-{ckpt}"
4 |
5 | model_size = ['14M', '70M', '160M', '410M', '1B']
6 |
7 | TASK = ('kitchen_sdoor_open-v3', 'kitchen_micro_open-v3', 'kitchen_light_on-v3', 'kitchen_ldoor_open-v3','kitchen_knob1_on-v3')
8 |
9 | task_id = 0
10 | for task_id in range(5):
11 | # print(f"{TASK[task_id]}")
12 | with open('eval_results_all.txt', 'w') as f:
13 | f.write(f"eval results on Franka Kitchen {TASK[task_id]}:\n")
14 |
15 | for msz in model_size:
16 | f.write(f"###################Model size:{msz}###################\n")
17 | for i in range(2,6):
18 | ckpt = str(i*1000)
19 | p = ckpt_p.replace('{msz}', msz).replace('{ckpt}', ckpt)
20 | try:
21 | with open(os.path.join(p, f"{TASK[task_id]}.txt"), 'r') as f1:
22 | content = f1.read()
23 | content = content.split('\n')
24 | # for e in content:
25 | # if e == "":
26 | # continue
27 | # if '50_' not in e:
28 | # continue
29 |
30 | # f.write(f"ckpt_{i*1000}:{e.strip()}\n")
31 | f.write(f"ckpt_{ckpt}:{content[-1].strip()}\n")
32 | except Exception as e:
33 | print(e)
34 | pass
35 | with open('eval_results_all.txt', 'r') as f:
36 | data = f.read()
37 | print(data)
38 |
39 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \
7 | --version v0 \
8 | --data_path /data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json \
9 | --image_folder /data/team/zhumj/data/finetune/data \
10 | --tune_mm_mlp_adapter True \
11 | --freeze_vision_tower False \
12 | --freeze_backbone Talse \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length False \
17 | --bf16 True \
18 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
19 | --num_train_epochs 1 \
20 | --per_device_train_batch_size 4 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 4 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 50000 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-5 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.03 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 4 \
36 | --lazy_preprocess True \
37 | --report_to wandb
38 |
39 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
40 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
41 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/finetune-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=14M
3 |
4 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-pretrain \
7 | --version v0 \
8 | --data_path /data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json \
9 | --image_folder /data/team/zhumj/data/finetune/data \
10 | --tune_mm_mlp_adapter True \
11 | --freeze_vision_tower False \
12 | --freeze_backbone Talse \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length False \
17 | --bf16 True \
18 | --output_dir ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
19 | --num_train_epochs 1 \
20 | --per_device_train_batch_size 4 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 4 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 50000 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-5 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.03 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 4 \
36 | --lazy_preprocess True \
37 | --report_to wandb
38 |
39 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
40 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
41 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ACTION_HEAD=droid_diffusion # specify action policy head type
4 | # define OUTPUT path
5 |
6 | OUTPUT=/path/to/save_dir
7 |
8 | if [ -d "$OUTPUT" ]; then
9 | echo 'output exists'
10 | else
11 | echo '!!output not exists!!'
12 | mkdir -p $OUTPUT
13 | fi
14 | # backup the train scripts
15 | cp ./scripts/train.sh $OUTPUT
16 |
17 | # detailed usage of each parameter can be found in train_tinyvla.py
18 |
19 | deepspeed --master_port 29600 --num_gpus=8 --num_nodes=1 ./train_tinyvla.py \
20 | --deepspeed scripts/zero2.json \
21 | --lora_enable True \
22 | --lora_module 'vit llm' \
23 | --load_pretrain False \
24 | --pretrain_image_size 320 \
25 | --lora_r 64 \
26 | --lora_alpha 256 \
27 | --non_lora_lr 2e-5 \
28 | --task_name "example_task_config" \
29 | --model_name_or_path /path/to/pretrained_vlm \
30 | --version v0 \
31 | --tune_mm_mlp_adapter True \
32 | --freeze_vision_tower True \
33 | --freeze_backbone True \
34 | --mm_use_im_start_end False \
35 | --mm_use_im_patch_token False \
36 | --image_aspect_ratio pad \
37 | --group_by_modality_length False \
38 | --bf16 True \
39 | --output_dir $OUTPUT \
40 | --max_steps 10000 \
41 | --per_device_train_batch_size 32 \
42 | --gradient_accumulation_steps 1 \
43 | --save_strategy "steps" \
44 | --save_steps 1000 \
45 | --save_total_limit 50 \
46 | --learning_rate 2e-4 \
47 | --weight_decay 0. \
48 | --warmup_ratio 0.005 \
49 | --lr_scheduler_type "cosine" \
50 | --logging_steps 10 \
51 | --tf32 True \
52 | --model_max_length 2048 \
53 | --gradient_checkpointing True \
54 | --dataloader_num_workers 8 \
55 | --lazy_preprocess True \
56 | --action_head_type $ACTION_HEAD \
57 | --use_state True \
58 | --concat "token_cat" \
59 | --window_size 6 \
60 | --report_to tensorboard \
61 | --logging_dir $OUTPUT/log
62 |
63 | for dir in "$OUTPUT"/*/ ; do
64 | # 检查文件夹名称是否包含'checkpoint'
65 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
66 | cp llava-pythia/preprocessor_config.json $dir
67 | fi
68 | done
69 |
70 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/train_robot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # LLM_MODEL_SIZE=$1
3 | LLM_MODEL_SIZE=2_8B
4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
5 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-1view_adapter3
6 |
7 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \
8 | # echo "waiting for 20minutes..."
9 | # sleep 20m
10 |
11 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
12 | --deepspeed ./scripts/zero2.json \
13 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
14 | --version v0 \
15 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \
16 | --image_folder /data/team/zhumj/data/finetune/data \
17 | --tune_mm_mlp_adapter True \
18 | --freeze_vision_tower True \
19 | --freeze_backbone True \
20 | --mm_use_im_start_end False \
21 | --mm_use_im_patch_token False \
22 | --image_aspect_ratio pad \
23 | --group_by_modality_length False \
24 | --bf16 True \
25 | --output_dir $OUTPUT \
26 | --num_train_epochs 15 \
27 | --per_device_train_batch_size 4 \
28 | --per_device_eval_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "steps" \
31 | --save_strategy "steps" \
32 | --save_steps 1000 \
33 | --save_total_limit 15 \
34 | --learning_rate 3e-5 \
35 | --weight_decay 0. \
36 | --warmup_ratio 0.005 \
37 | --lr_scheduler_type "cosine" \
38 | --logging_steps 10 \
39 | --tf32 True \
40 | --model_max_length 2048 \
41 | --gradient_checkpointing True \
42 | --dataloader_num_workers 4 \
43 | --lazy_preprocess True \
44 | --action_head "fc" \
45 | --use_state True \
46 | --lora_enable False \
47 | --window_size 6 \
48 | --logging_dir $OUTPUT/log
49 | --report_to wandb
50 |
51 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
52 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
53 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT
54 | for dir in "$OUTPUT"/*/ ; do
55 | # 检查文件夹名称是否包含'checkpoint'
56 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
57 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir
58 | fi
59 | done
60 |
61 | cp ./scripts/llava_pythia/train_robot.sh $OUTPUT
62 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config):
34 | """
35 | Constructs a vision projector based on the specified configuration.
36 |
37 | Args:
38 | - config: An object containing configuration attributes. It should have
39 | 'mm_projector_type' to specify the type of projector and 'mm_hidden_size'
40 | and 'hidden_size' for the dimensions of the layers.
41 |
42 | Returns:
43 | - A PyTorch module that acts as the vision projector. The type of module
44 | returned depends on the 'mm_projector_type' attribute in the config:
45 | - 'linear': Returns a linear layer mapping from mm_hidden_size to hidden_size.
46 | - 'mlp{n}x_gelu': Returns a sequential model with n layers, each consisting
47 | of a GELU activation followed by a linear layer.
48 | - 'identity': Returns an IdentityMap, which simply returns the input as is.
49 |
50 | Raises:
51 | - ValueError: If the 'mm_projector_type' is not recognized.
52 | """
53 | projector_type = getattr(config, 'mm_projector_type', 'linear')
54 |
55 | if projector_type == 'linear':
56 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
57 |
58 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
59 | if mlp_gelu_match:
60 | mlp_depth = int(mlp_gelu_match.group(1))
61 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
62 | for _ in range(1, mlp_depth):
63 | modules.append(nn.GELU())
64 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
65 | return nn.Sequential(*modules)
66 |
67 | if projector_type == 'identity':
68 | return IdentityMap()
69 |
70 | raise ValueError(f'Unknown projector type: {projector_type}')
71 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/train_robot-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # LLM_MODEL_SIZE=$1
3 | LLM_MODEL_SIZE=2_8B
4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
5 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-1view_adapter3
6 |
7 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \
8 | # echo "waiting for 20minutes..."
9 | # sleep 20m
10 |
11 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
12 | --deepspeed ./scripts/zero2.json \
13 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
14 | --version v0 \
15 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \
16 | --image_folder /data/team/zhumj/data/finetune/data \
17 | --tune_mm_mlp_adapter True \
18 | --freeze_vision_tower True \
19 | --freeze_backbone True \
20 | --mm_use_im_start_end False \
21 | --mm_use_im_patch_token False \
22 | --image_aspect_ratio pad \
23 | --group_by_modality_length False \
24 | --bf16 True \
25 | --output_dir $OUTPUT \
26 | --num_train_epochs 15 \
27 | --per_device_train_batch_size 4 \
28 | --per_device_eval_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "steps" \
31 | --save_strategy "steps" \
32 | --save_steps 1000 \
33 | --save_total_limit 15 \
34 | --learning_rate 3e-5 \
35 | --weight_decay 0. \
36 | --warmup_ratio 0.005 \
37 | --lr_scheduler_type "cosine" \
38 | --logging_steps 10 \
39 | --tf32 True \
40 | --model_max_length 2048 \
41 | --gradient_checkpointing True \
42 | --dataloader_num_workers 4 \
43 | --lazy_preprocess True \
44 | --action_head "fc" \
45 | --use_state True \
46 | --lora_enable False \
47 | --window_size 6 \
48 | --logging_dir $OUTPUT/log
49 | --report_to wandb
50 |
51 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
52 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
53 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT
54 | for dir in "$OUTPUT"/*/ ; do
55 | # 检查文件夹名称是否包含'checkpoint'
56 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
57 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir
58 | fi
59 | done
60 |
61 | cp ./scripts/llava_pythia/train_robot.sh $OUTPUT
62 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/lora_train_robot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=$1
3 | # LLM_MODEL_SIZE=2_8B
4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
5 | # lora only vit and tune adapter
6 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2
7 |
8 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \
9 | # echo "waiting for 6 hours..."
10 | # sleep 6h
11 |
12 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
13 | --lora_enable True --lora_r 64 --lora_alpha 256 --non_lora_lr 3e-5 \
14 | --deepspeed ./scripts/zero2.json \
15 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
16 | --version v0 \
17 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \
18 | --image_folder /data/team/zhumj/data/finetune/data \
19 | --tune_mm_mlp_adapter True \
20 | --freeze_vision_tower True \
21 | --freeze_backbone True \
22 | --mm_use_im_start_end False \
23 | --mm_use_im_patch_token False \
24 | --image_aspect_ratio pad \
25 | --group_by_modality_length False \
26 | --bf16 True \
27 | --output_dir $OUTPUT \
28 | --num_train_epochs 15 \
29 | --per_device_train_batch_size 4 \
30 | --per_device_eval_batch_size 4 \
31 | --gradient_accumulation_steps 4 \
32 | --evaluation_strategy "steps" \
33 | --save_strategy "steps" \
34 | --save_steps 1000 \
35 | --save_total_limit 15 \
36 | --learning_rate 2e-4 \
37 | --weight_decay 0. \
38 | --warmup_ratio 0.005 \
39 | --lr_scheduler_type "cosine" \
40 | --logging_steps 10 \
41 | --tf32 True \
42 | --model_max_length 2048 \
43 | --gradient_checkpointing True \
44 | --dataloader_num_workers 4 \
45 | --lazy_preprocess True \
46 | --action_head "fc" \
47 | --use_state True \
48 | --concat "channel_cat" \
49 | --window_size 6 \
50 | --report_to wandb \
51 | --logging_dir $OUTPUT/log
52 |
53 |
54 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
55 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
56 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT
57 | for dir in "$OUTPUT"/*/ ; do
58 | # 检查文件夹名称是否包含'checkpoint'
59 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
60 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir
61 | # cp $OUTPUT/non_lora_trainables.bin $dir
62 | fi
63 | done
64 |
65 | cp ./scripts/llava_pythia/lora_train_robot.sh $OUTPUT
66 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/llava_pythia/.ipynb_checkpoints/lora_train_robot-checkpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LLM_MODEL_SIZE=$1
3 | # LLM_MODEL_SIZE=2_8B
4 | # ./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune
5 | # lora only vit and tune adapter
6 | OUTPUT=./checkpoint_all/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-robot-action-view_channel_cat_lora2
7 |
8 | # deepspeed --master_port 29601 --include localhost:4,5,6,7 llava_pythia/train/train.py \
9 | # echo "waiting for 6 hours..."
10 | # sleep 6h
11 |
12 | deepspeed --master_port 29601 --num_gpus=8 --num_nodes=1 llava_pythia/train/train.py \
13 | --lora_enable True --lora_r 64 --lora_alpha 256 --non_lora_lr 3e-5 \
14 | --deepspeed ./scripts/zero2.json \
15 | --model_name_or_path /data/team/zhumj/model_Param/llava_pythia_checkpoints/pythia_$LLM_MODEL_SIZE/vanilla_pythia_pt_f_vit/llavaPythia-v0-finetune \
16 | --version v0 \
17 | --data_path /data/private/data/llava_data/franka_kitchen_finetune/left_cap2/std_train_left_cap2_50k.json \
18 | --image_folder /data/team/zhumj/data/finetune/data \
19 | --tune_mm_mlp_adapter True \
20 | --freeze_vision_tower True \
21 | --freeze_backbone True \
22 | --mm_use_im_start_end False \
23 | --mm_use_im_patch_token False \
24 | --image_aspect_ratio pad \
25 | --group_by_modality_length False \
26 | --bf16 True \
27 | --output_dir $OUTPUT \
28 | --num_train_epochs 15 \
29 | --per_device_train_batch_size 4 \
30 | --per_device_eval_batch_size 4 \
31 | --gradient_accumulation_steps 4 \
32 | --evaluation_strategy "steps" \
33 | --save_strategy "steps" \
34 | --save_steps 1000 \
35 | --save_total_limit 15 \
36 | --learning_rate 2e-4 \
37 | --weight_decay 0. \
38 | --warmup_ratio 0.005 \
39 | --lr_scheduler_type "cosine" \
40 | --logging_steps 10 \
41 | --tf32 True \
42 | --model_max_length 2048 \
43 | --gradient_checkpointing True \
44 | --dataloader_num_workers 4 \
45 | --lazy_preprocess True \
46 | --action_head "fc" \
47 | --use_state True \
48 | --concat "channel_cat" \
49 | --window_size 6 \
50 | --logging_dir $OUTPUT/log
51 | --report_to wandb
52 |
53 | #/data/private/data/llava_data/franka_kitchen_finetune/left_cap2/left_cap2_50k.json
54 | #/data/team/zhumj/data/finetune/data/llava_v1_5_mix665k.json
55 | # cp openai/clip-vit-large-patch14-336/preprocessor_config.json $OUTPUT
56 | for dir in "$OUTPUT"/*/ ; do
57 | # 检查文件夹名称是否包含'checkpoint'
58 | if [[ "$(basename "$dir")" == *"checkpoint"* ]]; then
59 | cp openai/clip-vit-large-patch14-336/preprocessor_config.json $dir
60 | # cp $OUTPUT/non_lora_trainables.bin $dir
61 | fi
62 | done
63 |
64 | cp ./scripts/llava_pythia/lora_train_robot.sh $OUTPUT
65 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_seed_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--annotation-file", type=str)
9 | parser.add_argument("--result-file", type=str)
10 | parser.add_argument("--result-upload-file", type=str)
11 | return parser.parse_args()
12 |
13 |
14 | def eval_single(result_file, eval_only_type=None):
15 | results = {}
16 | for line in open(result_file):
17 | row = json.loads(line)
18 | results[row['question_id']] = row
19 |
20 | type_counts = {}
21 | correct_counts = {}
22 | for question_data in data['questions']:
23 | if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue
24 | data_type = question_data['question_type_id']
25 | type_counts[data_type] = type_counts.get(data_type, 0) + 1
26 | try:
27 | question_id = int(question_data['question_id'])
28 | except:
29 | question_id = question_data['question_id']
30 | if question_id not in results:
31 | correct_counts[data_type] = correct_counts.get(data_type, 0)
32 | continue
33 | row = results[question_id]
34 | if row['text'] == question_data['answer']:
35 | correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
36 |
37 | total_count = 0
38 | total_correct = 0
39 | for data_type in sorted(type_counts.keys()):
40 | accuracy = correct_counts[data_type] / type_counts[data_type] * 100
41 | if eval_only_type is None:
42 | print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
43 |
44 | total_count += type_counts[data_type]
45 | total_correct += correct_counts[data_type]
46 |
47 | total_accuracy = total_correct / total_count * 100
48 | if eval_only_type is None:
49 | print(f"Total accuracy: {total_accuracy:.2f}%")
50 | else:
51 | print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
52 |
53 | return results
54 |
55 | if __name__ == "__main__":
56 | args = get_args()
57 | data = json.load(open(args.annotation_file))
58 | ques_type_id_to_name = {id:n for n,id in data['question_type'].items()}
59 |
60 | results = eval_single(args.result_file)
61 | eval_single(args.result_file, eval_only_type='image')
62 | eval_single(args.result_file, eval_only_type='video')
63 |
64 | with open(args.result_upload_file, 'w') as fp:
65 | for question in data['questions']:
66 | qid = question['question_id']
67 | if qid in results:
68 | result = results[qid]
69 | else:
70 | result = results[int(qid)]
71 | fp.write(json.dumps({
72 | 'question_id': qid,
73 | 'prediction': result['text']
74 | }) + '\n')
75 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_sqa_to_llava.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import fire
4 | import re
5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
6 |
7 |
8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
10 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
11 |
12 | split_problems = build_prompt_chatbot(
13 | problems, split_indices, prompt_format,
14 | use_caption=False, is_test=False)
15 |
16 | target_format = []
17 | for prob_id, (input, output) in split_problems.items():
18 | if input.startswith('Question: '):
19 | input = input.replace('Question: ', '')
20 | if output.startswith('Answer: '):
21 | output = output.replace('Answer: ', '')
22 |
23 | raw_prob_data = problems[prob_id]
24 | if raw_prob_data['image'] is None:
25 | target_format.append({
26 | "id": prob_id,
27 | "conversations": [
28 | {'from': 'human', 'value': f"{input}"},
29 | {'from': 'gpt', 'value': f"{output}"},
30 | ],
31 | })
32 |
33 | else:
34 | target_format.append({
35 | "id": prob_id,
36 | "image": os.path.join(prob_id, raw_prob_data['image']),
37 | "conversations": [
38 | {'from': 'human', 'value': f"{input}\n"},
39 | {'from': 'gpt', 'value': f"{output}"},
40 | ],
41 | })
42 |
43 | print(f'Number of samples: {len(target_format)}')
44 |
45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
46 | json.dump(target_format, f, indent=2)
47 |
48 |
49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
51 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
52 |
53 | split_problems = build_prompt_chatbot(
54 | problems, split_indices, prompt_format,
55 | use_caption=False, is_test=False)
56 |
57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
58 | for prob_id, (input, output) in split_problems.items():
59 | if input.startswith('Question: '):
60 | input = input.replace('Question: ', '')
61 | if output.startswith('Answer: '):
62 | output = output.replace('Answer: ', '')
63 |
64 | raw_prob_data = problems[prob_id]
65 | if raw_prob_data['image'] is None:
66 | data = {
67 | "id": prob_id,
68 | "instruction": f"{input}",
69 | "output": f"{output}",
70 | }
71 |
72 | else:
73 | data = {
74 | "id": prob_id,
75 | "image": os.path.join(prob_id, raw_prob_data['image']),
76 | "instruction": f"{input}\n",
77 | "output": f"{output}",
78 | }
79 | writer.write(json.dumps(data) + '\n')
80 | writer.close()
81 |
82 |
83 | def main(task, **kwargs):
84 | globals()[task](**kwargs)
85 |
86 |
87 | if __name__ == "__main__":
88 | fire.Fire(main)
89 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/.ipynb_checkpoints/convert_sqa_to_llava-checkpoint.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import fire
4 | import re
5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
6 |
7 |
8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
10 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
11 |
12 | split_problems = build_prompt_chatbot(
13 | problems, split_indices, prompt_format,
14 | use_caption=False, is_test=False)
15 |
16 | target_format = []
17 | for prob_id, (input, output) in split_problems.items():
18 | if input.startswith('Question: '):
19 | input = input.replace('Question: ', '')
20 | if output.startswith('Answer: '):
21 | output = output.replace('Answer: ', '')
22 |
23 | raw_prob_data = problems[prob_id]
24 | if raw_prob_data['image'] is None:
25 | target_format.append({
26 | "id": prob_id,
27 | "conversations": [
28 | {'from': 'human', 'value': f"{input}"},
29 | {'from': 'gpt', 'value': f"{output}"},
30 | ],
31 | })
32 |
33 | else:
34 | target_format.append({
35 | "id": prob_id,
36 | "image": os.path.join(prob_id, raw_prob_data['image']),
37 | "conversations": [
38 | {'from': 'human', 'value': f"{input}\n"},
39 | {'from': 'gpt', 'value': f"{output}"},
40 | ],
41 | })
42 |
43 | print(f'Number of samples: {len(target_format)}')
44 |
45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
46 | json.dump(target_format, f, indent=2)
47 |
48 |
49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
51 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
52 |
53 | split_problems = build_prompt_chatbot(
54 | problems, split_indices, prompt_format,
55 | use_caption=False, is_test=False)
56 |
57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
58 | for prob_id, (input, output) in split_problems.items():
59 | if input.startswith('Question: '):
60 | input = input.replace('Question: ', '')
61 | if output.startswith('Answer: '):
62 | output = output.replace('Answer: ', '')
63 |
64 | raw_prob_data = problems[prob_id]
65 | if raw_prob_data['image'] is None:
66 | data = {
67 | "id": prob_id,
68 | "instruction": f"{input}",
69 | "output": f"{output}",
70 | }
71 |
72 | else:
73 | data = {
74 | "id": prob_id,
75 | "image": os.path.join(prob_id, raw_prob_data['image']),
76 | "instruction": f"{input}\n",
77 | "output": f"{output}",
78 | }
79 | writer.write(json.dumps(data) + '\n')
80 | writer.close()
81 |
82 |
83 | def main(task, **kwargs):
84 | globals()[task](**kwargs)
85 |
86 |
87 | if __name__ == "__main__":
88 | fire.Fire(main)
89 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import CLIPPreTrainedModel, CLIPVisionConfig
7 | from transformers.models.clip.modeling_clip import CLIPVisionTransformer
8 | from llava_pythia.model.language_model.pythia.configuration_llava_pythia import LlavaPythiaVisionConfig
9 |
10 |
11 | class CLIPVisionTower(CLIPPreTrainedModel):
12 | config_class = LlavaPythiaVisionConfig
13 |
14 | def __init__(self, config):
15 | super().__init__(config)
16 |
17 | self.vision_model = CLIPVisionTransformer(config)
18 | # Initialize weights and apply final processing
19 | self.post_init()
20 |
21 | def get_input_embeddings(self) -> nn.Module:
22 | return self.vision_model.embeddings.patch_embedding
23 |
24 | def feature_select(self, image_forward_outs):
25 | image_features = image_forward_outs.hidden_states[self.config.mm_vision_select_layer]
26 | if self.config.mm_vision_select_feature == 'patch':
27 | image_features = image_features[:, 1:]
28 | elif self.config.mm_vision_select_feature == 'cls_patch':
29 | image_features = image_features
30 | else:
31 | raise ValueError(f'Unexpected select feature: {self.config.mm_vision_select_feature}')
32 | return image_features
33 |
34 | def forward(self, images):
35 | if type(images) is list:
36 | image_features = []
37 | for image in images:
38 | image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
39 | output_hidden_states=True)
40 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
41 | image_features.append(image_feature)
42 | else:
43 | image_forward_outs = self.vision_model(images.to(device=self.device, dtype=self.dtype),
44 | output_hidden_states=True)
45 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
46 |
47 | return image_features
48 |
49 | @property
50 | def dummy_feature(self):
51 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
52 |
53 | @property
54 | def dtype(self):
55 | return list(self.vision_model.parameters())[0].dtype
56 |
57 | @property
58 | def device(self):
59 | return list(self.vision_model.parameters())[0].device
60 |
61 | @property
62 | def hidden_size(self):
63 | return self.config.hidden_size
64 |
65 | @property
66 | def num_patches(self):
67 | return (self.config.image_size // self.config.patch_size) ** 2
68 |
69 |
70 | if __name__ == '__main__':
71 | clip_config = CLIPVisionConfig.from_pretrained(
72 | "/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
73 | )
74 | print("################ clip_config ##############")
75 | print(clip_config)
76 | pythia_vis_config = LlavaPythiaVisionConfig(**clip_config.to_dict())
77 | print("################ pythia_vis_config ##############")
78 | print(pythia_vis_config)
79 |
80 | model = CLIPVisionTower(clip_config)
81 | # print(list(model.vision_model.parameters())[0].dtype)
82 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/multimodal_encoder/siglip_encoder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers.models.siglip import SiglipPreTrainedModel, SiglipVisionConfig
7 | from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
8 | from llava_pythia.model.language_model.pythia.configuration_llava_pythia import LlavaPythiaVisionConfig
9 |
10 |
11 | class SiglipVisionTower(SiglipPreTrainedModel):
12 | config_class = LlavaPythiaVisionConfig
13 |
14 | def __init__(self, config):
15 | super().__init__(config)
16 |
17 | self.vision_model = SiglipVisionTransformer(config)
18 | # Initialize weights and apply final processing
19 | self.post_init()
20 |
21 | def get_input_embeddings(self) -> nn.Module:
22 | return self.vision_model.embeddings.patch_embedding
23 |
24 | def feature_select(self, image_forward_outs):
25 | image_features = image_forward_outs.hidden_states[self.config.mm_vision_select_layer]
26 | if self.config.mm_vision_select_feature == 'patch':
27 | image_features = image_features
28 | elif self.config.mm_vision_select_feature == 'cls_patch':
29 | image_features = image_features
30 | else:
31 | raise ValueError(f'Unexpected select feature: {self.config.mm_vision_select_feature}')
32 | return image_features
33 |
34 | def forward(self, images):
35 | if type(images) is list:
36 | image_features = []
37 | for image in images:
38 | image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
39 | output_hidden_states=True)
40 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
41 | image_features.append(image_feature)
42 | else:
43 | image_forward_outs = self.vision_model(images.to(device=self.device, dtype=self.dtype),
44 | output_hidden_states=True)
45 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
46 |
47 | return image_features
48 |
49 | @property
50 | def dummy_feature(self):
51 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
52 |
53 | @property
54 | def dtype(self):
55 | return list(self.vision_model.parameters())[0].dtype
56 |
57 | @property
58 | def device(self):
59 | return list(self.vision_model.parameters())[0].device
60 |
61 | @property
62 | def hidden_size(self):
63 | return self.config.hidden_size
64 |
65 | @property
66 | def num_patches(self):
67 | return (self.config.image_size // self.config.patch_size) ** 2
68 |
69 |
70 | if __name__ == '__main__':
71 | clip_config = SiglipVisionConfig.from_pretrained(
72 | "/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
73 | )
74 | print("################ clip_config ##############")
75 | print(clip_config)
76 | pythia_vis_config = LlavaPythiaVisionConfig(**clip_config.to_dict())
77 | print("################ pythia_vis_config ##############")
78 | print(pythia_vis_config)
79 |
80 | model = SiglipVisionTower(clip_config)
81 | # print(list(model.vision_model.parameters())[0].dtype)
82 |
--------------------------------------------------------------------------------
/aloha_scripts/constants.py:
--------------------------------------------------------------------------------
1 |
2 | ##################### Setting of training data #####################################
3 |
4 | # DATA_DIR = '/path/to/your/data_dir'
5 |
6 | TASK_CONFIGS = {
7 | 'example_task_config': { # for local debug
8 | 'dataset_dir': [
9 | "/media/rl/HDD/data/data/act/new_view/8_29_tennis", # task 1
10 | ],
11 | 'episode_len': 1000, # 1000,
12 | 'camera_names': ['left', 'right', 'wrist'] # corresponding to image keys saved in h5py files
13 | },
14 | }
15 | ####################################################################################
16 |
17 | #!!!!!!!!!!!!!!!!!!!!!!Followings are copied from aloha which are not used!!!!!!!!!!!!!!!!!!!!!!
18 | ### ALOHA fixed constants
19 | DT = 0.02
20 |
21 | FPS = 50
22 |
23 | JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
24 | START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
25 |
26 | # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
27 | MASTER_GRIPPER_POSITION_OPEN = 0.02417
28 | MASTER_GRIPPER_POSITION_CLOSE = 0.01244
29 | PUPPET_GRIPPER_POSITION_OPEN = 0.05800
30 | PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
31 |
32 | # Gripper joint limits (qpos[6])
33 | MASTER_GRIPPER_JOINT_OPEN = 0.3083
34 | MASTER_GRIPPER_JOINT_CLOSE = -0.6842
35 | PUPPET_GRIPPER_JOINT_OPEN = 1.4910
36 | PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
37 |
38 | ############################ Helper functions ############################
39 |
40 | MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
41 | PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
42 | MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
43 | PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
44 | MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
45 |
46 | MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
47 | PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
48 | MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
49 | PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
50 | MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
51 |
52 | MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
53 | PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
54 |
55 | MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
56 | MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
57 | PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
58 | PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
59 |
60 | MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | # *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 | .idea
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | *.egg-info/
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112 | .pdm.toml
113 | .pdm-python
114 | .pdm-build/
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 | __pycache__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | robomimic-r2d2/
152 | robomimic-r2d2.zip
153 | OUTPUT/
154 | data/
155 | checkpoints/
156 | loglog/
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
167 | # PyCharm
168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | # and can be added to the global gitignore or merged into this file. For a more nuclear
171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | #.idea/
173 |
--------------------------------------------------------------------------------
/policy_heads/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | """
11 | Convert bounding boxes from center-size format (cx, cy, w, h) to corner format (x0, y0, x1, y1).
12 |
13 | Args:
14 | x: Tensor of shape (..., 4) representing bounding boxes in center-size format.
15 |
16 | Returns:
17 | Tensor of shape (..., 4) representing bounding boxes in corner format.
18 | """
19 | x_c, y_c, w, h = x.unbind(-1)
20 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
21 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
22 | return torch.stack(b, dim=-1)
23 |
24 |
25 | def box_xyxy_to_cxcywh(x):
26 | """
27 | Convert bounding boxes from corner format (x0, y0, x1, y1) to center-size format (cx, cy, w, h).
28 |
29 | Args:
30 | x: Tensor of shape (..., 4) representing bounding boxes in corner format.
31 |
32 | Returns:
33 | Tensor of shape (..., 4) representing bounding boxes in center-size format.
34 | """
35 | x0, y0, x1, y1 = x.unbind(-1)
36 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
37 | (x1 - x0), (y1 - y0)]
38 | return torch.stack(b, dim=-1)
39 |
40 |
41 | # modified from torchvision to also return the union
42 | def box_iou(boxes1, boxes2):
43 | """
44 | Compute the Intersection over Union (IoU) between two sets of boxes.
45 |
46 | Args:
47 | boxes1: Tensor of shape (N, 4) representing the first set of boxes.
48 | boxes2: Tensor of shape (M, 4) representing the second set of boxes.
49 |
50 | Returns:
51 | iou: Tensor of shape (N, M) representing pairwise IoU values.
52 | union: Tensor of shape (N, M) representing pairwise union areas.
53 | """
54 | area1 = box_area(boxes1)
55 | area2 = box_area(boxes2)
56 |
57 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
58 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
59 |
60 | wh = (rb - lt).clamp(min=0) # [N,M,2]
61 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
62 |
63 | union = area1[:, None] + area2 - inter
64 |
65 | iou = inter / union
66 | return iou, union
67 |
68 |
69 | def generalized_box_iou(boxes1, boxes2):
70 | """
71 | Compute the Generalized Intersection over Union (GIoU) between two sets of boxes.
72 |
73 | Args:
74 | boxes1: Tensor of shape (N, 4) representing the first set of boxes in [x0, y0, x1, y1] format.
75 | boxes2: Tensor of shape (M, 4) representing the second set of boxes in [x0, y0, x1, y1] format.
76 |
77 | Returns:
78 | Tensor of shape (N, M) representing pairwise GIoU values.
79 | """
80 | # degenerate boxes give inf / nan results, so do an early check
81 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
82 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
83 | iou, union = box_iou(boxes1, boxes2)
84 |
85 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
86 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
87 |
88 | wh = (rb - lt).clamp(min=0) # [N,M,2]
89 | area = wh[:, :, 0] * wh[:, :, 1]
90 |
91 | return iou - (area - union) / area
92 |
93 |
94 | def masks_to_boxes(masks):
95 | """
96 | Compute the bounding boxes around the provided masks.
97 |
98 | Args:
99 | masks: Tensor of shape (N, H, W) where N is the number of masks, (H, W) are the spatial dimensions.
100 |
101 | Returns:
102 | Tensor of shape (N, 4) with the boxes in xyxy format.
103 | """
104 | if masks.numel() == 0:
105 | return torch.zeros((0, 4), device=masks.device)
106 |
107 | h, w = masks.shape[-2:]
108 |
109 | y = torch.arange(0, h, dtype=torch.float)
110 | x = torch.arange(0, w, dtype=torch.float)
111 | y, x = torch.meshgrid(y, x)
112 |
113 | x_mask = (masks * x.unsqueeze(0))
114 | x_max = x_mask.flatten(1).max(-1)[0]
115 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
116 |
117 | y_mask = (masks * y.unsqueeze(0))
118 | y_max = y_mask.flatten(1).max(-1)[0]
119 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
120 |
121 | return torch.stack([x_min, y_min, x_max, y_max], 1)
122 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava_pythia.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.21.0
3 | aiofiles==23.2.1
4 | aiohappyeyeballs==2.4.4
5 | aiohttp==3.10.11
6 | aiosignal==1.3.1
7 | altair==5.3.0
8 | anyio==4.5.2
9 | appdirs==1.4.4
10 | argcomplete==3.3.0
11 | asttokens==2.4.1
12 | async-timeout==5.0.1
13 | attrs==23.2.0
14 | backcall==0.2.0
15 | beautifulsoup4==4.12.3
16 | bitsandbytes==0.41.0
17 | bleach==6.1.0
18 | cachetools==5.3.3
19 | catkin-pkg==1.0.0
20 | certifi==2024.2.2
21 | charset-normalizer==3.3.2
22 | click==8.1.8
23 | cloudpickle==3.0.0
24 | cmake==3.29.2
25 | colorama==0.3.0
26 | contourpy==1.1.1
27 | cycler==0.12.1
28 | decorator==5.1.1
29 | deepspeed==0.9.5
30 | defusedxml==0.7.1
31 | diffusers==0.11.1
32 | distro==1.9.0
33 | dm-control==1.0.14
34 | dm-env==1.6
35 | dm-tree==0.1.8
36 | docker-pycreds==0.4.0
37 | docopt==0.6.2
38 | docutils==0.20.1
39 | egl-probe==1.0.2
40 | einops==0.6.1
41 | einops-exts==0.0.4
42 | evdev==1.7.0
43 | exceptiongroup==1.2.2
44 | executing==2.0.1
45 | fastapi==0.110.2
46 | fastjsonschema==2.20.0
47 | ffmpy==0.3.2
48 | filelock==3.16.1
49 | fonttools==4.51.0
50 | frozenlist==1.5.0
51 | fsspec==2025.2.0
52 | gitdb==4.0.12
53 | GitPython==3.1.44
54 | glfw==2.7.0
55 | google-auth==2.29.0
56 | google-auth-oauthlib==1.0.0
57 | gradio==3.35.2
58 | gradio_client==0.2.9
59 | grpcio==1.62.2
60 | gym==0.26.2
61 | gym-notices==0.0.8
62 | h11==0.14.0
63 | h5py==3.11.0
64 | hjson==3.1.0
65 | httpcore==0.17.3
66 | httpx==0.24.0
67 | huggingface-hub==0.23.5
68 | idna==3.7
69 | imageio==2.34.1
70 | imageio-ffmpeg==0.4.9
71 | importlib_metadata==8.5.0
72 | importlib_resources==6.4.5
73 | ipython==8.12.3
74 | jedi==0.19.1
75 | Jinja2==3.1.4
76 | joblib==1.4.0
77 | jsonschema==4.21.1
78 | jsonschema-specifications==2023.12.1
79 | jupyter_client==8.6.3
80 | jupyter_core==5.7.2
81 | jupyterlab_pygments==0.3.0
82 | kiwisolver==1.4.5
83 | labmaze==1.0.6
84 | linkify-it-py==2.0.3
85 | lit==18.1.3
86 | llvmlite==0.41.1
87 | lxml==5.2.1
88 | Markdown==3.6
89 | markdown-it-py==2.2.0
90 | markdown2==2.4.13
91 | MarkupSafe==2.1.5
92 | matplotlib==3.7.5
93 | matplotlib-inline==0.1.7
94 | mdit-py-plugins==0.3.3
95 | mdurl==0.1.2
96 | mistune==3.0.2
97 | mpmath==1.3.0
98 | mujoco==2.3.7
99 | multidict==6.1.0
100 | nbclient==0.10.0
101 | nbconvert==7.16.4
102 | nbformat==5.10.4
103 | networkx==3.1
104 | ninja==1.11.1.1
105 | numba==0.58.1
106 | numpy==1.24.4
107 | nvidia-cublas-cu11==11.10.3.66
108 | nvidia-cuda-cupti-cu11==11.7.101
109 | nvidia-cuda-nvrtc-cu11==11.7.99
110 | nvidia-cuda-runtime-cu11==11.7.99
111 | nvidia-cudnn-cu11==8.5.0.96
112 | nvidia-cufft-cu11==10.9.0.58
113 | nvidia-curand-cu11==10.2.10.91
114 | nvidia-cusolver-cu11==11.4.0.1
115 | nvidia-cusparse-cu11==11.7.4.91
116 | nvidia-nccl-cu11==2.14.3
117 | nvidia-nvtx-cu11==11.7.91
118 | oauthlib==3.2.2
119 | opencv-python==4.6.0.66
120 | orjson==3.10.1
121 | packaging==24.0
122 | pandas==2.0.3
123 | pandocfilters==1.5.1
124 | parso==0.8.4
125 | peft==0.4.0
126 | pexpect==4.9.0
127 | pickleshare==0.7.5
128 | pillow==10.3.0
129 | pipreqs==0.5.0
130 | pkgutil_resolve_name==1.3.10
131 | platformdirs==4.3.6
132 | prompt_toolkit==3.0.47
133 | propcache==0.2.0
134 | protobuf==3.19.6
135 | psutil==6.1.1
136 | ptyprocess==0.7.0
137 | pure-eval==0.2.2
138 | py-cpuinfo==9.0.0
139 | pyasn1==0.6.0
140 | pyasn1_modules==0.4.0
141 | pydantic==1.10.15
142 | pydub==0.25.1
143 | Pygments==2.17.2
144 | pynput==1.7.6
145 | PyOpenGL==3.1.7
146 | pyparsing==3.1.4
147 | pyquaternion==0.9.9
148 | python-dateutil==2.9.0.post0
149 | python-multipart==0.0.9
150 | python-xlib==0.33
151 | pytz==2024.1
152 | PyYAML==6.0.1
153 | pyzmq==26.2.0
154 | referencing==0.34.0
155 | regex==2024.4.16
156 | requests==2.31.0
157 | requests-oauthlib==2.0.0
158 | rospkg==1.5.1
159 | rpds-py==0.18.0
160 | rsa==4.9
161 | safetensors==0.4.3
162 | scikit-learn==1.2.2
163 | scipy==1.10.1
164 | semantic-version==2.10.0
165 | sentencepiece==0.1.99
166 | sentry-sdk==1.45.0
167 | setproctitle==1.3.3
168 | shortuuid==1.0.13
169 | six==1.16.0
170 | smmap==5.0.2
171 | sniffio==1.3.1
172 | snowballstemmer==2.2.0
173 | soupsieve==2.6
174 | stack-data==0.6.3
175 | starlette==0.37.2
176 | svgwrite==1.4.3
177 | sympy==1.12
178 | tensorboard==2.14.0
179 | tensorboard-data-server==0.7.2
180 | tensorboardX==2.6
181 | termcolor==2.4.0
182 | threadpoolctl==3.4.0
183 | tianshou==0.4.10
184 | timm==0.6.13
185 | tinycss2==1.3.0
186 | tokenizers==0.15.0
187 | toolz==0.12.1
188 | torch==2.0.1
189 | torchvision==0.15.2
190 | tornado==6.4.1
191 | tqdm==4.67.1
192 | traitlets==5.14.3
193 | transformers==4.37.1
194 | triton==2.0.0
195 | typing_extensions==4.11.0
196 | tzdata==2025.1
197 | uc-micro-py==1.0.3
198 | urllib3==2.2.3
199 | uvicorn==0.29.0
200 | wandb==0.16.6
201 | wavedrom==2.0.3.post3
202 | wcwidth==0.2.13
203 | webencodings==0.5.1
204 | websockets==13.1
205 | Werkzeug==3.0.2
206 | yarg==0.1.9
207 | yarl==1.15.2
208 | zipp==3.20.2
209 |
--------------------------------------------------------------------------------
/policy_heads/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 |
9 | from policy_heads.util.misc import NestedTensor
10 |
11 | import IPython
12 | e = IPython.embed
13 |
14 | class PositionEmbeddingSine(nn.Module):
15 | """
16 | This is a more standard version of the position embedding, very similar to the one
17 | used by the Attention is all you need paper, generalized to work on images.
18 | """
19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
20 | super().__init__()
21 | self.num_pos_feats = num_pos_feats
22 | self.temperature = temperature
23 | self.normalize = normalize
24 | if scale is not None and normalize is False:
25 | raise ValueError("normalize should be True if scale is passed")
26 | if scale is None:
27 | scale = 2 * math.pi
28 | self.scale = scale
29 |
30 | def forward(self, tensor):
31 | x = tensor
32 | # mask = tensor_list.mask
33 | # assert mask is not None
34 | # not_mask = ~mask
35 |
36 | not_mask = torch.ones_like(x[0, [0]])
37 | y_embed = not_mask.cumsum(1, dtype=tensor.dtype)
38 | x_embed = not_mask.cumsum(2, dtype=tensor.dtype)
39 | if self.normalize:
40 | eps = 1e-6
41 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
42 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
43 |
44 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
45 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
46 |
47 | pos_x = x_embed[:, :, :, None] / dim_t
48 | pos_y = y_embed[:, :, :, None] / dim_t
49 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
50 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52 | return pos
53 |
54 |
55 | class PositionEmbeddingLearned(nn.Module):
56 | """
57 | Computes learned absolute positional embeddings for the input tensor.
58 |
59 | This method generates position embeddings based on the spatial dimensions
60 | (height and width) of the input tensor. The embeddings are learned through
61 | trainable `nn.Embedding` layers for both row (height) and column (width) indices.
62 | """
63 | def __init__(self, num_pos_feats=256):
64 | super().__init__()
65 | self.row_embed = nn.Embedding(50, num_pos_feats)
66 | self.col_embed = nn.Embedding(50, num_pos_feats)
67 | self.reset_parameters()
68 |
69 | def reset_parameters(self):
70 | nn.init.uniform_(self.row_embed.weight)
71 | nn.init.uniform_(self.col_embed.weight)
72 |
73 | def forward(self, tensor_list: NestedTensor):
74 | x = tensor_list.tensors
75 | h, w = x.shape[-2:]
76 | i = torch.arange(w, device=x.device)
77 | j = torch.arange(h, device=x.device)
78 | x_emb = self.col_embed(i)
79 | y_emb = self.row_embed(j)
80 | pos = torch.cat([
81 | x_emb.unsqueeze(0).repeat(h, 1, 1),
82 | y_emb.unsqueeze(1).repeat(1, w, 1),
83 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
84 | return pos
85 |
86 | def position_encoding_1d(x):
87 | """
88 | Generates 1-dimensional positional encoding.
89 |
90 | Args:
91 | seq_len (int): Length of the input sequence.
92 | d_model (int): Dimension of the model.
93 |
94 | Returns:
95 | np.array: 2D array of shape (seq_len, d_model) containing positional encodings.
96 | """
97 | seq_len, d_model = x.shape[1:]
98 | pos_enc = torch.zeros((seq_len, d_model), dtype=x.dtype)
99 | position = torch.arange(0, seq_len, dtype=x.dtype).cuda().unsqueeze(1)
100 |
101 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=x.dtype) * (-math.log(10000.0) / d_model)).cuda())
102 |
103 | pos_enc[:, 0::2] = torch.sin(position * div_term)
104 | pos_enc[:, 1::2] = torch.cos(position * div_term)
105 |
106 | return pos_enc.to(x.device)
107 |
108 | def build_position_encoding(args):
109 | N_steps = args.hidden_dim // 2
110 | if args.position_embedding in ('v2', 'sine'):
111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
112 | elif args.position_embedding in ('v3', 'learned'):
113 | position_embedding = PositionEmbeddingLearned(N_steps)
114 | else:
115 | raise ValueError(f"not supported {args.position_embedding}")
116 |
117 | return position_embedding
--------------------------------------------------------------------------------
/policy_heads/util/plot_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Plotting utilities to visualize training logs.
3 | """
4 | import torch
5 | import pandas as pd
6 | import numpy as np
7 | import seaborn as sns
8 | import matplotlib.pyplot as plt
9 |
10 | from pathlib import Path, PurePath
11 |
12 |
13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
14 | '''
15 | Function to plot specific fields from training log(s). Plots both training and test results.
16 |
17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
18 | - fields = which results to plot from each log file - plots both training and test for each field.
19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
20 | - log_name = optional, name of log file if different than default 'log.txt'.
21 |
22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file.
23 | - solid lines are training results, dashed lines are test results.
24 |
25 | '''
26 | func_name = "plot_utils.py::plot_logs"
27 |
28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
29 | # convert single Path to list to avoid 'not iterable' error
30 |
31 | if not isinstance(logs, list):
32 | if isinstance(logs, PurePath):
33 | logs = [logs]
34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
35 | else:
36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
37 | Expect list[Path] or single Path obj, received {type(logs)}")
38 |
39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
40 | for i, dir in enumerate(logs):
41 | if not isinstance(dir, PurePath):
42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
43 | if not dir.exists():
44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
45 | # verify log_name exists
46 | fn = Path(dir / log_name)
47 | if not fn.exists():
48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
49 | print(f"--> full path of missing log file: {fn}")
50 | return
51 |
52 | # load log file(s) and plot
53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
54 |
55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
56 |
57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
58 | for j, field in enumerate(fields):
59 | if field == 'mAP':
60 | coco_eval = pd.DataFrame(
61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
62 | ).ewm(com=ewm_col).mean()
63 | axs[j].plot(coco_eval, c=color)
64 | else:
65 | df.interpolate().ewm(com=ewm_col).mean().plot(
66 | y=[f'train_{field}', f'test_{field}'],
67 | ax=axs[j],
68 | color=[color] * 2,
69 | style=['-', '--']
70 | )
71 | for ax, field in zip(axs, fields):
72 | ax.legend([Path(p).name for p in logs])
73 | ax.set_title(field)
74 |
75 |
76 | def plot_precision_recall(files, naming_scheme='iter'):
77 | if naming_scheme == 'exp_id':
78 | # name becomes exp_id
79 | names = [f.parts[-3] for f in files]
80 | elif naming_scheme == 'iter':
81 | names = [f.stem for f in files]
82 | else:
83 | raise ValueError(f'not supported {naming_scheme}')
84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
86 | data = torch.load(f)
87 | # precision is n_iou, n_points, n_cat, n_area, max_det
88 | precision = data['precision']
89 | recall = data['params'].recThrs
90 | scores = data['scores']
91 | # take precision for all classes, all areas and 100 detections
92 | precision = precision[0, :, :, 0, -1].mean(1)
93 | scores = scores[0, :, :, 0, -1].mean(1)
94 | prec = precision.mean()
95 | rec = data['recall'][0, :, 0, -1].mean()
96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
97 | f'score={scores.mean():0.3f}, ' +
98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
99 | )
100 | axs[0].plot(recall, precision, c=color)
101 | axs[1].plot(recall, scores, c=color)
102 |
103 | axs[0].set_title('Precision / Recall')
104 | axs[0].legend(names)
105 | axs[1].set_title('Scores / Recall')
106 | axs[1].legend(names)
107 | return fig, axs
108 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from llava_pythia.constants import IMAGE_TOKEN_INDEX
8 |
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 |
14 | def expand2square(pil_img, background_color):
15 | width, height = pil_img.size
16 | if width == height:
17 | return pil_img
18 | elif width > height:
19 | result = Image.new(pil_img.mode, (width, width), background_color)
20 | result.paste(pil_img, (0, (width - height) // 2))
21 | return result
22 | else:
23 | result = Image.new(pil_img.mode, (height, height), background_color)
24 | result.paste(pil_img, ((height - width) // 2, 0))
25 | return result
26 |
27 |
28 | def process_images(images, image_processor, model_cfg):
29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30 | new_images = []
31 | if image_aspect_ratio == 'pad':
32 | for image in images:
33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35 | new_images.append(image)
36 | else:
37 | return image_processor(images, return_tensors='pt')['pixel_values']
38 | if all(x.shape == new_images[0].shape for x in new_images):
39 | new_images = torch.stack(new_images, dim=0)
40 | return new_images
41 |
42 |
43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44 | """
45 | Tokenizes a prompt string that may contain image placeholders and returns the tokenized input IDs.
46 |
47 | Args:
48 | prompt (str): The input string containing text and '' placeholders.
49 | tokenizer: The tokenizer object used to tokenize the text.
50 | image_token_index (int, optional): The token index used to represent the '' placeholder. Defaults to IMAGE_TOKEN_INDEX.
51 | return_tensors (str, optional): If specified, returns the tokenized input as a tensor of the specified type ('pt' for PyTorch). Defaults to None.
52 |
53 | Returns:
54 | list or torch.Tensor: The tokenized input IDs. If `return_tensors` is specified as 'pt', returns a PyTorch tensor; otherwise, returns a list of input IDs.
55 | """
56 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
57 | # print("####"*100)
58 | # attention = [tokenizer(chunk).attention_mask for chunk in prompt.split('')]
59 | def insert_separator(X, sep):
60 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
61 |
62 | input_ids = []
63 | offset = 0
64 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
65 | offset = 1
66 | input_ids.append(prompt_chunks[0][0])
67 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
68 | input_ids.extend(x[offset:])
69 |
70 | if return_tensors is not None:
71 | if return_tensors == 'pt':
72 | return torch.tensor(input_ids, dtype=torch.long)
73 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
74 | return input_ids
75 |
76 |
77 | def get_model_name_from_path(model_path):
78 | model_path = model_path.strip("/")
79 | model_paths = model_path.split("/")
80 | if model_paths[-1].startswith('checkpoint-'):
81 | return model_paths[-2] + "_" + model_paths[-1]
82 | else:
83 | return model_paths[-1]
84 |
85 |
86 | class KeywordsStoppingCriteria(StoppingCriteria):
87 | def __init__(self, keywords, tokenizer, input_ids):
88 | self.keywords = keywords
89 | self.keyword_ids = []
90 | for keyword in keywords:
91 | cur_keyword_ids = tokenizer(keyword).input_ids
92 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
93 | cur_keyword_ids = cur_keyword_ids[1:]
94 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
95 | self.tokenizer = tokenizer
96 | self.start_len = input_ids.shape[1]
97 |
98 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
99 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
100 | offset = min(output_ids.shape[1] - self.start_len, 3)
101 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
102 | for keyword_id in self.keyword_ids:
103 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
104 | return True
105 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
106 | for keyword in self.keywords:
107 | if keyword in outputs:
108 | return True
109 | return False
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | TinyVLA: Towards Fast, Data-Efficient Vision-Language-Action Models
3 | for Robotic Manipulation
4 |
5 |
6 | * **TinyVLA: Towards Fast, Data-Efficient Vision-Language-Action Modelsfor Robotic Manipulation**
7 | [](https://arxiv.org/abs/2409.12514)
8 |
9 |
10 |
11 | ## 📰 News
12 | * **`Feb. 17th, 2025`**: 🔥🔥🔥Our code is released!
13 | * **`Feb. 9th, 2025`**: 🔥🔥🔥**TinyVLA** is accepted by IEEE Robotics and Automation Letters (RA-L) 2025!
14 | * **`Nov. 19th, 2024`**: **TinyVLA** is out! **Paper** can be found [here](https://arxiv.org/abs/2409.12514). The **project web** can be found [here](https://tiny-vla.github.io/).
15 |
16 | ## Contents
17 | - [📰 News](#-news)
18 | - [Contents](#contents)
19 | - [Install](#install)
20 | - [Data Preparation](#data-preparation)
21 | - [Download Pretrained VLM](#download-pretrained-vlm)
22 | - [Train](#train)
23 | - [Evaluation](#evaluation)
24 | - [Acknowledgement](#acknowledgement)
25 | - [Citation](#citation)
26 |
27 | ## Install
28 |
29 | 1. Clone this repository and navigate to diffusion-vla folder
30 | ```bash
31 | git clone https://github.com/liyaxuanliyaxuan/TinyVLA
32 | ```
33 |
34 | 2. Install Package
35 | ```Shell
36 | conda create -n tinyvla python=3.10 -y
37 | conda activate tinyvla
38 | pip install --upgrade pip #
39 | pip install -r requirements.txt
40 | cd policy_heads
41 | pip install -e .
42 | # install llava-pythia
43 | cd ../llava-pythia
44 | pip install -e .
45 | ```
46 |
47 | ## Data Preparation
48 | 1. Our data format is the same as [act](https://github.com/MarkFzp/act-plus-plus), so you need to transfer your data into h5py format. You can refer to the [rlds_to_h5py.py](https://github.com/lesjie-wen/tinyvla/blob/main/data_utils/rlds_to_h5py.py) which is used to transfer the data from rlds format to h5py format.
49 | ```angular2html
50 | # h5 data structure
51 | root
52 | |-action (100,10)
53 | |-language_raw (1,)
54 | |-observations
55 | |-images # multi-view
56 | |-left (100,480,640,3)
57 | |-right (100,480,640,3)
58 | |-wrist (100,480,640,3)
59 | |-joint_positions (100,7)
60 | |-qpos (100,7)
61 | |-qvel (100,7)
62 | ```
63 | 2. You have to add one entry in [constants.py](https://github.com/lesjie-wen/tinyvla/blob/main/aloha_scripts/constants.py) to specify the path of your data as follows.
64 | ```python
65 | 'your_task_name':{
66 | 'dataset_dir': DATA_DIR + '/your_task_path', # define the path of the dataset
67 | 'episode_len': 1000, #max length of the episode,
68 | 'camera_names': ['front', 'wrist'] # define the camera names which are used as the key when reading data
69 | }
70 | ```
71 | ## Download Pretrained VLM
72 | We construct the VLM backbone by integrating a series of tiny LLM([Pythia](https://github.com/EleutherAI/pythia)) into [Llava](https://github.com/haotian-liu/LLaVA) framework. We follow the standard training pipe line and data provided by [Llava](https://github.com/haotian-liu/LLaVA). All the weights of VLM used in our paper are listed as following:
73 |
74 | | Model | Usage | Link |
75 | |---------------------|---------------|----------------------------------------------------------------|
76 | | Llava-Pythia(~400M) | For TinyVLA-S | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-400M) |
77 | | Llava-Pythia(~700M) | For TinyVLA-B | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-700M) |
78 | | Llava-Pythia(~1.3B) | For TinyVLA-H | [huggingface](https://huggingface.co/lesjie/Llava-Pythia-1.3B) |
79 |
80 |
81 | ## Train
82 | The training script is "scripts/train.sh". And you need to change following parameters:
83 | 1. **OUTPUT** :refers to the save directory for training, which must include the keyword "llava_pythia" (and optionally "lora"). If LoRA training is used, the name must include "lora" (e.g., "llava_pythia_lora").
84 | 2. **task_name** :refers to the tasks used for training, which should be corresponded to "your_task_name" in aloha_scripts/constant.py
85 | 3. **model_name_or_path** :path to the pretrained VLM weights
86 | 4. Other hyperparameters like "batch_size", "save_steps" could be customized according to your computation resources.
87 |
88 | Start training by following commands:
89 | ```shell
90 | ./scripts/train.sh
91 | ```
92 |
93 | ## Evaluation
94 | Before evaluation, we provide a post process script to generate a usable and smaller weights.
95 | The process script is "scripts/process_ckpts.sh". And you need to change following parameters:
96 | 1. **source_dir** :path to trained VLA dir equals to **OUTPUT** in train.sh
97 | 2. **target_dir** :path to save processed VLA weights
98 |
99 | You can refer to our evaluation script [eval_real_franka.py](https://github.com/lesjie-wen/tinyvla/blob/main/eval_real_franka.py).
100 | ## Acknowledgement
101 | We build our project based on:
102 | - [LLaVA](https://github.com/haotian-liu/LLaVA): an amazing open-sourced project for vision language assistant
103 | - [act-plus-plus](https://github.com/haotian-liu/LLaVA): an amazing open-sourced project for robotics visuomotor learning
104 | - [Miphi](https://github.com/zhuyiche/llava-phi): an amazing open-sourced project for tiny vision language model
105 |
106 | ## Citation
107 |
108 | If you find Tiny-VLA useful for your research and applications, please cite using this BibTeX:
109 | ```bibtex
110 | @misc{
111 | @inproceedings{wen2024tinyvla,
112 | title={Tinyvla: Towards fast, data-efficient vision-language-action models for robotic manipulation},
113 | author={Wen, Junjie and Zhu, Yichen and Li, Jinming and Zhu, Minjie and Wu, Kun and Xu, Zhiyuan and Liu, Ning and Cheng, Ran and Shen, Chaomin and Peng, Yaxin and others},
114 | booktitle={IEEE Robotics and Automation Letters (RA-L)},
115 | year={2025}
116 | }
117 | ```
118 |
119 |
120 |
--------------------------------------------------------------------------------
/policy_heads/models/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Backbone modules.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torch import nn
11 | from torchvision.models._utils import IntermediateLayerGetter
12 | from typing import Dict, List
13 |
14 | from policy_heads.util.misc import is_main_process, NestedTensor
15 | from .position_encoding import build_position_encoding
16 |
17 | import IPython
18 | e = IPython.embed
19 |
20 | class FrozenBatchNorm2d(torch.nn.Module):
21 | """
22 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
23 |
24 | This implementation is a copy-paste from torchvision.misc.ops with added eps before rsqrt,
25 | without which any other models than torchvision.models.resnet[18,34,50,101] produce NaNs.
26 | """
27 | def __init__(self, n):
28 | super(FrozenBatchNorm2d, self).__init__()
29 | self.register_buffer("weight", torch.ones(n))
30 | self.register_buffer("bias", torch.zeros(n))
31 | self.register_buffer("running_mean", torch.zeros(n))
32 | self.register_buffer("running_var", torch.ones(n))
33 |
34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
35 | missing_keys, unexpected_keys, error_msgs):
36 | # remove num_batches_tracked from state_dict if present
37 | num_batches_tracked_key = prefix + 'num_batches_tracked'
38 | if num_batches_tracked_key in state_dict:
39 | del state_dict[num_batches_tracked_key]
40 |
41 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
42 | state_dict, prefix, local_metadata, strict,
43 | missing_keys, unexpected_keys, error_msgs)
44 |
45 | def forward(self, x):
46 | """
47 | Forward pass for the frozen batch normalization.
48 |
49 | Args:
50 | x: Input tensor.
51 |
52 | Returns:
53 | Normalized tensor.
54 | """
55 | # move reshapes to the beginning to make it fuser-friendly
56 | w = self.weight.reshape(1, -1, 1, 1)
57 | b = self.bias.reshape(1, -1, 1, 1)
58 | rv = self.running_var.reshape(1, -1, 1, 1)
59 | rm = self.running_mean.reshape(1, -1, 1, 1)
60 | eps = 1e-5
61 | scale = w * (rv + eps).rsqrt()
62 | bias = b - rm * scale
63 | return x * scale + bias
64 |
65 |
66 | class BackboneBase(nn.Module):
67 | """
68 | Base class for backbone networks.
69 |
70 | Args:
71 | backbone: The backbone model.
72 | train_backbone: Whether to train the backbone.
73 | num_channels: Number of output channels.
74 | return_interm_layers: Whether to return intermediate layers.
75 | """
76 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
77 | super().__init__()
78 | # determine which layers to return
79 | if return_interm_layers:
80 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
81 | else:
82 | return_layers = {'layer4': "0"}
83 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
84 | self.num_channels = num_channels
85 |
86 | def forward(self, tensor):
87 | """
88 | Forward pass for the backbone.
89 |
90 | Args:
91 | tensor: Input tensor.
92 |
93 | Returns:
94 | Dictionary of feature maps from the specified layers.
95 | """
96 | xs = self.body(tensor)
97 | return xs
98 |
99 |
100 | class Backbone(BackboneBase):
101 | """
102 | ResNet backbone with frozen BatchNorm.
103 |
104 | Args:
105 | name: Name of the ResNet model.
106 | train_backbone: Whether to train the backbone.
107 | return_interm_layers: Whether to return intermediate layers.
108 | dilation: Whether to use dilation in the last block.
109 | """
110 | def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
111 | backbone = getattr(torchvision.models, name)(
112 | replace_stride_with_dilation=[False, False, dilation],
113 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # use pretrained model
114 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
115 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
116 |
117 |
118 | class Joiner(nn.Sequential):
119 | """
120 | Combines a backbone and a position encoding module.
121 |
122 | Args:
123 | backbone: The backbone model.
124 | position_embedding: The position encoding module.
125 | """
126 | def __init__(self, backbone, position_embedding):
127 | super().__init__(backbone, position_embedding)
128 |
129 | def forward(self, tensor_list: NestedTensor):
130 | """
131 | Forward pass for the joiner.
132 |
133 | Args:
134 | tensor_list: NestedTensor containing input data and mask.
135 |
136 | Returns:
137 | Tuple of feature maps and position encodings.
138 | """
139 | xs = self[0](tensor_list)
140 | out: List[NestedTensor] = []
141 | pos = []
142 | for name, x in xs.items():
143 | out.append(x)
144 | # position encoding
145 | pos.append(self[1](x).to(x.dtype))
146 |
147 | return out, pos
148 |
149 |
150 | def build_backbone(args):
151 | """
152 | Builds the backbone model with position encoding.
153 |
154 | Args:
155 | args: Arguments containing configuration for the backbone.
156 |
157 | Returns:
158 | A model combining the backbone and position encoding.
159 | """
160 | position_embedding = build_position_encoding(args)
161 | train_backbone = args.lr_backbone > 0
162 | return_interm_layers = args.masks
163 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
164 | model = Joiner(backbone, position_embedding)
165 | model.num_channels = backbone.num_channels
166 | return model
167 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/tranfer2llava.py:
--------------------------------------------------------------------------------
1 | # Convert the R3M-formatted Franka kitchen data to the LLaVA format
2 | output_path = "/media/rl/HDD/data/data/franka_kitchen/frankakitchen_llava"
3 | import pickle
4 | import json
5 | import os
6 | from PIL import Image
7 | from tqdm import tqdm
8 | import numpy as np
9 | TASK = {
10 | 'kitchen_sdoor_open-v3':"Open the sliding door",
11 | 'kitchen_micro_open-v3':"Open the microwave oven",
12 | 'kitchen_light_on-v3':"Toggle the light switch",
13 | 'kitchen_ldoor_open-v3':"Open the cabinet door",
14 | 'kitchen_knob1_on-v3':"Rotate the round stovetop knob"
15 | }
16 |
17 | def franka_kitchen2llava_format(data_path:str, ratio:float, view):
18 | os.makedirs(os.path.join(output_path, view), exist_ok=True)
19 | os.makedirs(os.path.join(output_path, view, 'images'), exist_ok=True)
20 |
21 | pickle_path = os.listdir(data_path)
22 | pickle_path = [p for p in pickle_path if p.endswith('.pickle')]
23 | all_task_demo_paths = []
24 |
25 | t_bar = tqdm(total=1000 * ratio)
26 |
27 | try:
28 | with open(os.path.join(data_path, f'mean_var_{view}.json'), 'r') as f:
29 | mean_std = json.load(f)
30 | except:
31 | mean_std = {'action':{},'state':{}}
32 |
33 |
34 | data_processed = []
35 | eval_data = []
36 | train_data = []
37 | for p in pickle_path:
38 | cur_data = []
39 |
40 | if p.split('/')[-1] not in mean_std['action'].keys():
41 | mean_std['action'][p.split('/')[-1]] = {}
42 | mean_std['state'][p.split('/')[-1]] = {}
43 |
44 | action_all = []
45 | state_all = []
46 | # print(p.split('.')[0])
47 | max_action = [0] * 9
48 | min_action = [1000000] * 9
49 | max_state = [0] * 9
50 | min_state = [1000000] * 9
51 |
52 | demo_paths_loc = os.path.join(data_path, p)
53 | demo_paths = pickle.load(open(demo_paths_loc, 'rb'))
54 | all_task_demo_paths += demo_paths[:int(ratio*200)]
55 | # print(all_task_demo_paths[0].keys())
56 | # print(all_task_demo_paths[0]['actions'].shape)
57 | for idx, each in enumerate(demo_paths[:int(ratio*200)]): # trajectory id
58 | traj_len = each['images'].shape[0]
59 |
60 | # normalize action and state
61 | m_a,v_a = np.array([mean_std['action'][p.split('/')[-1]]['mean']]), np.array([mean_std['action'][p.split('/')[-1]]['var']])
62 | m_s,v_s = np.array([mean_std['state'][p.split('/')[-1]]['mean']]), np.array([mean_std['state'][p.split('/')[-1]]['var']])
63 | each['actions'] = (each['actions'] - m_a) /np.sqrt(v_a)
64 | # print(each['observations'][:,:9].shape)
65 | each['observations'][:,:9] = (each['observations'][:,:9] - m_s) / np.sqrt(v_s)
66 | # #############################
67 | for i in range(traj_len): # frame id
68 | t = {
69 | "id": "",
70 | "image": "",
71 | 'state': [],
72 | 'action': [],
73 | "conversations": [{"from": "human", "value": "\n"}, {"from": "gpt", "value": " "}]
74 | }
75 | img_p = os.path.join(output_path, view, 'images', f'{p.split(".")[0]}_{idx}_{i}.png')
76 | # print(each['images'].shape)
77 | # break
78 | if not os.path.exists(img_p):
79 | Image.fromarray(each['images'][i]).save(img_p)
80 | t['image'] = img_p
81 | t['id'] = img_p.split('/')[-1]
82 | t["conversations"][0]["value"] += TASK[p.split('.')[0]]
83 | t['state'] = each['observations'][i].tolist()
84 | t['action'] = each['actions'][i].tolist()
85 | # print(t['action'])
86 | ########################################查看一下数据范围
87 | # for j,a in enumerate(zip(t['action'])):
88 | # if max_action[j] < a:
89 | # max_action[j] = a
90 | # if min_action[j] > a:
91 | # min_action[j] = a
92 | # for j,a in enumerate(t['state'][:9]):
93 | # if max_state[j] < a:
94 | # max_state[j] = a
95 | # if min_state[j] > a:
96 | # min_state[j] = a
97 | action_all.append(t['action'])
98 | state_all.append(t['state'][:9])
99 |
100 | data_processed.append(t)
101 | cur_data.append(t)
102 |
103 | t_bar.update(1)
104 | # print(t)
105 | # break
106 | # break
107 | # print(p.split('/')[-1])
108 | # print(max_action,min_action)
109 | # print(max_state,min_state)
110 |
111 | mean_action = np.mean(np.array(action_all), axis=0)
112 | var_action = np.var(np.array(action_all), axis=0)
113 |
114 | mean_state = np.mean(np.array(state_all), axis=0)
115 | var_state = np.var(np.array(state_all), axis=0)
116 |
117 | mean_std['action'][p.split('/')[-1]]['mean'] = mean_action.tolist()
118 | mean_std['action'][p.split('/')[-1]]['var'] = var_action.tolist()
119 |
120 | mean_std['state'][p.split('/')[-1]]['mean'] = mean_state.tolist()
121 | mean_std['state'][p.split('/')[-1]]['var'] = var_state.tolist()
122 | eval_data += cur_data[int(200*50*0.9):]
123 | train_data += cur_data[:int(200*50*0.9)]
124 | # print("action")
125 | # print(mean_action, var_action)
126 |
127 | # print("state")
128 | # print(mean_state, var_state)
129 | # print(mean_std)
130 |
131 | # with open(f'mean_var_{view}.json', 'w') as f:
132 | # json.dump(mean_std,f)
133 |
134 | # with open('action_all.txt', 'w') as f:
135 | # for a in action_all:
136 | # f.write(str(a) + '\n')
137 | # with open('state_all.txt', 'w') as f:
138 | # for s in state_all:
139 | # f.write(str(s) + '\n')
140 |
141 | # with open(os.path.join(output_path, view, f"std_{view}_50k.json"), "w") as f:
142 | # json.dump(data_processed, f, indent=4)
143 | print(len(train_data), len(eval_data))
144 | with open(os.path.join(output_path, view, f"std_eval_{view}_50k.json"), "w") as f:
145 | json.dump(eval_data, f, indent=4)
146 | with open(os.path.join(output_path, view, f"std_train_{view}_50k.json"), "w") as f:
147 | json.dump(train_data, f, indent=4)
148 |
149 | for view in ['default', 'left_cap2', 'right_cap2']:
150 | data_path = f"/media/rl/HDD/data/data/franka_kitchen/FrankaKitchen/{view}"
151 | franka_kitchen2llava_format(data_path, 1, view)
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import shutil
4 |
5 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor, SiglipImageProcessor, \
6 | GPTNeoXModel, GPTNeoXPreTrainedModel
7 | import torch
8 | from llava_pythia.model import *
9 | from llava_pythia.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 |
11 |
12 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda"):
13 | """
14 | Loads a pretrained model with optional quantization and device mapping.
15 |
16 | Args:
17 | - model_path (str): Path to the model directory or file.
18 | - model_base (str): Base model path, used when loading LoRA models.
19 | - model_name (str): Name of the model to load.
20 | - load_8bit (bool): Whether to load the model in 8-bit precision.
21 | - load_4bit (bool): Whether to load the model in 4-bit precision.
22 | - device_map (str): Device map for model loading, default is "cuda".
23 | - device (str): Device to load the model onto, default is "cuda".
24 |
25 | Returns:
26 | - tokenizer: The tokenizer associated with the model.
27 | - model: The loaded model.
28 | - image_processor: The image processor if applicable.
29 | - context_len (int): The context length of the model.
30 | """
31 | kwargs = {"device_map": device_map}
32 | if load_8bit:
33 | kwargs['load_in_8bit'] = True
34 | elif load_4bit:
35 | kwargs['load_in_4bit'] = True
36 | kwargs['quantization_config'] = BitsAndBytesConfig(
37 | load_in_4bit=True,
38 | bnb_4bit_compute_dtype=torch.float16,
39 | bnb_4bit_use_double_quant=True,
40 | bnb_4bit_quant_type='nf4'
41 | )
42 | else:
43 | kwargs['torch_dtype'] = torch.float16
44 |
45 | if 'pythia' in model_name.lower():
46 | # Load LLaVA-Phi model
47 | if 'lora' in model_name.lower() and model_base is None:
48 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
49 | if 'lora' in model_name.lower() and model_base is not None:
50 |
51 | path = model_path.split('/')[0:-1]
52 | root_path = '/'.join(path)
53 | lora_cfg_pretrained = AutoConfig.from_pretrained(root_path)
54 | config = lora_cfg_pretrained
55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) # default use_fast=False
56 | print('Loading LLaVA-Pythia from base model...')
57 | model = LlavaPythiaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58 |
59 | # token_num, tokem_dim = model.embed_out.out_features, model.embed_out.in_features
60 | # if model.embed_out.weight.shape[0] != token_num:
61 | # model.embed_out.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62 | # model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
63 |
64 | print('Loading additional LLaVA-Pythia weights...')
65 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
66 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
67 | else:
68 | # this is probably from HF Hub
69 | from huggingface_hub import hf_hub_download
70 | def load_from_hf(repo_id, filename, subfolder=None):
71 | cache_file = hf_hub_download(
72 | repo_id=repo_id,
73 | filename=filename,
74 | subfolder=subfolder)
75 | return torch.load(cache_file, map_location='cpu')
76 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
77 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
78 | if any(k.startswith('model.gpt_neox.') for k in non_lora_trainables):
79 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
80 |
81 | # 删除lora相关的参数
82 | keys_to_del = []
83 | for k,v in non_lora_trainables.items():
84 | if 'lora' in k:
85 | keys_to_del.append(k)
86 | for key in keys_to_del:
87 | del non_lora_trainables[key]
88 |
89 | model.load_state_dict(non_lora_trainables, strict=False)
90 |
91 | from peft import PeftModel
92 | print('Loading LoRA weights...')
93 | model = PeftModel.from_pretrained(model, model_path)
94 | print('Merging LoRA weights...')
95 | model = model.merge_and_unload()
96 | print('Model is loaded...')
97 | elif model_base is not None:
98 | # this may be mm projector only
99 | print('Loading LLaVA-Pythia from base model...')
100 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
101 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
102 | model = LlavaPythiaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
103 |
104 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
105 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
106 | model.load_state_dict(mm_projector_weights, strict=False)
107 | else:
108 | print("load llaVA-Pythia MLLM!!!")
109 | config = LlavaPythiaConfig.from_pretrained(model_path, trust_remote_code=True)
110 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
111 | model = LlavaPythiaForCausalLM.from_pretrained(
112 | model_path,
113 | config=config,
114 | use_safetensors=True,
115 | **kwargs).to("cuda")
116 | else:
117 | # Load language model
118 | if model_base is not None:
119 | # PEFT model
120 | from peft import PeftModel
121 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
122 | model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
123 | print(f"Loading LoRA weights from {model_path}")
124 | model = PeftModel.from_pretrained(model, model_path)
125 | print(f"Merging weights")
126 | model = model.merge_and_unload()
127 | print('Convert to FP16...')
128 | model.to(torch.float16)
129 | else:
130 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
131 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
132 | if "clip" in config.vision_config["vision_tower"]["vision_model_name_or_path"]:
133 | image_processor = CLIPImageProcessor.from_pretrained(model_path)
134 | elif "siglip" in config.vision_config["vision_tower"]["vision_model_name_or_path"]:
135 | image_processor = SiglipImageProcessor.from_pretrained(model_path)
136 | else:
137 | return NotImplementedError
138 | # image_processor = CLIPImageProcessor.from_pretrained(model_path)
139 |
140 | if 'pythia' in model_name.lower():
141 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
142 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
143 |
144 | # TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200
145 | if mm_use_im_patch_token:
146 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
147 | if mm_use_im_start_end:
148 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
149 | # model.resize_token_embeddings(len(tokenizer))
150 | else:
151 | raise ValueError(f"Unsupported model name: {model_name}")
152 |
153 | if hasattr(model.config, "max_sequence_length"):
154 | context_len = model.config.max_sequence_length
155 | else:
156 | context_len = 2048
157 | model.to(device="cuda")
158 | print(kwargs)
159 | return tokenizer, model, image_processor, context_len
160 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/model/language_model/pythia/configuration_llava_pythia.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union
3 | from transformers import PretrainedConfig, GPTNeoXConfig
4 | from transformers.utils import logging
5 |
6 | logger = logging.get_logger(__name__)
7 |
8 |
9 | class LlavaPythiaVisionConfig(PretrainedConfig):
10 | r"""
11 | This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
12 | CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
13 | configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
14 | [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
15 |
16 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17 | documentation from [`PretrainedConfig`] for more information.
18 |
19 | Args:
20 | hidden_size (`int`, *optional*, defaults to 768):
21 | Dimensionality of the encoder layers and the pooler layer.
22 | intermediate_size (`int`, *optional*, defaults to 3072):
23 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
24 | projection_dim (`int`, *optional*, defaults to 512):
25 | Dimentionality of text and vision projection layers.
26 | num_hidden_layers (`int`, *optional*, defaults to 12):
27 | Number of hidden layers in the Transformer encoder.
28 | num_attention_heads (`int`, *optional*, defaults to 12):
29 | Number of attention heads for each attention layer in the Transformer encoder.
30 | num_channels (`int`, *optional*, defaults to 3):
31 | The number of input channels.
32 | image_size (`int`, *optional*, defaults to 224):
33 | The size (resolution) of each image.
34 | patch_size (`int`, *optional*, defaults to 32):
35 | The size (resolution) of each patch.
36 | hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
37 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
38 | `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
39 | layer_norm_eps (`float`, *optional*, defaults to 1e-05):
40 | The epsilon used by the layer normalization layers.
41 | attention_dropout (`float`, *optional*, defaults to 0.0):
42 | The dropout ratio for the attention probabilities.
43 | initializer_range (`float`, *optional*, defaults to 0.02):
44 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
45 | initializer_factor (`float`, *optional*, defaults to 1.0):
46 | A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
47 | testing).
48 | mm_vision_select_feature (`str`, *optional*, defaults to `"patch"`):
49 | The feature to select from the vision encoder output. Can be one of `"patch"` or `"cls_patch"`.
50 | mm_vision_select_layer (`int`, *optional*, defaults to `-2`):
51 | The layer to select from the vision encoder output.
52 |
53 | Example:
54 |
55 | ```python
56 | >>> from transformers import CLIPVisionConfig, CLIPVisionModel
57 |
58 | >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
59 | >>> configuration = CLIPVisionConfig()
60 |
61 | >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
62 | >>> model = CLIPVisionModel(configuration)
63 |
64 | >>> # Accessing the model configuration
65 | >>> configuration = model.config
66 | ```"""
67 |
68 | model_type = "llava_pythia_clip_vision_model"
69 |
70 | def __init__(
71 | self,
72 | hidden_size=768,
73 | intermediate_size=3072,
74 | projection_dim=512,
75 | num_hidden_layers=12,
76 | num_attention_heads=12,
77 | num_channels=3,
78 | image_size=224,
79 | patch_size=32,
80 | hidden_act="quick_gelu",
81 | layer_norm_eps=1e-5,
82 | attention_dropout=0.0,
83 | initializer_range=0.02,
84 | initializer_factor=1.0,
85 | mm_vision_select_feature="patch",
86 | mm_vision_select_layer=-2,
87 | vision_model_name_or_path="clip",
88 | concat="None",
89 | **kwargs,
90 | ):
91 | super().__init__(**kwargs)
92 |
93 | self.hidden_size = hidden_size
94 | self.intermediate_size = intermediate_size
95 | self.projection_dim = projection_dim
96 | self.num_hidden_layers = num_hidden_layers
97 | self.num_attention_heads = num_attention_heads
98 | self.num_channels = num_channels
99 | self.patch_size = patch_size
100 | self.image_size = image_size
101 | self.initializer_range = initializer_range
102 | self.initializer_factor = initializer_factor
103 | self.attention_dropout = attention_dropout
104 | self.layer_norm_eps = layer_norm_eps
105 | self.hidden_act = hidden_act
106 | self.mm_vision_select_feature = mm_vision_select_feature
107 | self.mm_vision_select_layer = mm_vision_select_layer
108 | self.vision_model_name_or_path = vision_model_name_or_path
109 | self.concat = concat
110 |
111 | @classmethod
112 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
113 | cls._set_token_in_kwargs(kwargs)
114 |
115 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
116 |
117 | # get the vision config dict if we are loading from CLIPConfig
118 | if config_dict.get("model_type") == "llava_pythia":
119 | config_dict = config_dict["vision_config"]["vision_tower"]
120 |
121 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
122 | logger.warning(
123 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
124 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
125 | )
126 |
127 | return cls.from_dict(config_dict, **kwargs)
128 |
129 |
130 | class ProjectorConfig(PretrainedConfig):
131 | model_type = "llava_pythia_projector"
132 |
133 | def __init__(
134 | self,
135 | mm_projector_type="linear",
136 | mm_hidden_size=768,
137 | hidden_size=2560,
138 | **kwargs
139 | ):
140 | self.mm_projector_type = mm_projector_type
141 | self.mm_hidden_size = mm_hidden_size
142 | self.hidden_size = hidden_size
143 | super().__init__(**kwargs)
144 |
145 | @classmethod
146 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
147 | cls._set_token_in_kwargs(kwargs)
148 |
149 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
150 |
151 | # get the vision config dict if we are loading from CLIPConfig
152 | if config_dict.get("model_type") == "llava_pythia":
153 | config_dict = config_dict["vision_config"]["mm_projector"]
154 |
155 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
156 | logger.warning(
157 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
158 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
159 | )
160 |
161 | return cls.from_dict(config_dict, **kwargs)
162 |
163 |
164 |
165 | from typing import List
166 |
167 | # for initialize act head
168 |
169 | DEFAULT_VISUAL_CONFIG = {
170 | "vision_tower": LlavaPythiaVisionConfig().to_dict(),
171 | "mm_projector": ProjectorConfig().to_dict(),
172 | }
173 |
174 | # print(DEFAULT_ACT_CONFIG['act'])
175 |
176 | class LlavaPythiaConfig(GPTNeoXConfig):
177 | model_type = "llava_pythia"
178 |
179 | def __init__(self, vision_config=None, **kwargs):
180 | if vision_config is None:
181 | self.vision_config = DEFAULT_VISUAL_CONFIG
182 | else:
183 | self.vision_config = vision_config
184 |
185 | self.concat = "None"
186 | super().__init__(**kwargs)
187 |
188 |
189 | if __name__ == "__main__":
190 | print(LlavaPythiaVisionConfig())
191 |
--------------------------------------------------------------------------------
/llava-pythia/llava_pythia/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 |
5 |
6 | class SeparatorStyle(Enum):
7 | """Different separator style."""
8 | SINGLE = auto()
9 | TWO = auto()
10 | MPT = auto()
11 | PLAIN = auto()
12 | LLAMA_2 = auto()
13 |
14 |
15 | @dataclasses.dataclass
16 | class Conversation:
17 | """A class that keeps all conversation history."""
18 | system: str
19 | roles: List[str]
20 | messages: List[List[str]]
21 | offset: int
22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23 | sep: str = "###"
24 | sep2: str = None
25 | version: str = "Unknown"
26 |
27 | skip_next: bool = False
28 |
29 | def get_prompt(self) -> str:
30 | """Generate a prompt string based on the conversation history and separator style."""
31 | messages = self.messages
32 | # check if the first message is a tuple and process accordingly
33 | if len(messages) > 0 and type(messages[0][1]) is tuple:
34 | messages = self.messages.copy()
35 | init_role, init_msg = messages[0].copy()
36 | init_msg = init_msg[0].replace("", "").strip()
37 | if 'mmtag' in self.version:
38 | messages[0] = (init_role, init_msg)
39 | messages.insert(0, (self.roles[0], ""))
40 | messages.insert(1, (self.roles[1], "Received."))
41 | else:
42 | messages[0] = (init_role, "\n" + init_msg)
43 |
44 | # construct the prompt based on the separator style
45 | if self.sep_style == SeparatorStyle.SINGLE:
46 | ret = self.system + self.sep
47 | for role, message in messages:
48 | if message:
49 | if type(message) is tuple:
50 | message, _, _ = message
51 | ret += role + ": " + message + self.sep
52 | else:
53 | ret += role + ":"
54 | elif self.sep_style == SeparatorStyle.TWO:
55 | seps = [self.sep, self.sep2]
56 | ret = self.system + seps[0]
57 | for i, (role, message) in enumerate(messages):
58 | if message:
59 | if type(message) is tuple:
60 | message, _, _ = message
61 | ret += role + ": " + message + seps[i % 2]
62 | else:
63 | ret += role + ":"
64 | elif self.sep_style == SeparatorStyle.PLAIN:
65 | seps = [self.sep, self.sep2]
66 | ret = self.system
67 | for i, (role, message) in enumerate(messages):
68 | if message:
69 | if type(message) is tuple:
70 | message, _, _ = message
71 | ret += message + seps[i % 2]
72 | else:
73 | ret += ""
74 | else:
75 | raise ValueError(f"Invalid style: {self.sep_style}")
76 |
77 | return ret
78 |
79 | def append_message(self, role: str, message: str) -> None:
80 | """Append a new message to the conversation."""
81 | self.messages.append([role, message])
82 |
83 | def get_images(self, return_pil: bool = False) -> List:
84 | """Extract images from the conversation messages."""
85 | images = []
86 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
87 | if i % 2 == 0:
88 | if type(msg) is tuple:
89 | import base64
90 | from io import BytesIO
91 | from PIL import Image
92 | msg, image, image_process_mode = msg
93 | # process image based on the mode
94 | if image_process_mode == "Pad":
95 | def expand2square(pil_img, background_color=(122, 116, 104)):
96 | width, height = pil_img.size
97 | if width == height:
98 | return pil_img
99 | elif width > height:
100 | result = Image.new(pil_img.mode, (width, width), background_color)
101 | result.paste(pil_img, (0, (width - height) // 2))
102 | return result
103 | else:
104 | result = Image.new(pil_img.mode, (height, height), background_color)
105 | result.paste(pil_img, ((height - width) // 2, 0))
106 | return result
107 | image = expand2square(image)
108 | elif image_process_mode in ["Default", "Crop"]:
109 | pass
110 | elif image_process_mode == "Resize":
111 | image = image.resize((336, 336))
112 | else:
113 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
114 | # resize image to maintain aspect ratio
115 | max_hw, min_hw = max(image.size), min(image.size)
116 | aspect_ratio = max_hw / min_hw
117 | max_len, min_len = 800, 400
118 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
119 | longest_edge = int(shortest_edge * aspect_ratio)
120 | W, H = image.size
121 | if longest_edge != max(image.size):
122 | if H > W:
123 | H, W = longest_edge, shortest_edge
124 | else:
125 | H, W = shortest_edge, longest_edge
126 | image = image.resize((W, H))
127 | if return_pil:
128 | images.append(image)
129 | else:
130 | buffered = BytesIO()
131 | image.save(buffered, format="PNG")
132 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
133 | images.append(img_b64_str)
134 | return images
135 |
136 | def to_gradio_chatbot(self):
137 | ret = []
138 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
139 | if i % 2 == 0:
140 | if type(msg) is tuple:
141 | import base64
142 | from io import BytesIO
143 | msg, image, image_process_mode = msg
144 | max_hw, min_hw = max(image.size), min(image.size)
145 | aspect_ratio = max_hw / min_hw
146 | max_len, min_len = 800, 400
147 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148 | longest_edge = int(shortest_edge * aspect_ratio)
149 | W, H = image.size
150 | if H > W:
151 | H, W = longest_edge, shortest_edge
152 | else:
153 | H, W = shortest_edge, longest_edge
154 | image = image.resize((W, H))
155 | buffered = BytesIO()
156 | image.save(buffered, format="JPEG")
157 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
158 | img_str = f'
'
159 | msg = img_str + msg.replace('', '').strip()
160 | ret.append([msg, None])
161 | else:
162 | ret.append([msg, None])
163 | else:
164 | ret[-1][-1] = msg
165 | return ret
166 |
167 | def copy(self):
168 | return Conversation(
169 | system=self.system,
170 | roles=self.roles,
171 | messages=[[x, y] for x, y in self.messages],
172 | offset=self.offset,
173 | sep_style=self.sep_style,
174 | sep=self.sep,
175 | sep2=self.sep2,
176 | version=self.version)
177 |
178 | def dict(self):
179 | if len(self.get_images()) > 0:
180 | return {
181 | "system": self.system,
182 | "roles": self.roles,
183 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
184 | "offset": self.offset,
185 | "sep": self.sep,
186 | "sep2": self.sep2,
187 | }
188 | return {
189 | "system": self.system,
190 | "roles": self.roles,
191 | "messages": self.messages,
192 | "offset": self.offset,
193 | "sep": self.sep,
194 | "sep2": self.sep2,
195 | }
196 |
197 |
198 | conv_pythia = Conversation(
199 | system="A chat between a curious user and an artificial intelligence assistant. "
200 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
201 | roles=("USER", "ASSISTANT"),
202 | version="v0",
203 | messages=(),
204 | offset=0,
205 | sep_style=SeparatorStyle.TWO,
206 | sep=" ",
207 | sep2="<|endoftext|>",
208 | )
209 |
210 | conv_llava_plain = Conversation(
211 | system="",
212 | roles=("", ""),
213 | messages=(),
214 | offset=0,
215 | sep_style=SeparatorStyle.PLAIN,
216 | sep="\n",
217 | )
218 |
219 | default_conversation = conv_pythia
220 | conv_templates = {
221 | "default": conv_pythia,
222 | "v0": conv_pythia,
223 | "pythia": conv_pythia,
224 |
225 | "plain": conv_llava_plain,
226 | }
227 |
228 |
229 | if __name__ == "__main__":
230 | print(default_conversation.get_prompt())
231 |
--------------------------------------------------------------------------------
/policy_heads/models/droid_unet_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
3 | """
4 | from typing import Callable, Union
5 | import math
6 | from collections import OrderedDict, deque
7 | from packaging.version import parse as parse_version
8 | import random
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | # requires diffusers==0.11.1
13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
14 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
15 | from diffusers.training_utils import EMAModel
16 |
17 |
18 | # =================== UNet for Diffusion ==============
19 |
20 | class SinusoidalPosEmb(nn.Module):
21 | """
22 | Sinusoidal positional embedding for diffusion models.
23 |
24 | Args:
25 | dim: The dimension of the embedding.
26 | dtype: The data type for the embedding.
27 | """
28 | def __init__(self, dim, dtype):
29 | super().__init__()
30 | self.dim = dim
31 | self.dtype = dtype
32 |
33 | def forward(self, x):
34 | """
35 | Forward pass to compute the sinusoidal positional embedding.
36 |
37 | Args:
38 | x: Input tensor.
39 |
40 | Returns:
41 | The sinusoidal positional embedding.
42 | """
43 | device = x.device
44 | half_dim = self.dim // 2
45 | emb = math.log(10000) / (half_dim - 1)
46 | emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
47 | emb = x[:, None] * emb[None, :]
48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
49 | return emb
50 |
51 |
52 | class Downsample1d(nn.Module):
53 | """
54 | 1D downsampling layer using convolution.
55 |
56 | Args:
57 | dim: The number of input and output channels.
58 | """
59 | def __init__(self, dim):
60 | super().__init__()
61 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
62 |
63 | def forward(self, x):
64 | return self.conv(x)
65 |
66 |
67 | class Upsample1d(nn.Module):
68 | """
69 | 1D upsampling layer using transposed convolution.
70 |
71 | Args:
72 | dim: The number of input and output channels.
73 | """
74 | def __init__(self, dim):
75 | super().__init__()
76 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
77 |
78 | def forward(self, x):
79 | return self.conv(x)
80 |
81 |
82 | class Conv1dBlock(nn.Module):
83 | """
84 | A block consisting of Conv1d, GroupNorm, and Mish activation.
85 |
86 | Args:
87 | inp_channels: Number of input channels.
88 | out_channels: Number of output channels.
89 | kernel_size: Size of the convolutional kernel.
90 | n_groups: Number of groups for GroupNorm.
91 | """
92 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
93 | super().__init__()
94 |
95 | self.block = nn.Sequential(
96 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
97 | nn.GroupNorm(n_groups, out_channels),
98 | nn.Mish(),
99 | )
100 |
101 | def forward(self, x):
102 | return self.block(x)
103 |
104 |
105 | class ConditionalResidualBlock1D(nn.Module):
106 | """
107 | Conditional residual block with FiLM modulation.
108 |
109 | Args:
110 | in_channels: Number of input channels.
111 | out_channels: Number of output channels.
112 | cond_dim: Dimension of the conditioning input.
113 | kernel_size: Size of the convolutional kernel.
114 | n_groups: Number of groups for GroupNorm.
115 | """
116 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
117 | super().__init__()
118 |
119 | self.blocks = nn.ModuleList([
120 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
121 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
122 | ])
123 |
124 | # FiLM modulation https://arxiv.org/abs/1709.07871
125 | # predicts per-channel scale and bias
126 | cond_channels = out_channels * 2
127 | self.out_channels = out_channels
128 | self.cond_encoder = nn.Sequential(
129 | nn.Mish(),
130 | nn.Linear(cond_dim, cond_channels),
131 | nn.Unflatten(-1, (-1, 1))
132 | )
133 |
134 | # ensure dimensions are compatible
135 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
136 | if in_channels != out_channels else nn.Identity()
137 |
138 | def forward(self, x, cond):
139 | """
140 | Forward pass for the conditional residual block.
141 |
142 | Args:
143 | x: Input tensor of shape [batch_size, in_channels, horizon].
144 | cond: Conditioning tensor of shape [batch_size, cond_dim].
145 |
146 | Returns:
147 | Output tensor of shape [batch_size, out_channels, horizon].
148 | """
149 | out = self.blocks[0](x)
150 | embed = self.cond_encoder(cond)
151 |
152 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
153 | scale = embed[:, 0, ...]
154 | bias = embed[:, 1, ...]
155 | out = scale * out + bias
156 |
157 | out = self.blocks[1](out)
158 | out = out + self.residual_conv(x)
159 | return out
160 |
161 |
162 | class ConditionalUnet1D(nn.Module):
163 | """
164 | Conditional 1D UNet for diffusion models.
165 |
166 | Args:
167 | input_dim: Dimension of the input actions.
168 | global_cond_dim: Dimension of global conditioning applied with FiLM.
169 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k.
170 | down_dims: Channel size for each UNet level.
171 | kernel_size: Convolutional kernel size.
172 | n_groups: Number of groups for GroupNorm.
173 | state_dim: Dimension of the state input.
174 | """
175 | def __init__(self, input_dim, global_cond_dim, diffusion_step_embed_dim=256,
176 | down_dims=[256, 512, 1024], kernel_size=5, n_groups=8, state_dim=7):
177 | super().__init__()
178 | all_dims = [input_dim] + list(down_dims)
179 | start_dim = down_dims[0]
180 |
181 | self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
182 | self.norm_after_pool = nn.LayerNorm(global_cond_dim)
183 | self.combine = nn.Linear(global_cond_dim + state_dim, global_cond_dim)
184 |
185 | dsed = diffusion_step_embed_dim
186 | diffusion_step_encoder = nn.Sequential(
187 | SinusoidalPosEmb(dsed, torch.bfloat16),
188 | nn.Linear(dsed, dsed * 4),
189 | nn.Mish(),
190 | nn.Linear(dsed * 4, dsed),
191 | )
192 | cond_dim = dsed + global_cond_dim
193 |
194 | in_out = list(zip(all_dims[:-1], all_dims[1:]))
195 | mid_dim = all_dims[-1]
196 | self.mid_modules = nn.ModuleList([
197 | ConditionalResidualBlock1D(
198 | mid_dim, mid_dim, cond_dim=cond_dim,
199 | kernel_size=kernel_size, n_groups=n_groups
200 | ),
201 | ConditionalResidualBlock1D(
202 | mid_dim, mid_dim, cond_dim=cond_dim,
203 | kernel_size=kernel_size, n_groups=n_groups
204 | ),
205 | ])
206 |
207 | down_modules = nn.ModuleList([])
208 | for ind, (dim_in, dim_out) in enumerate(in_out):
209 | is_last = ind >= (len(in_out) - 1)
210 | down_modules.append(nn.ModuleList([
211 | ConditionalResidualBlock1D(
212 | dim_in, dim_out, cond_dim=cond_dim,
213 | kernel_size=kernel_size, n_groups=n_groups),
214 | ConditionalResidualBlock1D(
215 | dim_out, dim_out, cond_dim=cond_dim,
216 | kernel_size=kernel_size, n_groups=n_groups),
217 | Downsample1d(dim_out) if not is_last else nn.Identity()
218 | ]))
219 |
220 | up_modules = nn.ModuleList([])
221 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
222 | is_last = ind >= (len(in_out) - 1)
223 | up_modules.append(nn.ModuleList([
224 | ConditionalResidualBlock1D(
225 | dim_out * 2, dim_in, cond_dim=cond_dim,
226 | kernel_size=kernel_size, n_groups=n_groups),
227 | ConditionalResidualBlock1D(
228 | dim_in, dim_in, cond_dim=cond_dim,
229 | kernel_size=kernel_size, n_groups=n_groups),
230 | Upsample1d(dim_in) if not is_last else nn.Identity()
231 | ]))
232 |
233 | final_conv = nn.Sequential(
234 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
235 | nn.Conv1d(start_dim, input_dim, 1),
236 | )
237 |
238 | self.diffusion_step_encoder = diffusion_step_encoder
239 | self.up_modules = up_modules
240 | self.down_modules = down_modules
241 | self.final_conv = final_conv
242 |
243 | print("number of parameters: {:e}".format(
244 | sum(p.numel() for p in self.parameters()))
245 | )
246 |
247 | def forward(self,
248 | sample: torch.Tensor,
249 | timestep: Union[torch.Tensor, float, int],
250 | global_cond=None,
251 | states=None):
252 | """
253 | Forward pass for the Conditional UNet.
254 |
255 | Args:
256 | sample: Input tensor of shape (B, T, input_dim).
257 | timestep: Diffusion step, can be a tensor or an integer.
258 | global_cond: Global conditioning tensor of shape (B, global_cond_dim).
259 | states: Optional state tensor.
260 |
261 | Returns:
262 | Output tensor of shape (B, T, input_dim).
263 | """
264 | # move axis for processing
265 | sample = sample.moveaxis(-1, -2)
266 | # process global conditioning
267 | global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
268 | global_cond = self.norm_after_pool(global_cond) # layernorm
269 | global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
270 | global_cond = self.combine(global_cond)
271 |
272 | timesteps = timestep
273 | if not torch.is_tensor(timesteps):
274 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
275 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
276 | timesteps = timesteps[None].to(sample.device)
277 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
278 | timesteps = timesteps.expand(sample.shape[0])
279 |
280 | global_feature = self.diffusion_step_encoder(timesteps)
281 |
282 | if global_cond is not None:
283 | global_feature = torch.cat([
284 | global_feature, global_cond
285 | ], axis=-1)
286 |
287 | x = sample
288 | h = []
289 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
290 | x = resnet(x, global_feature)
291 | x = resnet2(x, global_feature)
292 | h.append(x)
293 | x = downsample(x)
294 |
295 | for mid_module in self.mid_modules:
296 | x = mid_module(x, global_feature)
297 |
298 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
299 | x = torch.cat((x, h.pop()), dim=1)
300 | x = resnet(x, global_feature)
301 | x = resnet2(x, global_feature)
302 | x = upsample(x)
303 |
304 | x = self.final_conv(x)
305 |
306 | # (B,C,T)
307 | x = x.moveaxis(-1, -2)
308 | # (B,T,C)
309 | return x
310 |
--------------------------------------------------------------------------------
/policy_heads/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2020 - present, Facebook, Inc
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/llava-pythia/scripts/convert_vqav2_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 |
5 | # from llava_pythia.eval.m4c_evaluator import EvalAIAnswerProcessor
6 |
7 | import re
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | class EvalAIAnswerProcessor:
13 | """
14 | Processes an answer similar to Eval AI
15 | copied from
16 | https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
17 | """
18 |
19 | CONTRACTIONS = {
20 | "aint": "ain't",
21 | "arent": "aren't",
22 | "cant": "can't",
23 | "couldve": "could've",
24 | "couldnt": "couldn't",
25 | "couldn'tve": "couldn't've",
26 | "couldnt've": "couldn't've",
27 | "didnt": "didn't",
28 | "doesnt": "doesn't",
29 | "dont": "don't",
30 | "hadnt": "hadn't",
31 | "hadnt've": "hadn't've",
32 | "hadn'tve": "hadn't've",
33 | "hasnt": "hasn't",
34 | "havent": "haven't",
35 | "hed": "he'd",
36 | "hed've": "he'd've",
37 | "he'dve": "he'd've",
38 | "hes": "he's",
39 | "howd": "how'd",
40 | "howll": "how'll",
41 | "hows": "how's",
42 | "Id've": "I'd've",
43 | "I'dve": "I'd've",
44 | "Im": "I'm",
45 | "Ive": "I've",
46 | "isnt": "isn't",
47 | "itd": "it'd",
48 | "itd've": "it'd've",
49 | "it'dve": "it'd've",
50 | "itll": "it'll",
51 | "let's": "let's",
52 | "maam": "ma'am",
53 | "mightnt": "mightn't",
54 | "mightnt've": "mightn't've",
55 | "mightn'tve": "mightn't've",
56 | "mightve": "might've",
57 | "mustnt": "mustn't",
58 | "mustve": "must've",
59 | "neednt": "needn't",
60 | "notve": "not've",
61 | "oclock": "o'clock",
62 | "oughtnt": "oughtn't",
63 | "ow's'at": "'ow's'at",
64 | "'ows'at": "'ow's'at",
65 | "'ow'sat": "'ow's'at",
66 | "shant": "shan't",
67 | "shed've": "she'd've",
68 | "she'dve": "she'd've",
69 | "she's": "she's",
70 | "shouldve": "should've",
71 | "shouldnt": "shouldn't",
72 | "shouldnt've": "shouldn't've",
73 | "shouldn'tve": "shouldn't've",
74 | "somebody'd": "somebodyd",
75 | "somebodyd've": "somebody'd've",
76 | "somebody'dve": "somebody'd've",
77 | "somebodyll": "somebody'll",
78 | "somebodys": "somebody's",
79 | "someoned": "someone'd",
80 | "someoned've": "someone'd've",
81 | "someone'dve": "someone'd've",
82 | "someonell": "someone'll",
83 | "someones": "someone's",
84 | "somethingd": "something'd",
85 | "somethingd've": "something'd've",
86 | "something'dve": "something'd've",
87 | "somethingll": "something'll",
88 | "thats": "that's",
89 | "thered": "there'd",
90 | "thered've": "there'd've",
91 | "there'dve": "there'd've",
92 | "therere": "there're",
93 | "theres": "there's",
94 | "theyd": "they'd",
95 | "theyd've": "they'd've",
96 | "they'dve": "they'd've",
97 | "theyll": "they'll",
98 | "theyre": "they're",
99 | "theyve": "they've",
100 | "twas": "'twas",
101 | "wasnt": "wasn't",
102 | "wed've": "we'd've",
103 | "we'dve": "we'd've",
104 | "weve": "we've",
105 | "werent": "weren't",
106 | "whatll": "what'll",
107 | "whatre": "what're",
108 | "whats": "what's",
109 | "whatve": "what've",
110 | "whens": "when's",
111 | "whered": "where'd",
112 | "wheres": "where's",
113 | "whereve": "where've",
114 | "whod": "who'd",
115 | "whod've": "who'd've",
116 | "who'dve": "who'd've",
117 | "wholl": "who'll",
118 | "whos": "who's",
119 | "whove": "who've",
120 | "whyll": "why'll",
121 | "whyre": "why're",
122 | "whys": "why's",
123 | "wont": "won't",
124 | "wouldve": "would've",
125 | "wouldnt": "wouldn't",
126 | "wouldnt've": "wouldn't've",
127 | "wouldn'tve": "wouldn't've",
128 | "yall": "y'all",
129 | "yall'll": "y'all'll",
130 | "y'allll": "y'all'll",
131 | "yall'd've": "y'all'd've",
132 | "y'alld've": "y'all'd've",
133 | "y'all'dve": "y'all'd've",
134 | "youd": "you'd",
135 | "youd've": "you'd've",
136 | "you'dve": "you'd've",
137 | "youll": "you'll",
138 | "youre": "you're",
139 | "youve": "you've",
140 | }
141 |
142 | NUMBER_MAP = {
143 | "none": "0",
144 | "zero": "0",
145 | "one": "1",
146 | "two": "2",
147 | "three": "3",
148 | "four": "4",
149 | "five": "5",
150 | "six": "6",
151 | "seven": "7",
152 | "eight": "8",
153 | "nine": "9",
154 | "ten": "10",
155 | }
156 | ARTICLES = ["a", "an", "the"]
157 | PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
158 | COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
159 | PUNCTUATIONS = [
160 | ";",
161 | r"/",
162 | "[",
163 | "]",
164 | '"',
165 | "{",
166 | "}",
167 | "(",
168 | ")",
169 | "=",
170 | "+",
171 | "\\",
172 | "_",
173 | "-",
174 | ">",
175 | "<",
176 | "@",
177 | "`",
178 | ",",
179 | "?",
180 | "!",
181 | ]
182 |
183 | def __init__(self, *args, **kwargs):
184 | pass
185 |
186 | def word_tokenize(self, word):
187 | word = word.lower()
188 | word = word.replace(",", "").replace("?", "").replace("'s", " 's")
189 | return word.strip()
190 |
191 | def process_punctuation(self, in_text):
192 | out_text = in_text
193 | for p in self.PUNCTUATIONS:
194 | if (p + " " in in_text or " " + p in in_text) or (
195 | re.search(self.COMMA_STRIP, in_text) is not None
196 | ):
197 | out_text = out_text.replace(p, "")
198 | else:
199 | out_text = out_text.replace(p, " ")
200 | out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
201 | return out_text
202 |
203 | def process_digit_article(self, in_text):
204 | out_text = []
205 | temp_text = in_text.lower().split()
206 | for word in temp_text:
207 | word = self.NUMBER_MAP.setdefault(word, word)
208 | if word not in self.ARTICLES:
209 | out_text.append(word)
210 | else:
211 | pass
212 | for word_id, word in enumerate(out_text):
213 | if word in self.CONTRACTIONS:
214 | out_text[word_id] = self.CONTRACTIONS[word]
215 | out_text = " ".join(out_text)
216 | return out_text
217 |
218 | def __call__(self, item):
219 | item = self.word_tokenize(item)
220 | item = item.replace("\n", " ").replace("\t", " ").strip()
221 | item = self.process_punctuation(item)
222 | item = self.process_digit_article(item)
223 | return item
224 |
225 |
226 | class TextVQAAccuracyEvaluator:
227 | def __init__(self):
228 | self.answer_processor = EvalAIAnswerProcessor()
229 |
230 | def _compute_answer_scores(self, raw_answers):
231 | """
232 | compute the accuracy (soft score) of human answers
233 | """
234 | answers = [self.answer_processor(a) for a in raw_answers]
235 | assert len(answers) == 10
236 | gt_answers = list(enumerate(answers))
237 | unique_answers = set(answers)
238 | unique_answer_scores = {}
239 |
240 | for unique_answer in unique_answers:
241 | accs = []
242 | for gt_answer in gt_answers:
243 | other_answers = [item for item in gt_answers if item != gt_answer]
244 | matching_answers = [
245 | item for item in other_answers if item[1] == unique_answer
246 | ]
247 | acc = min(1, float(len(matching_answers)) / 3)
248 | accs.append(acc)
249 | unique_answer_scores[unique_answer] = sum(accs) / len(accs)
250 |
251 | return unique_answer_scores
252 |
253 | def eval_pred_list(self, pred_list):
254 | pred_scores = []
255 | for entry in tqdm(pred_list):
256 | pred_answer = self.answer_processor(entry["pred_answer"])
257 | unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
258 | score = unique_answer_scores.get(pred_answer, 0.0)
259 | pred_scores.append(score)
260 |
261 | accuracy = sum(pred_scores) / len(pred_scores)
262 | return accuracy
263 |
264 |
265 | class STVQAAccuracyEvaluator:
266 | def __init__(self):
267 | self.answer_processor = EvalAIAnswerProcessor()
268 |
269 | def eval_pred_list(self, pred_list):
270 | pred_scores = []
271 | for entry in pred_list:
272 | pred_answer = self.answer_processor(entry["pred_answer"])
273 | gts = [self.answer_processor(a) for a in entry["gt_answers"]]
274 | score = 1.0 if pred_answer in gts else 0.0
275 | pred_scores.append(score)
276 |
277 | accuracy = sum(pred_scores) / len(pred_scores)
278 | return accuracy
279 |
280 |
281 | class STVQAANLSEvaluator:
282 | def __init__(self):
283 | import editdistance # install with `pip install editdistance`
284 |
285 | self.get_edit_distance = editdistance.eval
286 |
287 | def get_anls(self, s1, s2):
288 | s1 = s1.lower().strip()
289 | s2 = s2.lower().strip()
290 | iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
291 | anls = iou if iou >= 0.5 else 0.0
292 | return anls
293 |
294 | def eval_pred_list(self, pred_list):
295 | pred_scores = []
296 | for entry in pred_list:
297 | anls = max(
298 | self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
299 | )
300 | pred_scores.append(anls)
301 |
302 | accuracy = sum(pred_scores) / len(pred_scores)
303 | return accuracy
304 |
305 |
306 | class TextCapsBleu4Evaluator:
307 | def __init__(self):
308 | # The following script requires Java 1.8.0 and pycocotools installed.
309 | # The pycocoevalcap can be installed with pip as
310 | # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
311 | # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
312 | # but has no python3 support yet.
313 | try:
314 | from pycocoevalcap.bleu.bleu import Bleu
315 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
316 | except ModuleNotFoundError:
317 | print(
318 | "Please install pycocoevalcap module using "
319 | "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
320 | )
321 | raise
322 |
323 | self.tokenizer = PTBTokenizer()
324 | self.scorer = Bleu(4)
325 |
326 | def eval_pred_list(self, pred_list):
327 | # Create reference and hypotheses captions.
328 | gts = {}
329 | res = {}
330 | for idx, entry in enumerate(pred_list):
331 | gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
332 | res[idx] = [{"caption": entry["pred_answer"]}]
333 |
334 | gts = self.tokenizer.tokenize(gts)
335 | res = self.tokenizer.tokenize(res)
336 | score, _ = self.scorer.compute_score(gts, res)
337 |
338 | bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
339 | return bleu4
340 |
341 |
342 | def parse_args():
343 | parser = argparse.ArgumentParser()
344 | parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2")
345 | parser.add_argument('--ckpt', type=str, required=True)
346 | parser.add_argument('--split', type=str, required=True)
347 | return parser.parse_args()
348 |
349 |
350 | if __name__ == '__main__':
351 |
352 | args = parse_args()
353 |
354 | src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl')
355 | test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl')
356 | dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json')
357 | os.makedirs(os.path.dirname(dst), exist_ok=True)
358 |
359 | results = []
360 | error_line = 0
361 | for line_idx, line in enumerate(open(src)):
362 | try:
363 | results.append(json.loads(line))
364 | except:
365 | error_line += 1
366 |
367 | results = {x['question_id']: x['text'] for x in results}
368 | test_split = [json.loads(line) for line in open(test_split)]
369 | split_ids = set([x['question_id'] for x in test_split])
370 |
371 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
372 |
373 | all_answers = []
374 |
375 | answer_processor = EvalAIAnswerProcessor()
376 |
377 | for x in test_split:
378 | if x['question_id'] not in results:
379 | all_answers.append({
380 | 'question_id': x['question_id'],
381 | 'answer': ''
382 | })
383 | else:
384 | all_answers.append({
385 | 'question_id': x['question_id'],
386 | 'answer': answer_processor(results[x['question_id']])
387 | })
388 |
389 | with open(dst, 'w') as f:
390 | json.dump(all_answers, open(dst, 'w'))
391 |
--------------------------------------------------------------------------------