├── LICENSE ├── README.md ├── ReConV2 ├── cfgs │ ├── dataset_configs │ │ ├── Hybrid.yaml │ │ ├── HybridLabeled.yaml │ │ ├── ModelNet10.yaml │ │ ├── ModelNet40.yaml │ │ ├── ModelNet40FewShot.yaml │ │ ├── OpenShape.yaml │ │ ├── ScanObjectNN_hardest.yaml │ │ ├── ScanObjectNN_objectbg.yaml │ │ └── ScanObjectNN_objectonly.yaml │ ├── pretrain │ │ ├── base │ │ │ ├── hybrid.yaml │ │ │ ├── hybrid_post.yaml │ │ │ ├── openshape.yaml │ │ │ └── openshape_1k.yaml │ │ ├── giant │ │ │ └── openshape.yaml │ │ ├── large │ │ │ ├── hybrid.yaml │ │ │ ├── hybrid_post.yaml │ │ │ ├── openshape.yaml │ │ │ └── openshape_1k.yaml │ │ └── small │ │ │ ├── hybrid.yaml │ │ │ ├── hybrid_post.yaml │ │ │ ├── openshape.yaml │ │ │ └── openshape_1k.yaml │ ├── svm │ │ └── modelnet40.yaml │ └── transfer │ │ ├── base │ │ ├── fewshot.yaml │ │ ├── finetune_modelnet.yaml │ │ ├── finetune_modelnet_8k.yaml │ │ ├── finetune_scan_hardest.yaml │ │ ├── finetune_scan_objbg.yaml │ │ └── finetune_scan_objonly.yaml │ │ ├── large │ │ ├── fewshot.yaml │ │ ├── finetune_modelnet.yaml │ │ ├── finetune_modelnet_8k.yaml │ │ ├── finetune_scan_hardest.yaml │ │ ├── finetune_scan_objbg.yaml │ │ └── finetune_scan_objonly.yaml │ │ └── small │ │ ├── fewshot.yaml │ │ ├── finetune_modelnet.yaml │ │ ├── finetune_modelnet_8k.yaml │ │ ├── finetune_scan_hardest.yaml │ │ ├── finetune_scan_objbg.yaml │ │ └── finetune_scan_objonly.yaml ├── convert_features.py ├── datasets │ ├── HybridDataset.py │ ├── ModelNetDataset.py │ ├── ModelNetDatasetFewShot.py │ ├── OpenShape.py │ ├── ScanObjectNNDataset.py │ ├── __init__.py │ ├── build.py │ ├── data.py │ ├── data_transforms.py │ ├── io.py │ └── pc_render.py ├── extensions │ └── chamfer_distance │ │ ├── __init__.py │ │ ├── chamfer_distance.cpp │ │ ├── chamfer_distance.cu │ │ └── chamfer_distance.py ├── figure │ └── framework.png ├── generate_depth_map.py ├── main.py ├── models │ ├── ReCon.py │ ├── __init__.py │ ├── build.py │ └── transformer.py ├── multi_cls.py ├── requirements.txt ├── scripts │ ├── downstream │ │ ├── cls.sh │ │ ├── fewshot.sh │ │ ├── svm.sh │ │ ├── test.sh │ │ └── zeroshot.sh │ ├── pretrain_transfer │ │ ├── pretrain_contrast.sh │ │ ├── pretrain_reconstruct.sh │ │ └── pretrain_supervise.sh │ └── pretrain_zeroshot │ │ ├── pretrain.sh │ │ ├── pretrain_contrast.sh │ │ └── pretrain_reconstruct.sh ├── segmentation │ ├── dataset.py │ ├── knn.py │ ├── logger.py │ ├── main.py │ ├── misc.py │ ├── models │ │ ├── pointnet2_utils.py │ │ └── pt.py │ ├── pointnet_util.py │ ├── provider.py │ ├── seg.sh │ └── test.sh ├── tools │ ├── __init__.py │ ├── builder.py │ ├── runner_finetune.py │ ├── runner_pretrain.py │ ├── runner_svm.py │ └── runner_zeroshot.py └── utils │ ├── AverageMeter.py │ ├── checkpoint.py │ ├── config.py │ ├── data.py │ ├── dist_utils.py │ ├── knn.py │ ├── logger.py │ ├── misc.py │ ├── parser.py │ ├── randaugment.py │ ├── registry.py │ └── transforms.py ├── assets ├── framework.jpg └── instrument.npy ├── docs ├── DATA.md ├── LoRA.md ├── MODEL_ZOO.md └── Windows.md ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── data │ │ ├── __init__.py │ │ ├── modelnet.py │ │ ├── modelnet_config │ │ │ ├── ModelNet40.yaml │ │ │ └── modelnet40_shape_names_modified.txt │ │ ├── object_point_dataset.py │ │ └── utils.py │ ├── eval_3dmmvet.py │ ├── eval_gapartnet.py │ ├── eval_modelnet_cls.py │ ├── eval_objaverse.py │ ├── evaluator.py │ ├── gpt_eval.py │ ├── model_vqa.py │ ├── traditional_evaluator.py │ └── utils.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── playground └── data │ └── eval │ ├── 3d-mm-vet │ ├── gt.jsonl │ └── question.jsonl │ ├── gapartnet │ ├── .DS_Store │ ├── gt.jsonl │ ├── question.jsonl │ └── test_list.json │ └── modelnet40 │ └── modelnet40_shape_names_modified.txt ├── pyproject.toml └── scripts ├── eval ├── eval_gapartnet.sh ├── eval_mmvet.sh ├── eval_modelnet40_cls.sh ├── eval_objaverse_cap.sh ├── eval_objaverse_cls.sh ├── gapartnet_ref.sh ├── mmvet.sh ├── modelnet40_cls.sh ├── objaverse_cap.sh └── objaverse_cls.sh ├── extract_mm_projector.py ├── finetune.sh ├── finetune_lora.sh ├── inference.sh ├── merge_lora_weights.py ├── pretrain.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /ReConV2/cfgs/dataset_configs/Hybrid.yaml: -------------------------------------------------------------------------------- 1 | NAME: Hybrid 2 | DATA_PATH: ReConV2/data/HybridDatasets 3 | IMG_PATH: ReConV2/data/HybridDatasets/depth 4 | IMG_FEATURE_PATH: ReConV2/data/HybridDatasets/depth_feature 5 | ratio: 1.0 6 | -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/HybridLabeled.yaml: -------------------------------------------------------------------------------- 1 | NAME: HybridLabeled 2 | DATA_PATH: ReConV2/data/HybridDatasets 3 | ratio: 1.0 4 | -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ModelNet10.yaml: -------------------------------------------------------------------------------- 1 | NAME: ModelNet 2 | DATA_PATH: ReConV2/data/ModelNet/modelnet40_normal_resampled 3 | N_POINTS: 8192 4 | NUM_CATEGORY: 10 5 | USE_NORMALS: FALSE -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ModelNet40.yaml: -------------------------------------------------------------------------------- 1 | NAME: ModelNet 2 | DATA_PATH: ReConV2/data/ModelNet/modelnet40_normal_resampled 3 | N_POINTS: 8192 4 | NUM_CATEGORY: 40 5 | USE_NORMALS: FALSE -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml: -------------------------------------------------------------------------------- 1 | NAME: ModelNetFewShot 2 | DATA_PATH: ReConV2/data/ModelNetFewshot 3 | N_POINTS: 8192 4 | NUM_CATEGORY: 40 5 | USE_NORMALS: FALSE -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/OpenShape.yaml: -------------------------------------------------------------------------------- 1 | NAME: OpenShape 2 | DATA_PATH: ReConV2/data/openshape 3 | ratio: 1.0 4 | -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml: -------------------------------------------------------------------------------- 1 | NAME: ScanObjectNN_hardest 2 | ROOT: ReConV2/data/ScanObjectNN/main_split -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml: -------------------------------------------------------------------------------- 1 | NAME: ScanObjectNN 2 | ROOT: ReConV2/data/ScanObjectNN/main_split -------------------------------------------------------------------------------- /ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml: -------------------------------------------------------------------------------- 1 | NAME: ScanObjectNN 2 | ROOT: ReConV2/data/ScanObjectNN/main_split_nobg -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/base/hybrid.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/Hybrid.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | 20 | model: 21 | NAME: ReCon2 22 | group_size: 32 23 | num_group: 64 24 | mask_ratio: 0.7 25 | mask_type: causal 26 | embed_dim: 768 27 | depth: 12 28 | drop_path_rate: 0.1 29 | num_heads: 12 30 | decoder_depth: 4 31 | with_color: False 32 | stop_grad: True 33 | large_embedding: True 34 | img_queries: 10 35 | text_queries: 1 36 | contrast_type: byol 37 | pretrained_model_name: vit_base_patch32_clip_384.openai_ft_in12k_in1k 38 | 39 | npoints: 1024 40 | total_bs: 512 41 | step_per_update: 1 42 | max_epoch: 300 43 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/base/hybrid_post.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 100 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 21 | others: 22 | subset: test 23 | npoints: 1024 24 | 25 | model: 26 | NAME: PointTransformer 27 | embed_dim: 768 28 | depth: 12 29 | drop_path_rate: 0.2 30 | cls_dim: 87 31 | num_heads: 12 32 | group_size: 32 33 | num_group: 64 34 | with_color: False 35 | large_embedding: True 36 | img_queries: 10 37 | text_queries: 1 38 | decoder_depth: 4 39 | 40 | 41 | npoints: 1024 42 | total_bs: 128 43 | step_per_update: 1 44 | max_epoch: 100 45 | grad_norm_clip: 10 46 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/base/openshape.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 10000 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 512 26 | mask_ratio: 0.7 27 | mask_type: rand 28 | embed_dim: 768 29 | depth: 12 30 | drop_path_rate: 0.1 31 | num_heads: 12 32 | decoder_depth: 4 33 | with_color: True 34 | stop_grad: False 35 | large_embedding: False 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: vit_base_patch32_clip_384.openai_ft_in12k_in1k 40 | # pretrained_model_name: vit_base_patch32_clip_224.openai 41 | # pretrained_model_name: vit_base_patch14_dinov2.lvd142m 42 | 43 | modelnet40: 44 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 45 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 46 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 47 | num_workers: 8 48 | batch_size: 128 49 | ratio: 0.5 50 | 51 | objaverse_lvis: 52 | split: ReConV2/data/openshape/meta_data/split/lvis.json 53 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 54 | num_workers: 8 55 | batch_size: 128 56 | ratio: 0.5 57 | 58 | scanobjectnn: 59 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 60 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 61 | num_workers: 8 62 | batch_size: 128 63 | ratio: 0.3 64 | 65 | npoints: 10000 66 | total_bs: 512 67 | step_per_update: 1 68 | max_epoch: 300 69 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/base/openshape_1k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 64 26 | mask_ratio: 0.7 27 | mask_type: causal 28 | embed_dim: 768 29 | depth: 12 30 | drop_path_rate: 0.1 31 | num_heads: 12 32 | decoder_depth: 4 33 | with_color: False 34 | stop_grad: True 35 | large_embedding: True 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: vit_base_patch32_clip_384.openai_ft_in12k_in1k 40 | # pretrained_model_name: vit_base_patch32_clip_224.openai 41 | # pretrained_model_name: vit_base_patch14_dinov2.lvd142m 42 | 43 | modelnet40: 44 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 45 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 46 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 47 | num_workers: 8 48 | batch_size: 128 49 | ratio: 0.5 50 | 51 | objaverse_lvis: 52 | split: ReConV2/data/openshape/meta_data/split/lvis.json 53 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 54 | num_workers: 8 55 | batch_size: 128 56 | ratio: 0.5 57 | 58 | scanobjectnn: 59 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 60 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 61 | num_workers: 8 62 | batch_size: 128 63 | ratio: 0.3 64 | 65 | npoints: 1024 66 | total_bs: 512 67 | step_per_update: 1 68 | max_epoch: 300 69 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/giant/openshape.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 200 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 10000 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 512 26 | mask_ratio: 0.7 27 | mask_type: rand 28 | embed_dim: 1408 29 | depth: 40 30 | drop_path_rate: 0.1 31 | num_heads: 16 32 | decoder_depth: 4 33 | with_color: True 34 | stop_grad: False 35 | large_embedding: False 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: vit_giant_patch14_clip_224.laion2b 40 | 41 | modelnet40: 42 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 43 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 44 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 45 | num_workers: 8 46 | batch_size: 128 47 | ratio: 0.5 48 | 49 | objaverse_lvis: 50 | split: ReConV2/data/openshape/meta_data/split/lvis.json 51 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 52 | num_workers: 8 53 | batch_size: 128 54 | ratio: 0.5 55 | 56 | scanobjectnn: 57 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 58 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 59 | num_workers: 8 60 | batch_size: 128 61 | ratio: 0.3 62 | 63 | npoints: 10000 64 | total_bs: 256 65 | step_per_update: 1 66 | max_epoch: 200 67 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/large/hybrid.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/Hybrid.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | 20 | model: 21 | NAME: ReCon2 22 | group_size: 32 23 | num_group: 64 24 | mask_ratio: 0.7 25 | mask_type: causal 26 | embed_dim: 1024 27 | depth: 24 28 | drop_path_rate: 0.1 29 | num_heads: 16 30 | decoder_depth: 4 31 | with_color: False 32 | stop_grad: True 33 | large_embedding: True 34 | img_queries: 10 35 | text_queries: 1 36 | contrast_type: byol 37 | pretrained_model_name: eva_large_patch14_336.in22k_ft_in22k_in1k 38 | 39 | npoints: 1024 40 | total_bs: 512 41 | step_per_update: 1 42 | max_epoch: 300 43 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/large/hybrid_post.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 100 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 21 | others: 22 | subset: test 23 | npoints: 1024 24 | 25 | model: 26 | NAME: PointTransformer 27 | embed_dim: 1024 28 | depth: 24 29 | drop_path_rate: 0.2 30 | cls_dim: 87 31 | num_heads: 16 32 | group_size: 32 33 | num_group: 64 34 | with_color: False 35 | large_embedding: True 36 | img_queries: 10 37 | text_queries: 1 38 | decoder_depth: 4 39 | 40 | 41 | npoints: 1024 42 | total_bs: 128 43 | step_per_update: 1 44 | max_epoch: 100 45 | grad_norm_clip: 10 46 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/large/openshape.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 10000 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 512 26 | mask_ratio: 0.7 27 | mask_type: rand 28 | embed_dim: 1024 29 | depth: 24 30 | drop_path_rate: 0.1 31 | num_heads: 16 32 | decoder_depth: 4 33 | with_color: True 34 | stop_grad: False 35 | large_embedding: False 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: eva_large_patch14_336.in22k_ft_in22k_in1k 40 | # pretrained_model_name: vit_large_patch14_clip_336.openai 41 | # pretrained_model_name: vit_large_patch14_reg4_dinov2.lvd142m 42 | 43 | modelnet40: 44 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 45 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 46 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 47 | num_workers: 8 48 | batch_size: 128 49 | ratio: 0.5 50 | 51 | objaverse_lvis: 52 | split: ReConV2/data/openshape/meta_data/split/lvis.json 53 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 54 | num_workers: 8 55 | batch_size: 128 56 | ratio: 0.5 57 | 58 | scanobjectnn: 59 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 60 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 61 | num_workers: 8 62 | batch_size: 128 63 | ratio: 0.3 64 | 65 | npoints: 10000 66 | total_bs: 512 67 | step_per_update: 1 68 | max_epoch: 300 69 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/large/openshape_1k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 64 26 | mask_ratio: 0.7 27 | mask_type: causal 28 | embed_dim: 1024 29 | depth: 24 30 | drop_path_rate: 0.1 31 | num_heads: 16 32 | decoder_depth: 4 33 | with_color: False 34 | stop_grad: True 35 | large_embedding: True 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: eva_large_patch14_336.in22k_ft_in22k_in1k 40 | # pretrained_model_name: vit_large_patch14_clip_336.openai 41 | # pretrained_model_name: vit_large_patch14_reg4_dinov2.lvd142m 42 | 43 | modelnet40: 44 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 45 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 46 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 47 | num_workers: 8 48 | batch_size: 128 49 | ratio: 0.5 50 | 51 | objaverse_lvis: 52 | split: ReConV2/data/openshape/meta_data/split/lvis.json 53 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 54 | num_workers: 8 55 | batch_size: 128 56 | ratio: 0.5 57 | 58 | scanobjectnn: 59 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 60 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 61 | num_workers: 8 62 | batch_size: 128 63 | ratio: 0.3 64 | 65 | npoints: 1024 66 | total_bs: 512 67 | step_per_update: 1 68 | max_epoch: 300 69 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/small/hybrid.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 4e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/Hybrid.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | 20 | model: 21 | NAME: ReCon2 22 | group_size: 32 23 | num_group: 64 24 | mask_ratio: 0.6 25 | mask_type: causal 26 | embed_dim: 384 27 | depth: 12 28 | drop_path_rate: 0.1 29 | num_heads: 6 30 | decoder_depth: 4 31 | with_color: False 32 | stop_grad: True 33 | large_embedding: False 34 | img_queries: 10 35 | text_queries: 1 36 | contrast_type: byol 37 | pretrained_model_name: vit_small_patch14_dinov2.lvd142m 38 | 39 | npoints: 1024 40 | total_bs: 512 41 | step_per_update: 1 42 | max_epoch: 300 43 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/small/hybrid_post.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 100 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/HybridLabeled.yaml 21 | others: 22 | subset: test 23 | npoints: 1024 24 | 25 | model: 26 | NAME: PointTransformer 27 | embed_dim: 384 28 | depth: 12 29 | drop_path_rate: 0.2 30 | cls_dim: 87 31 | num_heads: 6 32 | group_size: 32 33 | num_group: 64 34 | with_color: False 35 | large_embedding: False 36 | img_queries: 10 37 | text_queries: 1 38 | decoder_depth: 4 39 | 40 | 41 | npoints: 1024 42 | total_bs: 128 43 | step_per_update: 1 44 | max_epoch: 100 45 | grad_norm_clip: 10 46 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/small/openshape.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 10000 19 | rgb_random_drop_prob: 0.5 20 | 21 | model: 22 | NAME: ReCon2 23 | group_size: 32 24 | num_group: 512 25 | mask_ratio: 0.7 26 | mask_type: rand 27 | embed_dim: 384 28 | depth: 12 29 | drop_path_rate: 0.1 30 | num_heads: 6 31 | decoder_depth: 4 32 | with_color: True 33 | stop_grad: False 34 | large_embedding: False 35 | img_queries: 13 36 | text_queries: 3 37 | contrast_type: byol 38 | pretrained_model_name: vit_small_patch14_dinov2.lvd142m 39 | 40 | modelnet40: 41 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 42 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 43 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 44 | num_workers: 8 45 | batch_size: 128 46 | ratio: 0.5 47 | 48 | objaverse_lvis: 49 | split: ReConV2/data/openshape/meta_data/split/lvis.json 50 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 51 | num_workers: 8 52 | batch_size: 128 53 | ratio: 0.5 54 | 55 | scanobjectnn: 56 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 57 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 58 | num_workers: 8 59 | batch_size: 128 60 | ratio: 0.3 61 | 62 | npoints: 10000 63 | total_bs: 512 64 | step_per_update: 1 65 | max_epoch: 300 66 | -------------------------------------------------------------------------------- /ReConV2/cfgs/pretrain/small/openshape_1k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/OpenShape.yaml 16 | others: 17 | subset: train 18 | npoints: 1024 19 | rgb_random_drop_prob: 0.5 20 | occlusion: False 21 | 22 | model: 23 | NAME: ReCon2 24 | group_size: 32 25 | num_group: 64 26 | mask_ratio: 0.7 27 | mask_type: causal 28 | embed_dim: 384 29 | depth: 12 30 | drop_path_rate: 0.1 31 | num_heads: 6 32 | decoder_depth: 4 33 | with_color: False 34 | stop_grad: True 35 | large_embedding: False 36 | img_queries: 13 37 | text_queries: 3 38 | contrast_type: byol 39 | pretrained_model_name: vit_small_patch14_dinov2.lvd142m 40 | 41 | modelnet40: 42 | test_split: ReConV2/data/openshape/meta_data/modelnet40/test_split.json 43 | test_pc: ReConV2/data/openshape/meta_data/modelnet40/test_pc.npy 44 | clip_feat_path: ReConV2/data/openshape/meta_data/modelnet40/cat_name_pt_feat.npy 45 | num_workers: 8 46 | batch_size: 128 47 | ratio: 0.5 48 | 49 | objaverse_lvis: 50 | split: ReConV2/data/openshape/meta_data/split/lvis.json 51 | clip_feat_path: ReConV2/data/openshape/meta_data/lvis_cat_name_pt_feat.npy 52 | num_workers: 8 53 | batch_size: 128 54 | ratio: 0.5 55 | 56 | scanobjectnn: 57 | data_path: ReConV2/data/openshape/meta_data/scanobjectnn/xyz_label.npy 58 | clip_feat_path: ReConV2/data/openshape/meta_data/scanobjectnn/cat_name_pt_feat.npy 59 | num_workers: 8 60 | batch_size: 128 61 | ratio: 0.3 62 | 63 | npoints: 1024 64 | total_bs: 512 65 | step_per_update: 1 66 | max_epoch: 300 67 | -------------------------------------------------------------------------------- /ReConV2/cfgs/svm/modelnet40.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train: 3 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 4 | others: 5 | subset: train 6 | val: 7 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 8 | others: 9 | subset: test 10 | test: 11 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 12 | others: 13 | subset: test 14 | 15 | model: 16 | NAME: PointTransformerSVM 17 | embed_dim: 1024 18 | depth: 16 19 | drop_path_rate: 0.1 20 | cls_dim: 40 21 | num_heads: 16 22 | group_size: 32 23 | num_group: 64 24 | with_color: False 25 | large_embedding: False 26 | img_queries: 13 27 | text_queries: 3 28 | 29 | npoints: 1024 30 | total_bs: 128 31 | grad_norm_clip: 10 32 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/fewshot.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 150 11 | initial_epochs: 10 12 | 13 | 14 | dataset: 15 | train: 16 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 17 | others: 18 | subset: train 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 21 | others: 22 | subset: test 23 | test: 24 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 25 | others: 26 | subset: test 27 | 28 | model: 29 | NAME: PointTransformer 30 | embed_dim: 768 31 | depth: 12 32 | drop_path_rate: 0.1 33 | cls_dim: 40 34 | num_heads: 12 35 | group_size: 32 36 | num_group: 64 37 | with_color: False 38 | large_embedding: True 39 | img_queries: 13 40 | text_queries: 3 41 | decoder_depth: 4 42 | 43 | npoints: 1024 44 | total_bs: 32 45 | step_per_update: 1 46 | max_epoch: 150 47 | grad_norm_clip: 10 48 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/finetune_modelnet.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 768 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 12 34 | group_size: 32 35 | num_group: 64 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 1024 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/finetune_modelnet_8k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: 'train' 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: 'test' 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: 'test' 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 768 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 12 34 | group_size: 32 35 | num_group: 512 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 8192 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/finetune_scan_hardest.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 768 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 12 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/finetune_scan_objbg.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 768 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 12 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/base/finetune_scan_objonly.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 768 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 12 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/fewshot.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 150 11 | initial_epochs: 10 12 | 13 | 14 | dataset: 15 | train: 16 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 17 | others: 18 | subset: train 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 21 | others: 22 | subset: test 23 | test: 24 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 25 | others: 26 | subset: test 27 | 28 | model: 29 | NAME: PointTransformer 30 | embed_dim: 1024 31 | depth: 24 32 | drop_path_rate: 0.1 33 | cls_dim: 40 34 | num_heads: 16 35 | group_size: 32 36 | num_group: 64 37 | with_color: False 38 | large_embedding: True 39 | img_queries: 13 40 | text_queries: 3 41 | decoder_depth: 4 42 | 43 | npoints: 1024 44 | total_bs: 32 45 | step_per_update: 1 46 | max_epoch: 150 47 | grad_norm_clip: 10 48 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/finetune_modelnet.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 1024 30 | depth: 24 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 16 34 | group_size: 32 35 | num_group: 64 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 1024 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/finetune_modelnet_8k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 1024 30 | depth: 24 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 16 34 | group_size: 32 35 | num_group: 512 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 8192 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/finetune_scan_hardest.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 20 | others: 21 | subset: test 22 | 23 | model: 24 | NAME: PointTransformer 25 | embed_dim: 1024 26 | depth: 24 27 | drop_path_rate: 0.2 28 | cls_dim: 15 29 | num_heads: 16 30 | group_size: 32 31 | num_group: 128 32 | with_color: False 33 | large_embedding: True 34 | img_queries: 13 35 | text_queries: 3 36 | decoder_depth: 4 37 | 38 | npoints: 2048 39 | total_bs: 32 40 | step_per_update: 1 41 | max_epoch: 300 42 | grad_norm_clip: 10 43 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/finetune_scan_objbg.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 1024 30 | depth: 24 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 16 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/large/finetune_scan_objonly.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 1024 30 | depth: 24 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 16 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: True 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/fewshot.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 5e-4 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 150 11 | initial_epochs: 10 12 | 13 | 14 | dataset: 15 | train: 16 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 17 | others: 18 | subset: train 19 | val: 20 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40FewShot.yaml 21 | others: 22 | subset: test 23 | test: 24 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 25 | others: 26 | subset: test 27 | 28 | model: 29 | NAME: PointTransformer 30 | embed_dim: 384 31 | depth: 12 32 | drop_path_rate: 0.1 33 | cls_dim: 40 34 | num_heads: 6 35 | group_size: 32 36 | num_group: 64 37 | with_color: False 38 | large_embedding: False 39 | img_queries: 13 40 | text_queries: 3 41 | decoder_depth: 4 42 | 43 | npoints: 1024 44 | total_bs: 32 45 | step_per_update: 1 46 | max_epoch: 150 47 | grad_norm_clip: 10 48 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/finetune_modelnet.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: 'test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 384 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 6 34 | group_size: 32 35 | num_group: 64 36 | with_color: False 37 | large_embedding: False 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 1024 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/finetune_modelnet_8k.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 1e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ModelNet40.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 384 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 40 33 | num_heads: 6 34 | group_size: 32 35 | num_group: 512 36 | with_color: False 37 | large_embedding: False 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 8192 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/finetune_scan_hardest.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_hardest.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 384 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 6 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: False 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/finetune_scan_objbg.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectbg.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 384 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 6 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: False 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/cfgs/transfer/small/finetune_scan_objonly.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | kwargs: 4 | lr: 2e-5 5 | weight_decay: 0.05 6 | 7 | scheduler: 8 | type: CosLR 9 | kwargs: 10 | epochs: 300 11 | initial_epochs: 10 12 | 13 | dataset: 14 | train: 15 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 16 | others: 17 | subset: train 18 | val: 19 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 20 | others: 21 | subset: test 22 | test: 23 | _base_: ReConV2/cfgs/dataset_configs/ScanObjectNN_objectonly.yaml 24 | others: 25 | subset: test 26 | 27 | model: 28 | NAME: PointTransformer 29 | embed_dim: 384 30 | depth: 12 31 | drop_path_rate: 0.2 32 | cls_dim: 15 33 | num_heads: 6 34 | group_size: 32 35 | num_group: 128 36 | with_color: False 37 | large_embedding: False 38 | img_queries: 13 39 | text_queries: 3 40 | decoder_depth: 4 41 | 42 | npoints: 2048 43 | total_bs: 32 44 | step_per_update: 1 45 | max_epoch: 300 46 | grad_norm_clip: 10 47 | -------------------------------------------------------------------------------- /ReConV2/convert_features.py: -------------------------------------------------------------------------------- 1 | from utils.misc import * 2 | from utils.config import * 3 | from datasets.HybridDataset import Hybrid_depth 4 | 5 | import timm 6 | from tqdm import tqdm 7 | 8 | 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | clip_model = timm.create_model("vit_gigantic_patch14_clip_224.laion2b", pretrained=True).to(device) 11 | 12 | data_root = 'ReConV2/data/HybridDatasets/' 13 | img_path = 'ReConV2/data/HybridDatasets/depth/' 14 | save_path = 'ReConV2/data/HybridDatasets/depth_feature/' 15 | batch_size = 32 16 | 17 | if not os.path.exists(save_path): 18 | os.makedirs(save_path) 19 | 20 | dataset = Hybrid_depth(data_root, 'train', img_path) 21 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 22 | num_workers=8, 23 | drop_last=False, 24 | worker_init_fn=worker_init_fn, 25 | pin_memory=True) 26 | 27 | for img, id in tqdm(train_dataloader): 28 | B, n, h, w, c = img.shape 29 | img = img.reshape(B * n, h, w, c) 30 | img = img.to(device) 31 | feature = clip_model(img) 32 | feature = feature.reshape(B, n, -1) 33 | for i in range(B): 34 | torch.save(feature[i].cpu(), save_path + id[i] + '.pt') 35 | -------------------------------------------------------------------------------- /ReConV2/datasets/ModelNetDatasetFewShot.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Xu Yan 3 | @file: ModelNet.py 4 | @time: 2021/3/19 15:51 5 | ''' 6 | import os 7 | import numpy as np 8 | import warnings 9 | import pickle 10 | 11 | from torch.utils.data import Dataset 12 | from .build import DATASETS 13 | from ReConV2.utils.logger import * 14 | import torch 15 | 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | def pc_normalize(pc): 20 | # normalize pc to [-1, 1] 21 | pc = pc - np.mean(pc, axis=0) 22 | if np.max(np.linalg.norm(pc, axis=1)) < 1e-6: 23 | pc = np.zeros_like(pc) 24 | else: 25 | pc = pc / np.max(np.linalg.norm(pc, axis=1)) 26 | return pc 27 | 28 | 29 | @DATASETS.register_module() 30 | class ModelNetFewShot(Dataset): 31 | def __init__(self, config): 32 | self.root = config.DATA_PATH 33 | self.npoints = config.N_POINTS 34 | self.use_normals = config.USE_NORMALS 35 | self.num_category = config.NUM_CATEGORY 36 | self.process_data = True 37 | self.uniform = True 38 | split = config.subset 39 | self.subset = config.subset 40 | 41 | self.way = config.way 42 | self.shot = config.shot 43 | self.fold = config.fold 44 | if self.way == -1 or self.shot == -1 or self.fold == -1: 45 | raise RuntimeError() 46 | 47 | self.pickle_path = os.path.join(self.root, f'{self.way}way_{self.shot}shot', f'{self.fold}.pkl') 48 | 49 | 50 | print_log('Load processed data from %s...' % self.pickle_path, logger = 'ModelNetFewShot') 51 | 52 | with open(self.pickle_path, 'rb') as f: 53 | self.dataset = pickle.load(f)[self.subset] 54 | 55 | print_log('The size of %s data is %d' % (split, len(self.dataset)), logger = 'ModelNetFewShot') 56 | 57 | def __len__(self): 58 | return len(self.dataset) 59 | 60 | def __getitem__(self, index): 61 | points, label, _ = self.dataset[index] 62 | 63 | points[:, 0:3] = pc_normalize(points[:, 0:3]) 64 | if not self.use_normals: 65 | points = points[:, 0:3] 66 | 67 | pt_idxs = np.arange(0, points.shape[0]) # 2048 68 | if self.subset == 'train': 69 | np.random.shuffle(pt_idxs) 70 | current_points = points[pt_idxs].copy() 71 | current_points = torch.from_numpy(current_points).float() 72 | return 'ModelNet', 'sample', (current_points, label) -------------------------------------------------------------------------------- /ReConV2/datasets/ScanObjectNNDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys, h5py 3 | from torch.utils.data import Dataset 4 | import torch 5 | from .build import DATASETS 6 | 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(BASE_DIR) 9 | 10 | 11 | @DATASETS.register_module() 12 | class ScanObjectNN(Dataset): 13 | def __init__(self, config, **kwargs): 14 | super().__init__() 15 | self.subset = config.subset 16 | self.root = config.ROOT 17 | self.use_color = config.with_color 18 | 19 | if self.subset == 'train': 20 | h5 = h5py.File(os.path.join(self.root, 'training_objectdataset.h5'), 'r') 21 | self.points = np.array(h5['data']).astype(np.float32) 22 | self.labels = np.array(h5['label']).astype(int) 23 | h5.close() 24 | elif self.subset == 'test': 25 | h5 = h5py.File(os.path.join(self.root, 'test_objectdataset.h5'), 'r') 26 | self.points = np.array(h5['data']).astype(np.float32) 27 | self.labels = np.array(h5['label']).astype(int) 28 | h5.close() 29 | else: 30 | raise NotImplementedError() 31 | 32 | print(f'Successfully load ScanObjectNN shape of {self.points.shape}') 33 | 34 | def __getitem__(self, idx): 35 | pt_idxs = np.arange(0, self.points.shape[1]) # 2048 36 | if self.subset == 'train': 37 | np.random.shuffle(pt_idxs) 38 | 39 | current_points = self.points[idx, pt_idxs].copy() 40 | if self.use_color: 41 | rgb = np.ones_like(current_points) * 0.4 42 | current_points = np.concatenate([current_points, rgb], axis=-1) 43 | current_points = torch.from_numpy(current_points).float() 44 | label = self.labels[idx] 45 | 46 | return 'ScanObjectNN', 'sample', (current_points, label) 47 | 48 | def __len__(self): 49 | return self.points.shape[0] 50 | 51 | 52 | @DATASETS.register_module() 53 | class ScanObjectNN_hardest(Dataset): 54 | def __init__(self, config, **kwargs): 55 | super().__init__() 56 | self.subset = config.subset 57 | self.root = config.ROOT 58 | self.use_color = config.with_color 59 | 60 | if self.subset == 'train': 61 | h5 = h5py.File(os.path.join(self.root, 'training_objectdataset_augmentedrot_scale75.h5'), 'r') 62 | self.points = np.array(h5['data']).astype(np.float32) 63 | self.labels = np.array(h5['label']).astype(int) 64 | h5.close() 65 | elif self.subset == 'test': 66 | h5 = h5py.File(os.path.join(self.root, 'test_objectdataset_augmentedrot_scale75.h5'), 'r') 67 | self.points = np.array(h5['data']).astype(np.float32) 68 | self.labels = np.array(h5['label']).astype(int) 69 | h5.close() 70 | else: 71 | raise NotImplementedError() 72 | 73 | print(f'Successfully load ScanObjectNN shape of {self.points.shape}') 74 | 75 | def __getitem__(self, idx): 76 | pt_idxs = np.arange(0, self.points.shape[1]) # 2048 77 | if self.subset == 'train': 78 | np.random.shuffle(pt_idxs) 79 | 80 | current_points = self.points[idx, pt_idxs].copy() 81 | if self.use_color: 82 | rgb = np.ones_like(current_points) * 0.4 83 | current_points = np.concatenate([current_points, rgb], axis=-1) 84 | current_points = torch.from_numpy(current_points).float() 85 | label = self.labels[idx] 86 | 87 | return 'ScanObjectNN', 'sample', (current_points, label) 88 | 89 | def __len__(self): 90 | return self.points.shape[0] 91 | -------------------------------------------------------------------------------- /ReConV2/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_dataset_from_cfg 2 | import datasets.ModelNetDataset 3 | import datasets.ModelNetDatasetFewShot 4 | import datasets.ScanObjectNNDataset 5 | import datasets.HybridDataset 6 | import datasets.OpenShape 7 | -------------------------------------------------------------------------------- /ReConV2/datasets/build.py: -------------------------------------------------------------------------------- 1 | from utils import registry 2 | 3 | DATASETS = registry.Registry('dataset') 4 | 5 | 6 | def build_dataset_from_cfg(cfg, default_args=None): 7 | """ 8 | Build a dataset, defined by `dataset_name`. 9 | Args: 10 | cfg (eDICT): 11 | Returns: 12 | Dataset: a constructed dataset specified by dataset_name. 13 | """ 14 | return DATASETS.build(cfg, default_args=default_args) 15 | -------------------------------------------------------------------------------- /ReConV2/datasets/data_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | class PointcloudRotate(object): 7 | def __call__(self, pc): 8 | xyz = pc[:, :, :3] 9 | bsize = xyz.size()[0] 10 | for i in range(bsize): 11 | rotation_angle = np.random.uniform() * 2 * np.pi 12 | cosval = np.cos(rotation_angle) 13 | sinval = np.sin(rotation_angle) 14 | rotation_matrix = np.array([[cosval, 0, sinval], 15 | [0, 1, 0], 16 | [-sinval, 0, cosval]]) 17 | R = torch.from_numpy(rotation_matrix.astype(np.float32)).to(xyz.device) 18 | xyz[i, :, :] = torch.matmul(xyz[i], R) 19 | pc[:, :, :3] = xyz 20 | return pc 21 | 22 | 23 | class PointcloudScaleAndTranslate(object): 24 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2): 25 | self.scale_low = scale_low 26 | self.scale_high = scale_high 27 | self.translate_range = translate_range 28 | 29 | def __call__(self, pc): 30 | bsize = pc.size()[0] 31 | for i in range(bsize): 32 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 33 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 34 | 35 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy( 36 | xyz2).float().cuda() 37 | 38 | return pc 39 | 40 | 41 | class PointcloudJitter(object): 42 | def __init__(self, std=0.01, clip=0.05): 43 | self.std, self.clip = std, clip 44 | 45 | def __call__(self, pc): 46 | bsize = pc.size()[0] 47 | for i in range(bsize): 48 | jittered_data = pc.new(pc.size(1), 3).normal_( 49 | mean=0.0, std=self.std 50 | ).clamp_(-self.clip, self.clip) 51 | pc[i, :, 0:3] += jittered_data 52 | 53 | return pc 54 | 55 | 56 | class PointcloudScale(object): 57 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): 58 | self.scale_low = scale_low 59 | self.scale_high = scale_high 60 | 61 | def __call__(self, pc): 62 | bsize = pc.size()[0] 63 | for i in range(bsize): 64 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 65 | 66 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) 67 | 68 | return pc 69 | 70 | 71 | class PointcloudTranslate(object): 72 | def __init__(self, translate_range=0.2): 73 | self.translate_range = translate_range 74 | 75 | def __call__(self, pc): 76 | bsize = pc.size()[0] 77 | for i in range(bsize): 78 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 79 | 80 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda() 81 | 82 | return pc 83 | 84 | 85 | class PointcloudRandomInputDropout(object): 86 | def __init__(self, max_dropout_ratio=0.5): 87 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 88 | self.max_dropout_ratio = max_dropout_ratio 89 | 90 | def __call__(self, pc): 91 | bsize = pc.size()[0] 92 | for i in range(bsize): 93 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 94 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0] 95 | if len(drop_idx) > 0: 96 | cur_pc = pc[i, :, :] 97 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point 98 | pc[i, :, :] = cur_pc 99 | 100 | return pc 101 | 102 | 103 | class RandomHorizontalFlip(object): 104 | 105 | def __init__(self, upright_axis='z', is_temporal=False): 106 | """ 107 | upright_axis: axis index among x,y,z, i.e. 2 for z 108 | """ 109 | self.is_temporal = is_temporal 110 | self.D = 4 if is_temporal else 3 111 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] 112 | # Use the rest of axes for flipping. 113 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 114 | 115 | def __call__(self, coords): 116 | bsize = coords.size()[0] 117 | for i in range(bsize): 118 | if random.random() < 0.95: 119 | for curr_ax in self.horz_axes: 120 | if random.random() < 0.5: 121 | coord_max = torch.max(coords[i, :, curr_ax]) 122 | coords[i, :, curr_ax] = coord_max - coords[i, :, curr_ax] 123 | return coords 124 | -------------------------------------------------------------------------------- /ReConV2/datasets/io.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | 6 | class IO: 7 | @classmethod 8 | def get(cls, file_path): 9 | _, file_extension = os.path.splitext(file_path) 10 | 11 | if file_extension in ['.npy']: 12 | return cls._read_npy(file_path) 13 | elif file_extension in ['.npz']: 14 | return cls._read_npz(file_path) 15 | elif file_extension in ['.h5']: 16 | return cls._read_h5(file_path) 17 | elif file_extension in ['.txt']: 18 | return cls._read_txt(file_path) 19 | else: 20 | raise Exception('Unsupported file extension: %s' % file_extension) 21 | 22 | @classmethod 23 | def _read_npy(cls, file_path): 24 | return np.load(file_path) 25 | 26 | @classmethod 27 | def _read_npz(cls, file_path): 28 | return np.load(file_path)['arr_0'] 29 | 30 | @classmethod 31 | def _read_txt(cls, file_path): 32 | return np.loadtxt(file_path) 33 | 34 | @classmethod 35 | def _read_h5(cls, file_path): 36 | f = h5py.File(file_path, 'r') 37 | return f['data'][()] 38 | -------------------------------------------------------------------------------- /ReConV2/extensions/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /ReConV2/extensions/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /ReConV2/extensions/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | script_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | from torch.utils.cpp_extension import load 7 | 8 | if torch.cuda.is_available(): 9 | cd = load(name="cd", 10 | sources=[os.path.join(script_path, "chamfer_distance.cpp"), 11 | os.path.join(script_path, "chamfer_distance.cu")]) 12 | 13 | 14 | class ChamferDistanceFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, xyz1, xyz2): 17 | batchsize, n, _ = xyz1.size() 18 | _, m, _ = xyz2.size() 19 | xyz1 = xyz1.contiguous() 20 | xyz2 = xyz2.contiguous() 21 | dist1 = torch.zeros(batchsize, n) 22 | dist2 = torch.zeros(batchsize, m) 23 | 24 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 25 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 26 | 27 | if not xyz1.is_cuda: 28 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 29 | else: 30 | dist1 = dist1.cuda() 31 | dist2 = dist2.cuda() 32 | idx1 = idx1.cuda() 33 | idx2 = idx2.cuda() 34 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 35 | 36 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 37 | 38 | return dist1, dist2, idx1 39 | 40 | @staticmethod 41 | def backward(ctx, graddist1, graddist2, _): 42 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 43 | 44 | graddist1 = graddist1.contiguous() 45 | graddist2 = graddist2.contiguous() 46 | 47 | gradxyz1 = torch.zeros(xyz1.size()) 48 | gradxyz2 = torch.zeros(xyz2.size()) 49 | 50 | if not graddist1.is_cuda: 51 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 52 | else: 53 | gradxyz1 = gradxyz1.cuda() 54 | gradxyz2 = gradxyz2.cuda() 55 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 56 | 57 | return gradxyz1, gradxyz2 58 | 59 | 60 | class ChamferDistance(torch.nn.Module): 61 | def forward(self, xyz1, xyz2): 62 | return ChamferDistanceFunction.apply(xyz1, xyz2) 63 | -------------------------------------------------------------------------------- /ReConV2/figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/ReConV2/figure/framework.png -------------------------------------------------------------------------------- /ReConV2/generate_depth_map.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from tqdm import tqdm 3 | from ReConV2.utils.misc import * 4 | from ReConV2.datasets.HybridDataset import Hybrid_points 5 | from ReConV2.datasets.pc_render import Realistic_Projection 6 | 7 | data_root = 'ReConV2/data/HybridDatasets/' 8 | save_path = 'ReConV2/data/HybridDatasets/depth/' 9 | batch_size = 32 10 | 11 | if not os.path.exists(save_path): 12 | os.makedirs(save_path) 13 | 14 | dataset = Hybrid_points(data_root, 'train', 1024) 15 | train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 16 | num_workers=8, 17 | drop_last=False, 18 | worker_init_fn=worker_init_fn, 19 | pin_memory=True) 20 | 21 | pc_views = Realistic_Projection() 22 | get_img = pc_views.get_img 23 | 24 | 25 | def real_proj(pc, imsize=256): 26 | img = get_img(pc) 27 | img = torch.nn.functional.interpolate(img, size=(imsize, imsize), mode='bilinear', align_corners=True) 28 | return img 29 | 30 | 31 | for pts, index in tqdm(train_dataloader): 32 | pts = pts.cuda() 33 | img = real_proj(pts) 34 | n, c, w, h = img.shape 35 | batch_size = n // 10 36 | img = img.reshape(batch_size, 10, c, w, h) 37 | 38 | for i in range(batch_size): 39 | for j in range(10): 40 | tensor_image = (img[i, j].cpu().detach().numpy() * 255).astype(np.uint8) 41 | pil_image = Image.fromarray(np.transpose(tensor_image, (1, 2, 0))) 42 | path = save_path + index[i].replace("/", "-")[:-4] + f'-{j}.png' 43 | pil_image.save(path) 44 | -------------------------------------------------------------------------------- /ReConV2/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | from ReConV2.utils.logger import * 5 | from ReConV2.utils.config import * 6 | from ReConV2.tools import svm_run_net as svm 7 | from ReConV2.utils import parser, dist_utils, misc 8 | from ReConV2.tools import test_run_net as test_net 9 | from ReConV2.tools import pretrain_run_net as pretrain 10 | from ReConV2.tools import finetune_run_net as finetune 11 | from ReConV2.tools import zeroshot_run_net as zeroshot 12 | 13 | 14 | def main(): 15 | # args 16 | args = parser.get_args() 17 | # CUDA 18 | torch.backends.cudnn.benchmark = True 19 | # init distributed env first, since logger depends on the dist info. 20 | if args.distributed: 21 | dist_utils.init_dist(args.local_rank) 22 | args.world_size = torch.distributed.get_world_size() 23 | 24 | # logger 25 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 26 | log_file = os.path.join(args.experiment_path, f'{timestamp}.log') 27 | logger = get_root_logger(log_file=log_file, name=args.log_name) 28 | 29 | # config 30 | config = get_config(args, logger=logger) 31 | # batch size 32 | dist_utils.set_batch_size(args, config) 33 | # log 34 | log_args_to_file(args, 'args', logger=logger) 35 | log_config_to_file(config, 'config', logger=logger) 36 | logger.info(f'Distributed training: {args.distributed}') 37 | # set random seeds 38 | if args.seed is not None: 39 | logger.info(f'Set random seed to {args.seed}, ' 40 | f'deterministic: {args.deterministic}') 41 | misc.set_random_seed(args.seed + args.local_rank, 42 | deterministic=args.deterministic) # seed + rank, for augmentation 43 | 44 | if args.shot != -1: 45 | config.dataset.train.others.shot = args.shot 46 | config.dataset.train.others.way = args.way 47 | config.dataset.train.others.fold = args.fold 48 | config.dataset.val.others.shot = args.shot 49 | config.dataset.val.others.way = args.way 50 | config.dataset.val.others.fold = args.fold 51 | 52 | # run 53 | if args.test: 54 | test_net(args, config) 55 | elif args.zeroshot: 56 | zeroshot(args, config) 57 | elif args.svm: 58 | svm(args, config) 59 | elif args.finetune_model: 60 | finetune(args, config) 61 | else: 62 | pretrain(args, config) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /ReConV2/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model_from_cfg 2 | import ReConV2.models.ReCon 3 | import ReConV2.models.transformer 4 | -------------------------------------------------------------------------------- /ReConV2/models/build.py: -------------------------------------------------------------------------------- 1 | from ReConV2.utils import registry 2 | 3 | 4 | MODELS = registry.Registry('models') 5 | 6 | 7 | def build_model_from_cfg(cfg, **kwargs): 8 | """ 9 | Build a dataset, defined by `dataset_name`. 10 | Args: 11 | cfg (eDICT): 12 | Returns: 13 | Dataset: a constructed dataset specified by dataset_name. 14 | """ 15 | return MODELS.build(cfg, **kwargs) 16 | 17 | 18 | -------------------------------------------------------------------------------- /ReConV2/multi_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | 4 | 5 | def main(i): 6 | os.system(f'CUDA_VISIBLE_DEVICES={i} bash ReConV2/scripts/cls.sh {i} test{i} ReConV2/ckpt-last.pth') 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | pool = multiprocessing.Pool(processes=8) 12 | for i in range(8): 13 | p = pool.apply_async(main, (i,)) 14 | pool.close() 15 | pool.join() 16 | -------------------------------------------------------------------------------- /ReConV2/requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | easydict 3 | h5py 4 | matplotlib 5 | numpy 6 | opencv-python 7 | pyyaml 8 | scipy 9 | tensorboardX 10 | tqdm 11 | termcolor 12 | pandas 13 | ninja 14 | ftfy 15 | regex 16 | einops 17 | scikit-learn 18 | torch-scatter==2.0.9 19 | -------------------------------------------------------------------------------- /ReConV2/scripts/downstream/cls.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --config ReConV2/cfgs/transfer/large/finetune_scan_hardest.yaml \ 3 | --finetune_model \ 4 | --exp_name $2 \ 5 | --ckpts $3 \ 6 | --seed $RANDOM -------------------------------------------------------------------------------- /ReConV2/scripts/downstream/fewshot.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --config ReConV2/cfgs/large/fewshot.yaml \ 3 | --finetune_model \ 4 | --exp_name $2 \ 5 | --ckpts $3 \ 6 | --way $4 \ 7 | --shot $5 \ 8 | --fold $6 -------------------------------------------------------------------------------- /ReConV2/scripts/downstream/svm.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --config ReConV2/cfgs/svm/modelnet40.yaml \ 3 | --svm \ 4 | --exp_name $2 \ 5 | --ckpts $3 -------------------------------------------------------------------------------- /ReConV2/scripts/downstream/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --test \ 3 | --config ReConV2/cfgs/full/finetune_modelnet.yaml \ 4 | --exp_name $2 \ 5 | --ckpts $3 \ 6 | --seed $RANDOM -------------------------------------------------------------------------------- /ReConV2/scripts/downstream/zeroshot.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --config ReConV2/cfgs/pretrain/large/openshape.yaml \ 3 | --zeroshot \ 4 | --exp_name $2 \ 5 | --ckpts $3 -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_transfer/pretrain_contrast.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 2 | --nproc_per_node=8 ReConV2/main.py \ 3 | --config ReConV2/cfgs/pretrain/large/openshape_1k.yaml \ 4 | --exp_name $1 \ 5 | --distributed \ 6 | --contrast \ 7 | --ckpts $2 -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_transfer/pretrain_reconstruct.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 2 | --nproc_per_node=8 ReConV2/main.py \ 3 | --config ReConV2/cfgs/pretrain/large/openshape_1k.yaml \ 4 | --exp_name $1 \ 5 | --distributed \ 6 | --reconstruct -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_transfer/pretrain_supervise.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python ReConV2/main.py \ 2 | --config ReConV2/cfgs/pretrain/large/hybrid_post.yaml \ 3 | --finetune_model \ 4 | --exp_name $2 \ 5 | --ckpts $3 \ 6 | --seed $RANDOM -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_zeroshot/pretrain.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 2 | --nproc_per_node=8 ReConV2/main.py \ 3 | --config ReConV2/cfgs/pretrain/large/openshape.yaml \ 4 | --exp_name $1 \ 5 | --distributed -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_zeroshot/pretrain_contrast.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 2 | --nproc_per_node=8 ReConV2/main.py \ 3 | --config ReConV2/cfgs/pretrain/large/openshape.yaml \ 4 | --exp_name $1 \ 5 | --distributed \ 6 | --contrast \ 7 | --ckpts $2 -------------------------------------------------------------------------------- /ReConV2/scripts/pretrain_zeroshot/pretrain_reconstruct.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 2 | --nproc_per_node=8 ReConV2/main.py \ 3 | --config ReConV2/cfgs/pretrain/large/openshape.yaml \ 4 | --exp_name $1 \ 5 | --distributed \ 6 | --reconstruct -------------------------------------------------------------------------------- /ReConV2/segmentation/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def square_distance(src, dst): 5 | """ 6 | Calculate Euclid distance between each two points. 7 | src^T * dst = xn * xm + yn * ym + zn * zm; 8 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 9 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 10 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 11 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 12 | Input: 13 | src: source points, [B, N, C] 14 | dst: target points, [B, M, C] 15 | Output: 16 | dist: per-point square distance, [B, N, M] 17 | """ 18 | B, N, _ = src.shape 19 | _, M, _ = dst.shape 20 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 21 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 22 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 23 | return dist 24 | 25 | 26 | def knn_point(nsample, xyz, new_xyz): 27 | """ 28 | Input: 29 | nsample: max sample number in local region 30 | xyz: all points, [B, N, C] 31 | new_xyz: query points, [B, S, C] 32 | Return: 33 | group_idx: grouped points index, [B, S, nsample] 34 | """ 35 | sqrdists = square_distance(new_xyz, xyz) 36 | _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) 37 | return group_idx -------------------------------------------------------------------------------- /ReConV2/segmentation/seg.sh: -------------------------------------------------------------------------------- 1 | python main.py --gpu $1 --log_dir $2 --ckpts $3 --seed $RANDOM 2 | -------------------------------------------------------------------------------- /ReConV2/segmentation/test.sh: -------------------------------------------------------------------------------- 1 | python main.py --gpu $1 --log_dir $2 --ckpts $3 --seed $RANDOM --test 2 | -------------------------------------------------------------------------------- /ReConV2/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner_pretrain import run_net as pretrain_run_net 2 | from .runner_finetune import run_net as finetune_run_net 3 | from .runner_zeroshot import run_net as zeroshot_run_net 4 | from .runner_svm import run_net as svm_run_net 5 | from .runner_finetune import test_net as test_run_net 6 | -------------------------------------------------------------------------------- /ReConV2/tools/runner_svm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ReConV2.tools import builder 3 | from ReConV2.utils import misc, dist_utils 4 | from ReConV2.utils.logger import * 5 | import numpy as np 6 | from sklearn.svm import LinearSVC 7 | 8 | 9 | class Acc_Metric: 10 | def __init__(self, acc=0.): 11 | if type(acc).__name__ == 'dict': 12 | self.acc = acc['acc'] 13 | elif type(acc).__name__ == 'Acc_Metric': 14 | self.acc = acc.acc 15 | else: 16 | self.acc = acc 17 | 18 | def better_than(self, other): 19 | if self.acc > other.acc: 20 | return True 21 | else: 22 | return False 23 | 24 | def state_dict(self): 25 | _dict = dict() 26 | _dict['acc'] = self.acc 27 | return _dict 28 | 29 | 30 | def itr_merge(*itrs): 31 | for itr in itrs: 32 | for v in itr: 33 | yield v 34 | 35 | 36 | def evaluate_svm(train_features, train_labels, test_features, test_labels): 37 | clf = LinearSVC(C=0.075) 38 | clf.fit(train_features, train_labels) 39 | pred = clf.predict(test_features) 40 | return np.sum(test_labels == pred) * 1. / pred.shape[0] 41 | 42 | 43 | def run_net(args, config, train_writer=None, val_writer=None): 44 | logger = get_logger(args.log_name) 45 | print_log('Start SVM test... ', logger=logger) 46 | # build dataset 47 | (train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \ 48 | builder.dataset_builder(args, config.dataset.val) 49 | # build model 50 | base_model = builder.model_builder(config.model) 51 | base_model.load_model_from_ckpt(args.ckpts) 52 | 53 | if args.use_gpu: 54 | base_model.to(args.local_rank) 55 | base_model.eval() 56 | 57 | test_features = [] 58 | test_label = [] 59 | 60 | train_features = [] 61 | train_label = [] 62 | npoints = config.npoints 63 | with torch.no_grad(): 64 | for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader): 65 | points = data[0].cuda() 66 | label = data[1].cuda() 67 | 68 | points = misc.fps(points, npoints) 69 | 70 | assert points.size(1) == npoints 71 | feature = base_model(points) 72 | target = label.view(-1) 73 | 74 | train_features.append(feature.detach()) 75 | train_label.append(target.detach()) 76 | 77 | for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader): 78 | points = data[0].cuda() 79 | label = data[1].cuda() 80 | 81 | points = misc.fps(points, npoints) 82 | assert points.size(1) == npoints 83 | feature = base_model(points) 84 | target = label.view(-1) 85 | 86 | test_features.append(feature.detach()) 87 | test_label.append(target.detach()) 88 | 89 | train_features = torch.cat(train_features, dim=0) 90 | train_label = torch.cat(train_label, dim=0) 91 | test_features = torch.cat(test_features, dim=0) 92 | test_label = torch.cat(test_label, dim=0) 93 | 94 | if args.distributed: 95 | train_features = dist_utils.gather_tensor(train_features, args) 96 | train_label = dist_utils.gather_tensor(train_label, args) 97 | test_features = dist_utils.gather_tensor(test_features, args) 98 | test_label = dist_utils.gather_tensor(test_label, args) 99 | 100 | acc = evaluate_svm(train_features.data.cpu().numpy(), train_label.data.cpu().numpy(), 101 | test_features.data.cpu().numpy(), test_label.data.cpu().numpy()) 102 | 103 | print_log('[TEST_SVM] acc = %.4f' % (acc * 100), logger=logger) 104 | 105 | if args.distributed: 106 | torch.cuda.synchronize() 107 | -------------------------------------------------------------------------------- /ReConV2/tools/runner_zeroshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ReConV2.tools import builder 4 | from ReConV2.utils.logger import * 5 | from ReConV2.tools.runner_pretrain import Trainer 6 | from ReConV2.datasets import data 7 | 8 | 9 | def run_net(args, config): 10 | logger = get_logger(args.log_name) 11 | 12 | # build model 13 | device = torch.device("cuda", args.local_rank) 14 | 15 | base_model = builder.model_builder(config.model) 16 | base_model.to(device) 17 | 18 | modelnet40_loader = data.make_modelnet40test(config) 19 | scanobjectnn_loader = data.make_scanobjectnntest(config) 20 | objaverse_lvis_loader = data.make_objaverse_lvis(config) 21 | 22 | base_model.zero_grad() 23 | triner = Trainer(args.local_rank, args, config, base_model, None, None, device, 24 | logger, modelnet40_loader=modelnet40_loader, scanobjectnn_loader=scanobjectnn_loader, 25 | objaverse_lvis_loader=objaverse_lvis_loader) 26 | 27 | triner.load_from_checkpoint(args.ckpts) 28 | triner.model_parallel() 29 | triner.test_modelnet40() 30 | triner.test_scanobjectnn() 31 | triner.test_objaverse_lvis() 32 | -------------------------------------------------------------------------------- /ReConV2/utils/AverageMeter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | def __init__(self, items=None): 4 | self.items = items 5 | self.n_items = 1 if items is None else len(items) 6 | self.reset() 7 | 8 | def reset(self): 9 | self._val = [0] * self.n_items 10 | self._sum = [0] * self.n_items 11 | self._count = [0] * self.n_items 12 | 13 | def update(self, values): 14 | if type(values).__name__ == 'list': 15 | for idx, v in enumerate(values): 16 | self._val[idx] = v 17 | self._sum[idx] += v 18 | self._count[idx] += 1 19 | else: 20 | self._val[0] = values 21 | self._sum[0] += values 22 | self._count[0] += 1 23 | 24 | def val(self, idx=None): 25 | if idx is None: 26 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)] 27 | else: 28 | return self._val[idx] 29 | 30 | def count(self, idx=None): 31 | if idx is None: 32 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)] 33 | else: 34 | return self._count[idx] 35 | 36 | def avg(self, idx=None): 37 | if idx is None: 38 | return self._sum[0] / self._count[0] if self.items is None else [ 39 | self._sum[i] / self._count[i] for i in range(self.n_items) 40 | ] 41 | else: 42 | return self._sum[idx] / self._count[idx] -------------------------------------------------------------------------------- /ReConV2/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from collections import defaultdict 5 | import torch.nn as nn 6 | 7 | from typing import Any 8 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable 9 | 10 | from termcolor import colored 11 | 12 | def get_missing_parameters_message(keys: List[str]) -> str: 13 | """ 14 | Get a logging-friendly message to report parameter names (keys) that are in 15 | the model but not found in a checkpoint. 16 | Args: 17 | keys (list[str]): List of keys that were not found in the checkpoint. 18 | Returns: 19 | str: message. 20 | """ 21 | groups = _group_checkpoint_keys(keys) 22 | msg = "Some model parameters or buffers are not found in the checkpoint:\n" 23 | msg += "\n".join( 24 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() 25 | ) 26 | return msg 27 | 28 | 29 | def get_unexpected_parameters_message(keys: List[str]) -> str: 30 | """ 31 | Get a logging-friendly message to report parameter names (keys) that are in 32 | the checkpoint but not found in the model. 33 | Args: 34 | keys (list[str]): List of keys that were not found in the model. 35 | Returns: 36 | str: message. 37 | """ 38 | groups = _group_checkpoint_keys(keys) 39 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n" 40 | msg += "\n".join( 41 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() 42 | ) 43 | return msg 44 | 45 | 46 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: 47 | """ 48 | Strip the prefix in metadata, if any. 49 | Args: 50 | state_dict (OrderedDict): a state-dict to be loaded to the model. 51 | prefix (str): prefix. 52 | """ 53 | keys = sorted(state_dict.keys()) 54 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys): 55 | return 56 | 57 | for key in keys: 58 | newkey = key[len(prefix):] 59 | state_dict[newkey] = state_dict.pop(key) 60 | 61 | # also strip the prefix in metadata, if any.. 62 | try: 63 | metadata = state_dict._metadata # pyre-ignore 64 | except AttributeError: 65 | pass 66 | else: 67 | for key in list(metadata.keys()): 68 | # for the metadata dict, the key can be: 69 | # '': for the DDP module, which we want to remove. 70 | # 'module': for the actual model. 71 | # 'module.xx.xx': for the rest. 72 | 73 | if len(key) == 0: 74 | continue 75 | newkey = key[len(prefix):] 76 | metadata[newkey] = metadata.pop(key) 77 | 78 | 79 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: 80 | """ 81 | Group keys based on common prefixes. A prefix is the string up to the final 82 | "." in each key. 83 | Args: 84 | keys (list[str]): list of parameter names, i.e. keys in the model 85 | checkpoint dict. 86 | Returns: 87 | dict[list]: keys with common prefixes are grouped into lists. 88 | """ 89 | groups = defaultdict(list) 90 | for key in keys: 91 | pos = key.rfind(".") 92 | if pos >= 0: 93 | head, tail = key[:pos], [key[pos + 1:]] 94 | else: 95 | head, tail = key, [] 96 | groups[head].extend(tail) 97 | return groups 98 | 99 | 100 | def _group_to_str(group: List[str]) -> str: 101 | """ 102 | Format a group of parameter name suffixes into a loggable string. 103 | Args: 104 | group (list[str]): list of parameter name suffixes. 105 | Returns: 106 | str: formated string. 107 | """ 108 | if len(group) == 0: 109 | return "" 110 | 111 | if len(group) == 1: 112 | return "." + group[0] 113 | 114 | return ".{" + ", ".join(group) + "}" 115 | 116 | 117 | def _named_modules_with_dup( 118 | model: nn.Module, prefix: str = "" 119 | ) -> Iterable[Tuple[str, nn.Module]]: 120 | """ 121 | The same as `model.named_modules()`, except that it includes 122 | duplicated modules that have more than one name. 123 | """ 124 | yield prefix, model 125 | for name, module in model._modules.items(): # pyre-ignore 126 | if module is None: 127 | continue 128 | submodule_prefix = prefix + ("." if prefix else "") + name 129 | yield from _named_modules_with_dup(module, submodule_prefix) -------------------------------------------------------------------------------- /ReConV2/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict 3 | import os 4 | from .logger import print_log 5 | 6 | 7 | def log_args_to_file(args, pre='args', logger=None): 8 | for key, val in args.__dict__.items(): 9 | print_log(f'{pre}.{key} : {val}', logger=logger) 10 | 11 | 12 | def log_config_to_file(cfg, pre='cfg', logger=None): 13 | for key, val in cfg.items(): 14 | if isinstance(cfg[key], EasyDict): 15 | print_log(f'{pre}.{key} = edict()', logger=logger) 16 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 17 | continue 18 | print_log(f'{pre}.{key} : {val}', logger=logger) 19 | 20 | 21 | def merge_new_config(config, new_config): 22 | for key, val in new_config.items(): 23 | if not isinstance(val, dict): 24 | if key == '_base_': 25 | with open(new_config['_base_'], 'r') as f: 26 | try: 27 | val = yaml.load(f, Loader=yaml.FullLoader) 28 | except: 29 | val = yaml.load(f) 30 | config[key] = EasyDict() 31 | merge_new_config(config[key], val) 32 | else: 33 | config[key] = val 34 | continue 35 | if key not in config: 36 | config[key] = EasyDict() 37 | merge_new_config(config[key], val) 38 | return config 39 | 40 | 41 | def cfg_from_yaml_file(cfg_file): 42 | config = EasyDict() 43 | with open(cfg_file, 'r') as f: 44 | try: 45 | new_config = yaml.load(f, Loader=yaml.FullLoader) 46 | except: 47 | new_config = yaml.load(f) 48 | merge_new_config(config=config, new_config=new_config) 49 | return config 50 | 51 | 52 | def get_config(args, logger=None): 53 | if args.resume: 54 | cfg_path = os.path.join(args.experiment_path, 'config.yaml') 55 | if not os.path.exists(cfg_path): 56 | print_log("Failed to resume", logger=logger) 57 | raise FileNotFoundError() 58 | print_log(f'Resume yaml from {cfg_path}', logger=logger) 59 | args.config = cfg_path 60 | config = cfg_from_yaml_file(args.config) 61 | if not args.resume and args.local_rank == 0: 62 | save_experiment_config(args, config, logger) 63 | return config 64 | 65 | 66 | def save_experiment_config(args, config, logger=None): 67 | config_path = os.path.join(args.experiment_path, 'config.yaml') 68 | os.system('cp %s %s' % (args.config, config_path)) 69 | print_log(f'Copy the Config file from {args.config} to {config_path}', logger=logger) 70 | -------------------------------------------------------------------------------- /ReConV2/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_rotate_z(pc): 5 | # random roate around z axis 6 | theta = np.random.uniform(0, 2 * np.pi) 7 | R = np.array([[np.cos(theta), -np.sin(theta), 0], 8 | [np.sin(theta), np.cos(theta), 0], 9 | [0, 0, 1]]) 10 | return np.matmul(pc, R) 11 | 12 | 13 | def normalize_pc(pc): 14 | """ pc: NxC, return NxC """ 15 | centroid = np.mean(pc, axis=0) 16 | pc = pc - centroid 17 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 18 | if m < 1e-6: 19 | pc = np.zeros_like(pc) 20 | else: 21 | pc = pc / m 22 | return pc 23 | 24 | 25 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 26 | """ batch_pc: BxNx3 """ 27 | for b in range(batch_pc.shape[0]): 28 | dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875 29 | drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0] 30 | if len(drop_idx) > 0: 31 | batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point 32 | return batch_pc 33 | 34 | 35 | def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25): 36 | 37 | scales = np.random.uniform(scale_low, scale_high) 38 | data *= scales 39 | return data 40 | 41 | 42 | def shift_point_cloud(batch_data, shift_range=0.1): 43 | """ Randomly shift point cloud. Shift is per point cloud. 44 | Input: 45 | BxNx3 array, original batch of point clouds 46 | Return: 47 | BxNx3 array, shifted batch of point clouds 48 | """ 49 | B, N, C = batch_data.shape 50 | shifts = np.random.uniform(-shift_range, shift_range, (B, 3)) 51 | for batch_index in range(B): 52 | batch_data[batch_index, :, :] += shifts[batch_index, :] 53 | return batch_data 54 | 55 | 56 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 57 | """ Randomly perturb the point clouds by small rotations 58 | Input: 59 | BxNx3 array, original batch of point clouds 60 | Return: 61 | BxNx3 array, rotated batch of point clouds 62 | """ 63 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 64 | for k in range(batch_data.shape[0]): 65 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 66 | Rx = np.array([[1, 0, 0], 67 | [0, np.cos(angles[0]), -np.sin(angles[0])], 68 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 69 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 70 | [0, 1, 0], 71 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 72 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 73 | [np.sin(angles[2]), np.cos(angles[2]), 0], 74 | [0, 0, 1]]) 75 | R = np.dot(Rz, np.dot(Ry, Rx)) 76 | shape_pc = batch_data[k, ...] 77 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 78 | return rotated_data 79 | 80 | 81 | def rotate_point_cloud(batch_data): 82 | """ Randomly rotate the point clouds to augument the dataset 83 | rotation is per shape based along up direction 84 | Input: 85 | BxNx3 array, original batch of point clouds 86 | Return: 87 | BxNx3 array, rotated batch of point clouds 88 | """ 89 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 90 | for k in range(batch_data.shape[0]): 91 | rotation_angle = np.random.uniform() * 2 * np.pi 92 | cosval = np.cos(rotation_angle) 93 | sinval = np.sin(rotation_angle) 94 | rotation_matrix = np.array([[cosval, 0, sinval], 95 | [0, 1, 0], 96 | [-sinval, 0, cosval]]) 97 | shape_pc = batch_data[k, ...] 98 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 99 | return rotated_data 100 | 101 | 102 | def augment_pc(data): 103 | # data = random_point_dropout(data[None, ...]) 104 | data = random_scale_point_cloud(data[None, ...]) 105 | data = shift_point_cloud(data) 106 | data = rotate_perturbation_point_cloud(data) 107 | data = rotate_point_cloud(data) 108 | data = data.squeeze() 109 | return data 110 | -------------------------------------------------------------------------------- /ReConV2/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | 4 | 5 | def init_dist(local_rank, backend='nccl', **kwargs): 6 | torch.cuda.set_device(local_rank) 7 | dist.init_process_group(backend=backend, **kwargs) 8 | print(f'init distributed in rank {local_rank}') 9 | 10 | 11 | def reduce_tensor(tensor, args): 12 | ''' 13 | for acc kind, get the mean in each gpu 14 | ''' 15 | rt = tensor.clone() 16 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 17 | rt /= args.world_size 18 | return rt 19 | 20 | 21 | def gather_tensor(tensor, args): 22 | output_tensors = [tensor.clone() for _ in range(args.world_size)] 23 | torch.distributed.all_gather(output_tensors, tensor) 24 | concat = torch.cat(output_tensors, dim=0) 25 | return concat 26 | 27 | 28 | def set_batch_size(args, config): 29 | if args.distributed: 30 | assert config.total_bs % args.world_size == 0 31 | if config.dataset.get('train'): 32 | config.dataset.train.others.bs = config.total_bs // args.world_size 33 | if config.dataset.get('extra_train'): 34 | config.dataset.extra_train.others.bs = config.total_bs // args.world_size 35 | if config.dataset.get('val'): 36 | config.dataset.val.others.bs = config.total_bs // args.world_size 37 | if config.dataset.get('test'): 38 | config.dataset.test.others.bs = config.total_bs // args.world_size 39 | else: 40 | if config.dataset.get('train'): 41 | config.dataset.train.others.bs = config.total_bs 42 | if config.dataset.get('extra_train'): 43 | config.dataset.extra_train.others.bs = config.total_bs 44 | if config.dataset.get('extra_val'): 45 | config.dataset.extra_val.others.bs = config.total_bs 46 | if config.dataset.get('val'): 47 | config.dataset.val.others.bs = config.total_bs 48 | if config.dataset.get('test'): 49 | config.dataset.test.others.bs = config.total_bs 50 | -------------------------------------------------------------------------------- /ReConV2/utils/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def square_distance(src, dst): 5 | """ 6 | Calculate Euclid distance between each two points. 7 | src^T * dst = xn * xm + yn * ym + zn * zm; 8 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 9 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 10 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 11 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 12 | Input: 13 | src: source points, [B, N, C] 14 | dst: target points, [B, M, C] 15 | Output: 16 | dist: per-point square distance, [B, N, M] 17 | """ 18 | B, N, _ = src.shape 19 | _, M, _ = dst.shape 20 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 21 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 22 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 23 | return dist 24 | 25 | 26 | def knn_point(nsample, xyz, new_xyz): 27 | """ 28 | Input: 29 | nsample: max sample number in local region 30 | xyz: all points, [B, N, C] 31 | new_xyz: query points, [B, S, C] 32 | Return: 33 | group_idx: grouped points index, [B, S, nsample] 34 | """ 35 | sqrdists = square_distance(new_xyz, xyz) 36 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) 37 | return group_idx 38 | -------------------------------------------------------------------------------- /ReConV2/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /ReConV2/utils/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | '--config', 10 | type=str, 11 | help='yaml config file') 12 | parser.add_argument('--distributed', action='store_true', default=False) 13 | parser.add_argument('--local-rank', type=int, default=0) 14 | parser.add_argument('--num_workers', type=int, default=8) 15 | # seed 16 | parser.add_argument('--seed', type=int, default=0, help='random seed') 17 | parser.add_argument( 18 | '--deterministic', 19 | action='store_true', 20 | help='whether to set deterministic options for CUDNN backend.') 21 | # bn 22 | parser.add_argument( 23 | '--sync_bn', 24 | action='store_true', 25 | default=False, 26 | help='whether to use sync bn') 27 | # some args 28 | parser.add_argument('--exp_name', type=str, default='default', help='experiment name') 29 | parser.add_argument('--start_ckpts', type=str, default=None, help='reload used ckpt path') 30 | parser.add_argument('--ckpts', type=str, default=None, help='test used ckpt path') 31 | parser.add_argument('--val_freq', type=int, default=1, help='test freq') 32 | parser.add_argument( 33 | '--vote', 34 | action='store_true', 35 | default=False, 36 | help='vote acc') 37 | parser.add_argument( 38 | '--resume', 39 | action='store_true', 40 | default=False, 41 | help='autoresume training (interrupted by accident)') 42 | parser.add_argument( 43 | '--svm', 44 | action='store_true', 45 | default=False, 46 | help='svm') 47 | parser.add_argument( 48 | '--zeroshot', 49 | action='store_true', 50 | default=False, 51 | help='zero-shot') 52 | parser.add_argument( 53 | '--test', 54 | action='store_true', 55 | default=False, 56 | help='test mode for certain ckpt') 57 | parser.add_argument( 58 | '--reconstruct', 59 | action='store_true', 60 | default=False, 61 | help='reconstruct pretraining stage') 62 | parser.add_argument( 63 | '--contrast', 64 | action='store_true', 65 | default=False, 66 | help='contrast pretraining stage') 67 | parser.add_argument( 68 | '--finetune_model', 69 | action='store_true', 70 | default=False, 71 | help='finetune modelnet with pretrained weight') 72 | parser.add_argument( 73 | '--way', type=int, default=-1) 74 | parser.add_argument( 75 | '--shot', type=int, default=-1) 76 | parser.add_argument( 77 | '--fold', type=int, default=-1) 78 | 79 | args = parser.parse_args() 80 | 81 | if args.test and args.resume: 82 | raise ValueError( 83 | '--test and --resume cannot be both activate') 84 | 85 | if args.resume and args.start_ckpts is not None: 86 | raise ValueError( 87 | '--resume and --start_ckpts cannot be both activate') 88 | 89 | if args.test and args.ckpts is None: 90 | raise ValueError( 91 | 'ckpts shouldnt be None while test mode') 92 | 93 | if args.finetune_model and args.ckpts is None: 94 | print( 95 | 'training from scratch') 96 | 97 | if 'LOCAL_RANK' not in os.environ: 98 | os.environ['LOCAL_RANK'] = str(args.local_rank) 99 | 100 | if args.test: 101 | args.exp_name = 'test_' + args.exp_name 102 | args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, 103 | args.exp_name) 104 | args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, 'TFBoard', 105 | args.exp_name) 106 | args.log_name = Path(args.config).stem 107 | create_experiment_dir(args) 108 | return args 109 | 110 | 111 | def create_experiment_dir(args): 112 | if not os.path.exists(args.experiment_path): 113 | os.makedirs(args.experiment_path, exist_ok=True) 114 | print('Create experiment path successfully at %s' % args.experiment_path) 115 | if not os.path.exists(args.tfboard_path): 116 | os.makedirs(args.tfboard_path, exist_ok=True) 117 | print('Create TFBoard path successfully at %s' % args.tfboard_path) 118 | -------------------------------------------------------------------------------- /ReConV2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | from ReConV2.utils.randaugment import RandAugmentMC 4 | __all__ = ['get_transforms'] 5 | 6 | 7 | class ResizeImage(): 8 | def __init__(self, size): 9 | if isinstance(size, int): 10 | self.size = (int(size), int(size)) 11 | else: 12 | self.size = size 13 | 14 | def __call__(self, img): 15 | th, tw = self.size 16 | return img.resize((th, tw)) 17 | 18 | 19 | class PlaceCrop(object): 20 | """Crops the given PIL.Image at the particular index. 21 | Args: 22 | size (sequence or int): Desired output size of the crop. If size is an 23 | int instead of sequence like (w, h), a square crop (size, size) is 24 | made. 25 | """ 26 | 27 | def __init__(self, size, start_x, start_y): 28 | if isinstance(size, int): 29 | self.size = (int(size), int(size)) 30 | else: 31 | self.size = size 32 | self.start_x = start_x 33 | self.start_y = start_y 34 | 35 | def __call__(self, img): 36 | """ 37 | Args: 38 | img (PIL.Image): Image to be cropped. 39 | Returns: 40 | PIL.Image: Cropped image. 41 | """ 42 | th, tw = self.size 43 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 44 | 45 | 46 | class ForceFlip(object): 47 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 48 | 49 | def __call__(self, img): 50 | """ 51 | Args: 52 | img (PIL.Image): Image to be flipped. 53 | Returns: 54 | PIL.Image: Randomly flipped image. 55 | """ 56 | return img.transpose(Image.FLIP_LEFT_RIGHT) 57 | 58 | 59 | def transform_train(resize_size=256, crop_size=224): 60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225]) 62 | return transforms.Compose([ 63 | # ResizeImage(resize_size), 64 | # transforms.RandomHorizontalFlip(), 65 | # transforms.RandomResizedCrop(crop_size, scale=(0.64, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 66 | # RandAugmentMC(n=2, m=10), 67 | ResizeImage(crop_size), 68 | transforms.ToTensor(), 69 | normalize 70 | ]) 71 | 72 | 73 | def get_transforms(resize_size=256, crop_size=224): 74 | transforms = { 75 | 'train': transform_train(resize_size, crop_size) 76 | } 77 | 78 | return transforms 79 | -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/assets/framework.jpg -------------------------------------------------------------------------------- /assets/instrument.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/assets/instrument.npy -------------------------------------------------------------------------------- /docs/DATA.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | ## ShapeLLM 4 | 5 | ShapeLLM utilizes text-3D paired data provided by Cap3D for the first-stage alignment training and employs GPT-4V to construct multi-view-based general data for supervised finetuning. 6 | To equip the model with the capability of 3D Visual Grounding, we also leverage GPT4 based on GAPartNet to build SFT data for Embodied Understanding. 7 | All data except Cap3D_pts can be directly downloaded from [Hugging Face](https://huggingface.co/datasets/qizekun/ShapeLLM). Cap3D_pcs data needs to be obtained in pt form from the [Cap3D repository](https://huggingface.co/datasets/tiange/Cap3D/tree/main). 8 | 9 | | Data file name | Size | 10 | |---------------------------------------------------------------------------------------------------------------------------|------------:| 11 | | [cap3d_objaverse_785k.json](https://huggingface.co/datasets/qizekun/ShapeLLM/blob/main/cap3d_objaverse_785k.json) | 242 MB | 12 | | [cap3d_objaverse_sft_45k.json](https://huggingface.co/datasets/qizekun/ShapeLLM/blob/main/cap3d_objaverse_sft_45k.json) | 16.9 MB | 13 | | [gapartnet_sft_27k_openai.json](https://huggingface.co/datasets/qizekun/ShapeLLM/blob/main/gapartnet_sft_27k_openai.json) | 12.5 MB | 14 | | [gapartnet_pcs.zip](https://huggingface.co/datasets/qizekun/ShapeLLM/blob/main/gapartnet_pcs.zip) | 4.59 GB | 15 | | [cap3d_pcs](https://huggingface.co/datasets/tiange/Cap3D/tree/main/PointCloud_pt_zips) | 173.8 GB | 16 | Organize the data as follows in `./playground/data/shapellm/` 17 | ``` 18 | │playground/data/shapellm/ 19 | ├── cap3d_objaverse_785k.json 20 | ├── cap3d_objaverse_sft_45k.json 21 | ├── gapartnet_sft_27k_openai.json 22 | ├── gapartnet_pcs 23 | │ ├── Box_100129_0_0.npy 24 | │ └── ... 25 | └── cap3d_pcs 26 | ├── 00000054c36d44a2a483bdbff31d8edf.pt 27 | └── ... 28 | ``` 29 | 30 | ## ReCon++ 31 | The overall directory structure should be: 32 | ``` 33 | │llava/ 34 | │ReConV2/ 35 | │ └──data/ 36 | │ ├──OpenShape/ 37 | │ ├──ModelNet/ 38 | │ ├──ModelNetFewshot/ 39 | │ └──ScanObjectNN/ 40 | ``` 41 | 42 | ### OpenShape Dataset: 43 | 44 | ``` 45 | │OpenShape/ 46 | ├──objaverse-processed/ 47 | │ └── merged_for_training_final/ 48 | │ ├── 3D-FUTURE/ 49 | │ ├── ABO/ 50 | │ ├── Objaverse/ 51 | │ └── ShapeNet/ 52 | ├──meta_data/ 53 | │ ├── modelnet40/ 54 | │ ├── scanobjectnn/ 55 | │ ├── split/ 56 | │ ├── gpt4_filtering.json 57 | │ ├── lvis_cat_name_pt_feat.npy 58 | │ └── point_feat_knn.npy 59 | ``` 60 | Download: You can download the processed data from [OpenShape Hugging Face](https://huggingface.co/datasets/OpenShape/openshape-training-data/tree/main). Note that the rendered image data is not necessary. 61 | 62 | 63 | ### ModelNet40 Dataset: 64 | 65 | ``` 66 | │ModelNet/ 67 | ├──modelnet40_normal_resampled/ 68 | │ ├── modelnet40_shape_names.txt 69 | │ ├── modelnet40_train.txt 70 | │ ├── modelnet40_test.txt 71 | │ ├── modelnet40_train_8192pts_fps.dat 72 | │ └── modelnet40_test_8192pts_fps.dat 73 | ``` 74 | Download: You can download the processed data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md), or download from the [official website](https://modelnet.cs.princeton.edu/#) and process it by yourself. 75 | 76 | ### ModelNet Few-shot Dataset: 77 | ``` 78 | │ModelNetFewshot/ 79 | ├──5way10shot/ 80 | │ ├── 0.pkl 81 | │ ├── ... 82 | │ ├── 9.pkl 83 | ├──5way20shot/ 84 | │ ├── ... 85 | ├──10way10shot/ 86 | │ ├── ... 87 | ├──10way20shot/ 88 | │ ├── ... 89 | ``` 90 | 91 | Download: Please download the data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md). We use the same data split as theirs. 92 | 93 | ### ScanObjectNN Dataset: 94 | ``` 95 | │ScanObjectNN/ 96 | ├──main_split/ 97 | │ ├── training_objectdataset_augmentedrot_scale75.h5 98 | │ ├── test_objectdataset_augmentedrot_scale75.h5 99 | │ ├── training_objectdataset.h5 100 | │ ├── test_objectdataset.h5 101 | ├──main_split_nobg/ 102 | │ ├── training_objectdataset.h5 103 | │ └── test_objectdataset.h5 104 | ``` 105 | Download: Please download the data from the [official website](https://hkust-vgd.github.io/scanobjectnn/). 106 | -------------------------------------------------------------------------------- /docs/LoRA.md: -------------------------------------------------------------------------------- 1 | # LLaVA (LoRA, Preview) 2 | 3 | NOTE: This is a technical preview, and is not yet ready for production use. We are still running hyperparameter search for the LoRA model, and will release the final model soon. If you'd like to contribute to this, please contact us. 4 | 5 | You need latest code base for LoRA support (instructions [here](https://github.com/haotian-liu/LLaVA#upgrade-to-latest-code-base)) 6 | 7 | ## Demo (Web UI) 8 | 9 | Please execute each of the commands below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)). 10 | 11 | #### Launch a controller 12 | ```Shell 13 | python -m llava.serve.controller --host 0.0.0.0 --port 10000 14 | ``` 15 | 16 | #### Launch a gradio web server. 17 | ```Shell 18 | python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload 19 | ``` 20 | You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker. 21 | 22 | #### Launch a model worker 23 | ```Shell 24 | python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-vicuna-7b-v1.1-lcs_558k-instruct_80k_3e-lora-preview-alpha --model-base /path/to/vicuna-v1.1 25 | ``` 26 | Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list. 27 | 28 | You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker. 29 | 30 | 31 | ## Training 32 | 33 | Please see sample training scripts for [LoRA](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_lora.sh) and [QLoRA](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_qlora.sh). 34 | 35 | We provide sample DeepSpeed configs, [`zero3.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3.json) is more like PyTorch FSDP, and [`zero3_offload.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3_offload.json) can further save memory consumption by offloading parameters to CPU. `zero3.json` is usually faster than `zero3_offload.json` but requires more GPU memory, therefore, we recommend trying `zero3.json` first, and if you run out of GPU memory, try `zero3_offload.json`. You can also tweak the `per_device_train_batch_size` and `gradient_accumulation_steps` in the config to save memory, and just to make sure that `per_device_train_batch_size` and `gradient_accumulation_steps` remains the same. 36 | 37 | If you are having issues with ZeRO-3 configs, and there are enough VRAM, you may try [`zero2.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero2.json). This consumes slightly more memory than ZeRO-3, and behaves more similar to PyTorch FSDP, while still supporting parameter-efficient tuning. 38 | 39 | ## Create Merged Checkpoints 40 | 41 | ```Shell 42 | python scripts/merge_lora_weights.py \ 43 | --model-path /path/to/lora_model \ 44 | --model-base /path/to/base_model \ 45 | --save-model-path /path/to/merge_model 46 | ``` 47 | -------------------------------------------------------------------------------- /docs/Windows.md: -------------------------------------------------------------------------------- 1 | # Run LLaVA on Windows 2 | 3 | *NOTE: LLaVA on Windows is not fully supported. Currently we only support 16-bit inference. For a more complete support, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install) for now. More functionalities on Windows is to be added soon, stay tuned.* 4 | 5 | ## Installation 6 | 7 | 1. Clone this repository and navigate to LLaVA folder 8 | ```bash 9 | git clone https://github.com/haotian-liu/LLaVA.git 10 | cd LLaVA 11 | ``` 12 | 13 | 2. Install Package 14 | ```Shell 15 | conda create -n llava python=3.10 -y 16 | conda activate llava 17 | python -mpip install --upgrade pip # enable PEP 660 support 18 | pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117 19 | pip install -e . 20 | pip uninstall bitsandbytes 21 | ``` 22 | 23 | ## Run demo 24 | 25 | See instructions [here](https://github.com/haotian-liu/LLaVA#demo). 26 | 27 | Note that quantization (4-bit, 8-bit) is *NOT* supported on Windows. Stay tuned for the 4-bit support on Windows! 28 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | POINT_TOKEN_INDEX = -200 9 | DEFAULT_POINT_TOKEN = "" 10 | DEFAULT_POINT_PATCH_TOKEN = "" 11 | DEFAULT_PT_START_TOKEN = "" 12 | DEFAULT_PT_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/eval/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import load_objaverse_point_cloud, pc_norm, farthest_point_sample 2 | from .object_point_dataset import ObjectPointCloudDataset, make_object_point_data_module 3 | from .modelnet import ModelNet -------------------------------------------------------------------------------- /llava/eval/data/modelnet_config/ModelNet40.yaml: -------------------------------------------------------------------------------- 1 | NAME: ModelNet 2 | DATA_PATH: playground/data/eval/modelnet40 3 | NUM_CATEGORY: 40 4 | USE_NORMALS: FALSE 5 | npoints: 8192 6 | random_sampling: TRUE 7 | use_height: FALSE 8 | use_normals: FALSE -------------------------------------------------------------------------------- /llava/eval/data/modelnet_config/modelnet40_shape_names_modified.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | bathtub 3 | bed 4 | bench 5 | bookshelf 6 | bottle 7 | bowl 8 | car 9 | chair 10 | cone 11 | cup 12 | curtain 13 | desk 14 | door 15 | dresser 16 | flower pot 17 | glass box 18 | guitar 19 | keyboard 20 | lamp 21 | laptop 22 | mantel 23 | monitor 24 | night stand 25 | person 26 | piano 27 | plant 28 | radio 29 | range hood 30 | sink 31 | sofa 32 | stairs 33 | stool 34 | table 35 | tent 36 | toilet 37 | tv stand 38 | vase 39 | wardrobe 40 | xbox 41 | -------------------------------------------------------------------------------- /llava/eval/eval_3dmmvet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import jsonlines 4 | from tqdm import tqdm 5 | import concurrent.futures 6 | from llava.eval.gpt_eval import gpt_get_average_score 7 | 8 | 9 | def main(args): 10 | ans_lines = args.answers_file.readlines() 11 | gt_lines = args.gt_file.readlines() 12 | assert len(ans_lines) == len(gt_lines) 13 | 14 | model_name = args.model 15 | output_file = args.output_file 16 | open(output_file, 'w').write("") 17 | 18 | ans_dict = { 19 | "General Visual Recognition": [], 20 | "Knowledge": [], 21 | "Language Generation": [], 22 | "Spatial Recognition": [], 23 | "Embodied Interaction": [], 24 | "Overall": [], 25 | } 26 | 27 | with tqdm(total=len(ans_lines), desc="Processing tasks", unit="task") as pbar: 28 | def process_task(i): 29 | model_output = json.loads(ans_lines[i]) 30 | gt = json.loads(gt_lines[i]) 31 | question = model_output['prompt'] 32 | model_ans = model_output['text'] 33 | gt_ans = gt['text'] 34 | category = gt['category'] 35 | 36 | score = gpt_get_average_score(question, category, model_ans, gt_ans, model=model_name, times=args.times) 37 | score = round(score, 1) 38 | ans_dict[category].append(score) 39 | ans_dict["Overall"].append(score) 40 | 41 | data = { 42 | "question_id": gt['question_id'], 43 | "answer_id": model_output['answer_id'], 44 | "question": question, 45 | "answer_model": model_ans, 46 | "answer_label": gt_ans, 47 | "model_id": model_name, 48 | "score": score, 49 | } 50 | jsonlines.open(output_file, mode='a').write(data) 51 | pbar.update(1) 52 | return f"{gt['question_id']} {score}%" 53 | 54 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: 55 | futures = [executor.submit(process_task, i) for i in range(len(ans_lines))] 56 | 57 | for future in concurrent.futures.as_completed(futures): 58 | try: 59 | result = future.result() 60 | print("Task completed:", result) 61 | except Exception as e: 62 | print("Task encountered an exception:", e) 63 | 64 | for category in ans_dict: 65 | print(f"{category} Acc: {round(sum(ans_dict[category]) / len(ans_dict[category]), 1)}%") 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--answers-file", type=argparse.FileType('r'), default="tables/answer.jsonl") 71 | parser.add_argument("--gt-file", type=argparse.FileType('r'), default="tables/gt.jsonl") 72 | parser.add_argument("--output-file", type=str, default="tables/result.jsonl") 73 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 74 | parser.add_argument("--max_workers", type=int, default=4) 75 | parser.add_argument("--times", type=int, default=5) 76 | args = parser.parse_args() 77 | 78 | main(args) 79 | -------------------------------------------------------------------------------- /llava/eval/eval_gapartnet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | import jsonlines 5 | import numpy as np 6 | from tqdm import tqdm 7 | import concurrent.futures 8 | 9 | 10 | def get_iou(bbox1, bbox2): 11 | 12 | def calculate_volume(bbox): 13 | min_coords = np.min(bbox, axis=0) 14 | max_coords = np.max(bbox, axis=0) 15 | side_lengths = np.maximum(0.0, max_coords - min_coords) 16 | volume = np.prod(side_lengths) 17 | return volume 18 | 19 | def calculate_intersection(bbox1, bbox2): 20 | min_coords = np.maximum(np.min(bbox1, axis=0), np.min(bbox2, axis=0)) 21 | max_coords = np.minimum(np.max(bbox1, axis=0), np.max(bbox2, axis=0)) 22 | side_lengths = np.maximum(0.0, max_coords - min_coords) 23 | intersection_volume = np.prod(side_lengths) 24 | return intersection_volume 25 | 26 | volume_bbox1 = calculate_volume(bbox1) 27 | volume_bbox2 = calculate_volume(bbox2) 28 | 29 | intersection_volume = calculate_intersection(bbox1, bbox2) 30 | union_volume = volume_bbox1 + volume_bbox2 - intersection_volume 31 | iou = intersection_volume / union_volume if union_volume > 0 else 0.0 32 | 33 | return iou 34 | 35 | 36 | def main(args): 37 | ans_lines = args.answers_file.readlines() 38 | gt_lines = args.gt_file.readlines() 39 | assert len(ans_lines) == len(gt_lines) 40 | 41 | output_file = args.output_file 42 | open(output_file, 'w').write("") 43 | 44 | ans_dict = { 45 | # "Bucket": [], 46 | # "CoffeeMachine": [], 47 | # "Printer": [], 48 | # "Camera": [], 49 | # "Toaster": [], 50 | "StorageFurniture": [], 51 | "Toilet": [], 52 | "Box": [], 53 | "WashingMachine": [], 54 | "Dishwasher": [], 55 | "Microwave": [], 56 | "Overall": [], 57 | } 58 | pattern = r"\[\[.*\], \[.*\], \[.*\], \[.*\], \[.*\], \[.*\], \[.*\], \[.*\]\]" 59 | 60 | with tqdm(total=len(ans_lines), desc="Processing tasks", unit="task") as pbar: 61 | def process_task(i): 62 | model_output = json.loads(ans_lines[i]) 63 | gt = json.loads(gt_lines[i]) 64 | question = model_output['prompt'] 65 | model_ans = model_output['text'] 66 | gt_ans = gt['text'] 67 | gt_bboxes = gt['bboxes'] 68 | category = gt['category'] 69 | 70 | bboxes = [] 71 | for pred in model_ans.split("]]"): 72 | pred = re.findall(pattern, pred + ']]') 73 | if len(pred) > 0: 74 | bbox = json.loads(pred[0]) 75 | bboxes.append(bbox) 76 | 77 | iou_list = [] 78 | acc_list = [] 79 | for j in range(len(bboxes)): 80 | iou = get_iou(bboxes[j], gt_bboxes[j]) 81 | iou_list.append(iou) 82 | if iou > 0.25: 83 | acc_list.append(1) 84 | else: 85 | acc_list.append(0) 86 | miou = round(sum(iou_list) / len(gt_bboxes) * 100, 2) 87 | acc = round(sum(acc_list) / len(gt_bboxes) * 100, 2) 88 | 89 | ans_dict[category].append(acc) 90 | ans_dict["Overall"].append(acc) 91 | 92 | data = { 93 | "question_id": gt['question_id'], 94 | "answer_id": model_output['answer_id'], 95 | "question": question, 96 | "answer_model": model_ans, 97 | "answer_label": gt_ans, 98 | "mIoU": miou, 99 | "acc": acc, 100 | } 101 | jsonlines.open(output_file, mode='a').write(data) 102 | pbar.update(1) 103 | return f"{gt['question_id']} {miou}%" 104 | 105 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: 106 | futures = [executor.submit(process_task, i) for i in range(len(ans_lines))] 107 | 108 | for future in concurrent.futures.as_completed(futures): 109 | try: 110 | result = future.result() 111 | print("Task completed:", result) 112 | except Exception as e: 113 | print("Task encountered an exception:", e) 114 | 115 | category_acc_list = [] 116 | for category in ans_dict: 117 | acc = round(sum(ans_dict[category]) / len(ans_dict[category]), 1) 118 | category_acc_list.append(acc) 119 | print(f"{category} Acc: {acc}%") 120 | print(f"Mean Acc: {round(sum(category_acc_list) / len(category_acc_list), 1)}%") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--answers-file", type=argparse.FileType('r'), default="tables/answer.jsonl") 126 | parser.add_argument("--gt-file", type=argparse.FileType('r'), default="tables/gt.jsonl") 127 | parser.add_argument("--output-file", type=str, default="tables/result.jsonl") 128 | parser.add_argument("--max_workers", type=int, default=4) 129 | parser.add_argument("--times", type=int, default=5) 130 | args = parser.parse_args() 131 | 132 | main(args) 133 | -------------------------------------------------------------------------------- /llava/eval/gpt_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import openai 3 | 4 | NUM_SECONDS_TO_SLEEP = 1.0 5 | 6 | openai.api_base = "" # if needed 7 | openai.api_key = "" 8 | 9 | 10 | def gpt_get_score(question, question_type, answer_model, answer_label, model="gpt-3.5-turbo"): 11 | while True: 12 | try: 13 | completion = openai.ChatCompletion.create( 14 | model=model, 15 | messages=[ 16 | {"role": "system", "content": "You are a helpful AI assistant."}, 17 | {"role": "user", "content": """ 18 | Now I will give you a question, the type of the question, an answer from model, and an answer from label. 19 | All you need to do is focus on these two answers and figure out whether they are saying the same thing about the specific type of question. 20 | Your response should only be a confidence score ranging from 0 to 100. 21 | Remember the confidence score is to evaluate how much two answers are describing the same thing. 22 | Your response confidence score should follow the scoring standard of the prompt I gave. 23 | Firstly I will give you several question & answer pairs as long as their confidence score: 24 | 25 | question1: How many oranges will there be if 1/3 of them are removed? 26 | question type: Knowledge 27 | answer from model: There will be 6 left. 28 | answer from label: As there are 9 oranges in total, there will be 6 oranges left if 1/3 of them are removed. 29 | confidence score: 100 30 | 31 | question2: What is this object? 32 | question type: General Visual Recognition 33 | answer from model: This is a bathtub 34 | answer from label: This is a dirty bathtub. 35 | confidence score: 80 36 | 37 | question3: What is this object? 38 | question type: General Visual Recognition 39 | answer from model: This is a bottle of water 40 | answer from label: This is a bottle of oil 41 | confidence score: 50 42 | 43 | question4: What is holding in this boy's right hand? 44 | question type: Spatial Recognition 45 | answer from model: He is holding a white cup in his right hand. 46 | answer from label: He is holding a sword in his right hand. 47 | confidence score: 0 48 | 49 | Next, I will give you the elements: 50 | question: {}, 51 | question type: {}, 52 | answer from model: {}, 53 | answer from label: {}. 54 | Please remember, while outputting the confidence score, do not include any words, just the number. 55 | """.format(question, question_type, answer_model, answer_label)}, 56 | ] 57 | ) 58 | response = completion.choices[0].message["content"] 59 | break 60 | except: 61 | pass 62 | time.sleep(NUM_SECONDS_TO_SLEEP) 63 | 64 | return response 65 | 66 | 67 | def is_valid(value): 68 | try: 69 | value = float(value) 70 | if 0 <= value <= 100: 71 | return True 72 | else: 73 | return False 74 | except ValueError: 75 | return False 76 | 77 | 78 | def gpt_get_average_score(question, question_type, answer_model, answer_label, model="gpt-3.5-turbo", times=5): 79 | scores = [] 80 | while len(scores) < times: 81 | score = gpt_get_score(question, question_type, answer_model, answer_label, model) 82 | if is_valid(score): 83 | scores.append(float(score)) 84 | return sum(scores) / len(scores) 85 | 86 | 87 | if __name__ == "__main__": 88 | score = gpt_get_score( 89 | question="How many oranges will there be if 1/3 of them are removed?", 90 | question_type="Knowledge", 91 | answer_model="There will be 6 left.", 92 | answer_label="As there are 9 oranges in total, there will be 6 oranges left if 1/3 of them are removed.", 93 | model="gpt-3.5-turbo" 94 | ) 95 | print(score) 96 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import torch 5 | import argparse 6 | import shortuuid 7 | from tqdm import tqdm 8 | 9 | from llava.utils import disable_torch_init 10 | from llava.model.builder import load_pretrained_model 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.mm_utils import tokenizer_point_token, get_model_name_from_path, load_pts, process_pts 13 | from llava.constants import POINT_TOKEN_INDEX, DEFAULT_POINT_TOKEN, DEFAULT_PT_START_TOKEN, DEFAULT_PT_END_TOKEN 14 | 15 | 16 | def split_list(lst, n): 17 | """Split a list into n (roughly) equal-sized chunks""" 18 | chunk_size = math.ceil(len(lst) / n) # integer division 19 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 20 | 21 | 22 | def get_chunk(lst, n, k): 23 | chunks = split_list(lst, n) 24 | return chunks[k] 25 | 26 | 27 | def eval_model(args): 28 | # Model 29 | disable_torch_init() 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 34 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 35 | answers_file = os.path.expanduser(args.answers_file) 36 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 37 | ans_file = open(answers_file, "w") 38 | for line in tqdm(questions): 39 | idx = line["question_id"] 40 | point_file = line["point"] 41 | qs = line["text"] 42 | cur_prompt = qs 43 | if model.config.mm_use_pt_start_end: 44 | qs = DEFAULT_PT_START_TOKEN + DEFAULT_POINT_TOKEN + DEFAULT_PT_END_TOKEN + '\n' + qs 45 | else: 46 | qs = DEFAULT_POINT_TOKEN + '\n' + qs 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | conv.append_message(conv.roles[0], qs) 50 | conv.append_message(conv.roles[1], None) 51 | prompt = conv.get_prompt() 52 | 53 | input_ids = tokenizer_point_token(prompt, tokenizer, POINT_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 54 | 55 | point = load_pts(os.path.join(args.point_folder, point_file)) 56 | pts_tensor = process_pts(point, model.config).unsqueeze(0) 57 | pts_tensor = pts_tensor.to(model.device, dtype=torch.float16) 58 | 59 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 60 | 61 | with torch.inference_mode(): 62 | output_ids = model.generate( 63 | input_ids, 64 | points=pts_tensor, 65 | do_sample=True if args.temperature > 0 and args.num_beams == 1 else False, 66 | temperature=args.temperature, 67 | top_k=args.top_k, 68 | top_p=args.top_p, 69 | num_beams=args.num_beams, 70 | max_new_tokens=1024, 71 | use_cache=True) 72 | 73 | input_token_len = input_ids.shape[1] 74 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 75 | if n_diff_input_output > 0: 76 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 77 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 78 | outputs = outputs.strip() 79 | if outputs.endswith(stop_str): 80 | outputs = outputs[:-len(stop_str)] 81 | outputs = outputs.strip() 82 | 83 | ans_id = shortuuid.uuid() 84 | ans_file.write(json.dumps({"question_id": idx, 85 | "prompt": cur_prompt, 86 | "text": outputs, 87 | "answer_id": ans_id, 88 | "model_id": model_name, 89 | "metadata": {}}) + "\n") 90 | ans_file.flush() 91 | ans_file.close() 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 97 | parser.add_argument("--model-base", type=str, default=None) 98 | parser.add_argument("--point-folder", type=str, default="") 99 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 100 | parser.add_argument("--answers-file", type=str, default="tables/answer.jsonl") 101 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 102 | parser.add_argument("--num-chunks", type=int, default=1) 103 | parser.add_argument("--chunk-idx", type=int, default=0) 104 | parser.add_argument("--temperature", type=float, default=0.2) 105 | parser.add_argument("--top_k", type=int, default=1) 106 | parser.add_argument("--top_p", type=float, default=None) 107 | parser.add_argument("--num_beams", type=int, default=1) 108 | args = parser.parse_args() 109 | 110 | eval_model(args) 111 | -------------------------------------------------------------------------------- /llava/eval/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | import random 4 | import os 5 | 6 | 7 | def retry_with_exponential_backoff( 8 | func, 9 | initial_delay: float = 1, 10 | exponential_base: float = 2, 11 | jitter: bool = True, 12 | max_retries: int = 40, 13 | max_delay: int = 30, 14 | errors: tuple = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout), 15 | ): 16 | """Retry a function with exponential backoff.""" 17 | 18 | def wrapper(*args, **kwargs): 19 | num_retries = 0 20 | delay = initial_delay 21 | 22 | while True: 23 | try: 24 | return func(*args, **kwargs) 25 | except errors as e: 26 | # * print the error info 27 | num_retries += 1 28 | if num_retries > max_retries: 29 | print(f"[OPENAI] Encounter error: {e}.") 30 | raise Exception( 31 | f"[OPENAI] Maximum number of retries ({max_retries}) exceeded." 32 | ) 33 | delay *= exponential_base * (1 + jitter * random.random()) 34 | time.sleep(min(delay, max_delay)) 35 | except Exception as e: 36 | raise e 37 | 38 | return wrapper 39 | 40 | 41 | class OpenAIGPT(): 42 | def __init__(self, model="gpt-3.5-turbo-0613", temperature=1, top_p=1, max_tokens=2048, **kwargs) -> None: 43 | setup_openai(model) 44 | self.default_chat_parameters = { 45 | "model": model, 46 | "temperature": temperature, 47 | "top_p": top_p, 48 | "max_tokens": max_tokens, 49 | **kwargs 50 | } 51 | 52 | @retry_with_exponential_backoff 53 | def safe_chat_complete(self, messages, content_only=True, **kwargs): 54 | chat_parameters = self.default_chat_parameters.copy() 55 | if len(kwargs) > 0: 56 | chat_parameters.update(**kwargs) 57 | 58 | response = openai.ChatCompletion.create( 59 | messages=messages, 60 | **chat_parameters 61 | ) 62 | 63 | if content_only: 64 | response = response['choices'][0]["message"]['content'] 65 | 66 | return response 67 | 68 | 69 | def setup_openai(model_name): 70 | # Setup OpenAI API Key 71 | if 'gpt-3.5' in model_name: 72 | openai.api_base = "http://openai.group-megvii-aic-research-hardware.megvii-aic.svc.hh-d.brainpp.local:5000/v1" 73 | else: 74 | openai.api_base = "http://openai.group-megvii-aic-research-hardware.megvii-aic.svc.hh-d.brainpp.local:5000/gpt4/v1" 75 | openai.api_key = "sk-ZXlKMGVYQWlPaUpLVjFRaUxDSmhiR2NpT2lKSVV6STFOaUo5LmV5SjFjMlZ5Ym1GdFpTSTZJbWhoYm1Ob2RXNXlkV2tpZlEucWdWaldsZW1hNWFmTlhMSExoNFZVTnY3VHVkeW9rdGVENFpjZDVyZHpsNA==" 76 | # print("[OPENAI] Setting OpenAI api_key...") 77 | # openai.api_key = os.getenv('OPENAI_API_KEY') 78 | # print(f"[OPENAI] OpenAI organization: {openai.organization}") 79 | # print(f"[OPENAI] Using MODEL: {model_name}") 80 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | points: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, points) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "points": kwargs.get("points", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | 140 | AutoConfig.register("llava", LlavaConfig) 141 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 142 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 6 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 7 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ReConV2.models.ReCon import ReCon2 4 | from ReConV2.utils.config import cfg_from_yaml_file 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.cfg_path = vision_tower 12 | self.vision_tower_path = args.vision_tower_path 13 | self.config = cfg_from_yaml_file(self.cfg_path) 14 | self.config.with_color = args.with_color 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 17 | 18 | self.vision_tower = ReCon2(self.config.model) 19 | self.hidden_size = self.vision_tower.embed_dim 20 | self.global_query_num = self.vision_tower.global_query_num 21 | self.is_loaded = False 22 | 23 | def load_model(self): 24 | ckpt = torch.load(self.vision_tower_path, map_location='cpu') 25 | state_dict = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 26 | self.vision_tower.load_state_dict(state_dict, strict=True) 27 | self.vision_tower.requires_grad_(False) 28 | self.is_loaded = True 29 | 30 | @torch.no_grad() 31 | def forward(self, pts): 32 | 33 | if type(pts) is list: 34 | pos_features = [] 35 | local_features = [] 36 | global_features = [] 37 | for pt in pts: 38 | pos_feature, local_feature, global_feature = self.vision_tower.model.inference(pt.to(device=self.device, dtype=self.dtype).unsqueeze(0)) 39 | pos_features.append(pos_feature.to(pts.dtype)) 40 | local_features.append(local_feature.to(pts.dtype)) 41 | global_features.append(global_feature.to(pts.dtype)) 42 | else: 43 | pos_features, local_features, global_features = self.vision_tower.model.inference(pts.to(device=self.device, dtype=self.dtype)) 44 | local_features = local_features.to(pts.dtype) 45 | global_features = global_features.to(pts.dtype) 46 | 47 | return pos_features, local_features, global_features 48 | 49 | @property 50 | def dummy_feature(self): 51 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 52 | 53 | @property 54 | def dtype(self): 55 | return self.vision_tower.dtype 56 | 57 | @property 58 | def device(self): 59 | return self.vision_tower.device 60 | 61 | @property 62 | def num_patches(self): 63 | return self.config.num_group 64 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, y, *args, **kwargs): 11 | return x, y 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class LinearProjector(nn.Module): 19 | def __init__(self, in_channels, out_channels): 20 | super().__init__() 21 | self.proj1 = nn.Linear(in_channels, out_channels) 22 | self.proj2 = nn.Linear(in_channels, out_channels) 23 | self.proj3 = nn.Linear(in_channels, out_channels) 24 | 25 | def forward(self, x, y, z, *args, **kwargs): 26 | x = self.proj1(x) 27 | y = self.proj2(y) 28 | z = self.proj3(z) 29 | 30 | return torch.cat([x, y, z], dim=1) 31 | 32 | @property 33 | def config(self): 34 | return {"mm_projector_type": 'linear'} 35 | 36 | 37 | class ReConProjector(nn.Module): 38 | def __init__(self, in_channels, out_channels, mlp_depth, prompt_token_num, 39 | with_ape=True, with_local=True, with_global=True): 40 | super().__init__() 41 | 42 | self.in_channels = in_channels 43 | self.out_channels = out_channels 44 | self.mlp_depth = mlp_depth 45 | self.hidden_size = [1024 * 2 ** i for i in range(mlp_depth)] 46 | self.prompt_token_num = prompt_token_num 47 | self.with_ape = with_ape 48 | self.with_local = with_local 49 | self.with_global = with_global 50 | 51 | if prompt_token_num > 0: 52 | self.prompt1 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels)) 53 | self.prompt2 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels)) 54 | self.prompt3 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels)) 55 | 56 | self.proj1 = self.set_proj() 57 | self.proj2 = self.set_proj() 58 | self.proj3 = self.set_proj() 59 | 60 | def set_proj(self): 61 | modules = [nn.Linear(self.in_channels, self.hidden_size[0])] 62 | for i in range(1, self.mlp_depth): 63 | modules.append(nn.LayerNorm(self.hidden_size[i - 1])) 64 | modules.append(nn.GELU()) 65 | modules.append(nn.Linear(self.hidden_size[i - 1], self.hidden_size[i])) 66 | modules.append(nn.LayerNorm(self.hidden_size[-1])) 67 | modules.append(nn.GELU()) 68 | modules.append(nn.Linear(self.hidden_size[-1], self.out_channels)) 69 | return nn.Sequential(*modules) 70 | 71 | def forward(self, pos_feat, local_feat, global_feat, *args, **kwargs): 72 | B = pos_feat.shape[0] 73 | pos_feat = self.proj1(pos_feat) 74 | local_feat = self.proj2(local_feat) 75 | global_feat = self.proj3(global_feat) 76 | 77 | if self.prompt_token_num > 0: 78 | pos_feat = torch.cat([pos_feat, self.prompt1.expand(B, -1, -1)], dim=1) 79 | local_feat = torch.cat([local_feat, self.prompt2.expand(B, -1, -1)], dim=1) 80 | global_feat = torch.cat([global_feat, self.prompt3.expand(B, -1, -1)], dim=1) 81 | 82 | pts_feat = [feat for feat, flag in [(pos_feat, self.with_ape), (local_feat, self.with_local), (global_feat, self.with_global)] if flag] 83 | pts_feat = torch.cat(pts_feat, dim=1) 84 | 85 | return pts_feat 86 | 87 | @property 88 | def config(self): 89 | return {"mm_projector_type": 'mlp'} 90 | 91 | 92 | def build_vision_projector(config, delay_load=False, **kwargs): 93 | projector_type = getattr(config, 'mm_projector_type', 'linear') 94 | 95 | if projector_type == 'linear': 96 | return LinearProjector(config.mm_hidden_size, config.hidden_size) 97 | 98 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 99 | if mlp_gelu_match: 100 | mlp_depth = int(mlp_gelu_match.group(1)) 101 | return ReConProjector(config.mm_hidden_size, config.hidden_size, mlp_depth, config.prompt_token_num, 102 | config.with_ape, config.with_local, config.with_global) 103 | 104 | if projector_type == 'identity': 105 | return IdentityMap() 106 | 107 | raise ValueError(f'Unknown projector type: {projector_type}') 108 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from transformers import TextStreamer 5 | from llava.utils import disable_torch_init 6 | from llava.model.builder import load_pretrained_model 7 | from llava.conversation import conv_templates, SeparatorStyle 8 | from llava.constants import POINT_TOKEN_INDEX, DEFAULT_POINT_TOKEN, DEFAULT_PT_START_TOKEN, DEFAULT_PT_END_TOKEN 9 | from llava.mm_utils import load_pts, process_pts, rotation, tokenizer_point_token, get_model_name_from_path, \ 10 | KeywordsStoppingCriteria 11 | 12 | 13 | def main(args): 14 | # Model 15 | disable_torch_init() 16 | 17 | model_name = get_model_name_from_path(args.model_path) 18 | tokenizer, model, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, 19 | args.load_4bit, device=args.device) 20 | 21 | conv_mode = "llava_v1" 22 | 23 | if args.conv_mode is not None and conv_mode != args.conv_mode: 24 | print( 25 | '[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, 26 | args.conv_mode, 27 | args.conv_mode)) 28 | else: 29 | args.conv_mode = conv_mode 30 | 31 | conv = conv_templates[args.conv_mode].copy() 32 | roles = conv.roles 33 | 34 | if args.pts_file is not None: 35 | pts = load_pts(args.pts_file) 36 | if args.objaverse: 37 | pts[:, :3] = rotation(pts[:, :3], [0, 0, -90]) 38 | pts_tensor = process_pts(pts, model.config).unsqueeze(0) 39 | pts_tensor = pts_tensor.to(model.device, dtype=torch.float16) 40 | else: 41 | pts = None 42 | pts_tensor = None 43 | 44 | while True: 45 | try: 46 | inp = input(f"{roles[0]}: ") 47 | except EOFError: 48 | inp = "" 49 | if not inp: 50 | print("exit...") 51 | break 52 | 53 | print(f"{roles[1]}: ", end="") 54 | 55 | if pts is not None: 56 | # first message 57 | if model.config.mm_use_pt_start_end: 58 | inp = DEFAULT_PT_START_TOKEN + DEFAULT_POINT_TOKEN + DEFAULT_PT_END_TOKEN + '\n' + inp 59 | else: 60 | inp = DEFAULT_POINT_TOKEN + '\n' + inp 61 | conv.append_message(conv.roles[0], inp) 62 | pts = None 63 | else: 64 | # later messages 65 | conv.append_message(conv.roles[0], inp) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_point_token(prompt, tokenizer, POINT_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 70 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 71 | keywords = [stop_str] 72 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 73 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 74 | 75 | with torch.inference_mode(): 76 | output_ids = model.generate( 77 | input_ids, 78 | points=pts_tensor, 79 | do_sample=True, 80 | temperature=args.temperature, 81 | max_new_tokens=args.max_new_tokens, 82 | streamer=streamer, 83 | use_cache=True, 84 | stopping_criteria=[stopping_criteria]) 85 | 86 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 87 | conv.messages[-1][-1] = outputs 88 | 89 | if args.debug: 90 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 96 | parser.add_argument("--model-base", type=str, default=None) 97 | parser.add_argument("--pts-file", type=str, required=False) 98 | parser.add_argument("--device", type=str, default="cuda") 99 | parser.add_argument("--conv-mode", type=str, default=None) 100 | parser.add_argument("--temperature", type=float, default=0.2) 101 | parser.add_argument("--max-new-tokens", type=int, default=512) 102 | parser.add_argument("--load-8bit", action="store_true") 103 | parser.add_argument("--load-4bit", action="store_true") 104 | parser.add_argument("--objaverse", action="store_true") 105 | parser.add_argument("--debug", action="store_true") 106 | 107 | args = parser.parse_args() 108 | main(args) 109 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /playground/data/eval/gapartnet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizekun/ShapeLLM/f754b0d488f7187a699a549dd0c5e43b2349f051/playground/data/eval/gapartnet/.DS_Store -------------------------------------------------------------------------------- /playground/data/eval/gapartnet/test_list.json: -------------------------------------------------------------------------------- 1 | ["Toilet_102706", "Box_100247", "Camera_102536", "Microwave_7296", "StorageFurniture_45676", "Toilet_102697", "StorageFurniture_47182", "StorageFurniture_45173", "StorageFurniture_48797", "WashingMachine_103369", "Camera_102398", "Microwave_7263", "StorageFurniture_46277", "StorageFurniture_47419", "StorageFurniture_41004", "StorageFurniture_45007", "Camera_102882", "AKBBox_64", "Remote_100385", "StorageFurniture_45403", "AKBBox_45", "Toaster_103524", "Toilet_102646", "AKBBucket_2", "Toilet_102689", "Camera_102873", "StorageFurniture_46380", "StorageFurniture_40453", "AKBBox_55", "Toilet_102630", "Toilet_102631", "AKBDrawer_289", "Dishwasher_12559", "AKBDrawer_300", "Microwave_7349", "Bucket_100438", "AKBBox_41", "CoffeeMachine_103129", "Dishwasher_12530", "Keyboard_13082", "StorageFurniture_45134", "AKBBox_63", "AKBBucket_3", "Printer_104011", "Bucket_100439", "Toilet_102670", "AKBBucket_0", "Camera_102874", "AKBBox_40", "StorageFurniture_46889", "StorageFurniture_46179", "AKBTrashCan_219", "AKBTrashCan_224", "StorageFurniture_45212", "StorageFurniture_47021", "Toaster_103486", "Printer_103863", "StorageFurniture_47133", "Remote_101133", "AKBBucket_1", "Remote_104038", "StorageFurniture_49188", "StorageFurniture_47585", "StorageFurniture_45372", "Dishwasher_12484", "AKBDrawer_282", "StorageFurniture_48381", "AKBDrawer_294", "CoffeeMachine_103082", "Toaster_103561", "Remote_100706", "CoffeeMachine_103074", "Printer_104016", "Camera_102472", "Bucket_100486", "StorageFurniture_45949", "Box_100243", "CoffeeMachine_103037", "Box_100191", "AKBBucket_6", "WashingMachine_100283", "StorageFurniture_46107", "Toilet_102651", "AKBDrawer_283", "StorageFurniture_46699", "AKBTrashCan_213", "StorageFurniture_46145", "Toilet_102708", "Keyboard_12977", "StorageFurniture_45691", "Bucket_100431", "StorageFurniture_45623", "StorageFurniture_45444", "CoffeeMachine_103128", "AKBBox_60", "StorageFurniture_45783", "Printer_100279", "StorageFurniture_46481", "Keyboard_12738", "StorageFurniture_45159", "AKBTrashCan_227", "StorageFurniture_45759", "Toilet_101323", "StorageFurniture_47443", "StorageFurniture_45855", "Printer_104000", "StorageFurniture_46598", "StorageFurniture_45746", "StorageFurniture_46655", "Dishwasher_12065", "AKBTrashCan_225", "StorageFurniture_45767", "StorageFurniture_46019", "StorageFurniture_35059", "Remote_100395", "StorageFurniture_47233", "Keyboard_12917", "StorageFurniture_45385", "Keyboard_13075", "StorageFurniture_45606", "Camera_102403", "Remote_101104", "CoffeeMachine_103143", "AKBDrawer_288", "Box_100189", "Dishwasher_11700", "StorageFurniture_46437", "StorageFurniture_45841", "StorageFurniture_46172", "Toaster_103469", "StorageFurniture_46744", "Remote_101121", "AKBBucket_4", "StorageFurniture_45261", "Toilet_103234", "Remote_101028", "StorageFurniture_46653", "Toilet_102636", "StorageFurniture_46641", "Box_100221", "StorageFurniture_48036", "Dishwasher_12428", "StorageFurniture_45910", "AKBBox_58", "WashingMachine_103528", "Toilet_102676", "Bucket_100432", "Bucket_100435", "StorageFurniture_46014", "Dishwasher_12558", "AKBBucket_5", "Dishwasher_12597", "StorageFurniture_46109", "StorageFurniture_48721", "StorageFurniture_45633", "StorageFurniture_45779", "Remote_101004", "StorageFurniture_49038", "StorageFurniture_45963", "Remote_104040", "CoffeeMachine_103092", "Keyboard_13086", "Printer_103972", "StorageFurniture_45166", "StorageFurniture_45671", "StorageFurniture_48497", "CoffeeMachine_102901"] -------------------------------------------------------------------------------- /playground/data/eval/modelnet40/modelnet40_shape_names_modified.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | bathtub 3 | bed 4 | bench 5 | bookshelf 6 | bottle 7 | bowl 8 | car 9 | chair 10 | cone 11 | cup 12 | curtain 13 | desk 14 | door 15 | dresser 16 | flower pot 17 | glass box 18 | guitar 19 | keyboard 20 | lamp 21 | laptop 22 | mantel 23 | monitor 24 | night stand 25 | person 26 | piano 27 | plant 28 | radio 29 | range hood 30 | sink 31 | sofa 32 | stairs 33 | stool 34 | table 35 | tent 36 | toilet 37 | tv stand 38 | vase 39 | wardrobe 40 | xbox 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.1.3" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.0.1", "torchvision==0.15.2", 17 | "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", 19 | "pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", 20 | "gradio==3.35.2", "gradio_client==0.2.9", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12", 23 | "argparse", "easydict", "h5py", "matplotlib", "tqdm", 24 | "opencv-python==4.10.0.84", "pyyaml", "scipy", "tensorboardX", 25 | "termcolor", "pandas", "ftfy", "regex", "plyfile", "ipdb", 26 | "jsonlines", "openai", "nltk", "rouge", "py-rouge" 27 | ] 28 | 29 | [project.optional-dependencies] 30 | train = ["deepspeed==0.9.5", "ninja", "wandb", "torch-scatter==2.0.9"] 31 | 32 | [project.urls] 33 | "Homepage" = "https://qizekun.github.io/shapellm" 34 | "Bug Tracker" = "https://github.com/qizekun/ShapeLLM/issues" 35 | 36 | [tool.setuptools.packages.find] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | 39 | [tool.wheel] 40 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 41 | -------------------------------------------------------------------------------- /scripts/eval/eval_gapartnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=shapellm-13b 4 | TAG=gapartnet-v1.0 5 | 6 | mkdir -p ./playground/data/eval/gapartnet/results 7 | 8 | python -m llava.eval.eval_gapartnet \ 9 | --answers-file ./playground/data/eval/gapartnet/answers/$MODEL_VERSION-$TAG.jsonl \ 10 | --gt-file ./playground/data/eval/gapartnet/gt.jsonl \ 11 | --output-file ./playground/data/eval/gapartnet/results/$MODEL_VERSION-$TAG.jsonl \ 12 | --max_workers 16 -------------------------------------------------------------------------------- /scripts/eval/eval_mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=shapellm-13b 4 | TAG=general-v1.0 5 | 6 | mkdir -p ./playground/data/eval/3d-mm-vet/results 7 | 8 | python -m llava.eval.eval_3dmmvet \ 9 | --answers-file ./playground/data/eval/3d-mm-vet/answers/$MODEL_VERSION-$TAG.jsonl \ 10 | --gt-file ./playground/data/eval/3d-mm-vet/gt.jsonl \ 11 | --output-file ./playground/data/eval/3d-mm-vet/results/$MODEL_VERSION-$TAG.jsonl \ 12 | --model gpt-4-0125-preview \ 13 | --max_workers 16 \ 14 | --times 5 -------------------------------------------------------------------------------- /scripts/eval/eval_modelnet40_cls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python llava/eval/evaluator.py \ 4 | --results_path ModelNet_classification_prompt0.json \ 5 | --model_type gpt-3.5-turbo-0613 \ 6 | --eval_type modelnet-close-set-classification \ 7 | --parallel \ 8 | --num_workers 15 -------------------------------------------------------------------------------- /scripts/eval/eval_objaverse_cap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python llava/eval/evaluator.py \ 4 | --results_path PointLLM_brief_description_val_200_GT_Objaverse_captioning_prompt2.json \ 5 | --model_type gpt-4-0613 \ 6 | --eval_type object-captioning \ 7 | --parallel \ 8 | --num_workers 15 9 | 10 | python llava/eval/traditional_evaluator.py \ 11 | --results_path PointLLM_brief_description_val_200_GT_Objaverse_captioning_prompt2.json -------------------------------------------------------------------------------- /scripts/eval/eval_objaverse_cls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python llava/eval/evaluator.py \ 4 | --results_path PointLLM_brief_description_val_200_GT_Objaverse_captioning_prompt0.json \ 5 | --model_type gpt-4-0613 \ 6 | --eval_type open-free-form-classification \ 7 | --parallel \ 8 | --num_workers 15 -------------------------------------------------------------------------------- /scripts/eval/gapartnet_ref.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=shapellm-13b 4 | TAG=gapartnet-v1.0 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -m llava.eval.model_vqa \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --question-file ./playground/data/eval/gapartnet/question.jsonl \ 9 | --point-folder ./playground/data/shapellm/gapartnet_pcs \ 10 | --answers-file ./playground/data/eval/gapartnet/answers/$MODEL_VERSION-$TAG.jsonl \ 11 | --conv-mode vicuna_v1 \ 12 | --num_beams 5 -------------------------------------------------------------------------------- /scripts/eval/mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=shapellm-13b 4 | TAG=general-v1.0 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -m llava.eval.model_vqa \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --question-file ./playground/data/eval/3d-mm-vet/question.jsonl \ 9 | --point-folder ./playground/data/eval/3d-mm-vet/points \ 10 | --answers-file ./playground/data/eval/3d-mm-vet/answers/$MODEL_VERSION-$TAG.jsonl \ 11 | --conv-mode vicuna_v1 \ 12 | --num_beams 5 -------------------------------------------------------------------------------- /scripts/eval/modelnet40_cls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=llama-vicuna-7b 4 | TAG=v1.0 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -m llava.eval.eval_modelnet_cls \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --prompt_index 0 \ 9 | --num_beams 5 -------------------------------------------------------------------------------- /scripts/eval/objaverse_cap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=llama-vicuna-7b 4 | TAG=v1.0 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -m llava.eval.eval_objaverse \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --task_type captioning \ 9 | --prompt_index 2 \ 10 | --num_beams 5 -------------------------------------------------------------------------------- /scripts/eval/objaverse_cls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=llama-vicuna-7b 4 | TAG=v1.0 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -m llava.eval.eval_objaverse \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --task_type classification \ 9 | --prompt_index 0 \ 10 | --num_beams 5 -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLM_VERSION=lmsys/vicuna-13b-v1.1 4 | MODEL_VERSION=shapellm-13b 5 | PRETRAIN_TAG=v1.0 6 | TAG=v1.0 7 | 8 | type=general 9 | 10 | if [ $type = "general" ]; then 11 | meta_path="./playground/data/shapellm/cap3d_objaverse_sft_45k.json" 12 | pcs_path="./playground/data/shapellm/cap3d_pcs" 13 | elif [ $type = "gapartnet" ]; then 14 | meta_path="./playground/data/shapellm/gapartnet_sft_27k_openai.json" 15 | pcs_path="./playground/data/shapellm/gapartnet_pcs" 16 | else 17 | echo "Unknown type" 18 | exit 1 19 | fi 20 | 21 | deepspeed llava/train/train_mem.py \ 22 | --deepspeed ./scripts/zero2.json \ 23 | --model_name_or_path $LLM_VERSION \ 24 | --version v1 \ 25 | --data_path $meta_path \ 26 | --point_folder $pcs_path \ 27 | --vision_tower ReConV2/cfgs/pretrain/large/openshape.yaml \ 28 | --vision_tower_path ./checkpoints/recon/large.pth \ 29 | --sample_points_num 10000 \ 30 | --with_color True \ 31 | --occlusion False \ 32 | --prompt_token_num 32 \ 33 | --with_ape True \ 34 | --with_local True \ 35 | --with_global True \ 36 | --pretrain_mm_mlp_adapter ./checkpoints/$MODEL_VERSION-$PRETRAIN_TAG-pretrain/mm_projector.bin \ 37 | --mm_projector_type mlp2x_gelu \ 38 | --mm_vision_select_layer -2 \ 39 | --mm_use_pt_start_end False \ 40 | --mm_use_pt_patch_token False \ 41 | --bf16 True \ 42 | --output_dir ./checkpoints/$MODEL_VERSION-$type-$TAG-finetune \ 43 | --num_train_epochs 1 \ 44 | --per_device_train_batch_size 16 \ 45 | --per_device_eval_batch_size 4 \ 46 | --gradient_accumulation_steps 1 \ 47 | --evaluation_strategy "no" \ 48 | --save_strategy "steps" \ 49 | --save_steps 50000 \ 50 | --save_total_limit 1 \ 51 | --learning_rate 2e-5 \ 52 | --weight_decay 0. \ 53 | --warmup_ratio 0.03 \ 54 | --lr_scheduler_type "cosine" \ 55 | --logging_steps 1 \ 56 | --tf32 True \ 57 | --model_max_length 2048 \ 58 | --gradient_checkpointing True \ 59 | --dataloader_num_workers 4 \ 60 | --lazy_preprocess True \ 61 | --report_to wandb -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLM_VERSION=lmsys/vicuna-13b-v1.1 4 | MODEL_VERSION=shapellm-13b 5 | PRETRAIN_TAG=v1.0 6 | TAG=v1.0 7 | 8 | type=general 9 | 10 | if [ $type = "general" ]; then 11 | meta_path="./playground/data/shapellm/cap3d_objaverse_sft_45k.json" 12 | pcs_path="./playground/data/shapellm/cap3d_pcs" 13 | elif [ $type = "gapartnet" ]; then 14 | meta_path="./playground/data/shapellm/gapartnet_sft_27k_openai.json" 15 | pcs_path="./playground/data/shapellm/gapartnet_pcs" 16 | else 17 | echo "Unknown type" 18 | exit 1 19 | fi 20 | 21 | deepspeed llava/train/train_mem.py \ 22 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 23 | --deepspeed ./scripts/zero2.json \ 24 | --model_name_or_path $LLM_VERSION \ 25 | --version v1 \ 26 | --data_path $meta_path \ 27 | --point_folder $pcs_path \ 28 | --vision_tower ReConV2/cfgs/pretrain/large/openshape.yaml \ 29 | --vision_tower_path ./checkpoints/recon/large.pth \ 30 | --sample_points_num 10000 \ 31 | --with_color True \ 32 | --occlusion False \ 33 | --prompt_token_num 32 \ 34 | --with_ape True \ 35 | --with_local True \ 36 | --with_global True \ 37 | --pretrain_mm_mlp_adapter ./checkpoints/$MODEL_VERSION-$PRETRAIN_TAG-pretrain/mm_projector.bin \ 38 | --mm_projector_type mlp2x_gelu \ 39 | --mm_vision_select_layer -2 \ 40 | --mm_use_pt_start_end False \ 41 | --mm_use_pt_patch_token False \ 42 | --point_aspect_ratio pad \ 43 | --group_by_modality_length True \ 44 | --bf16 True \ 45 | --output_dir ./checkpoints/$MODEL_VERSION-$type-$TAG-lora \ 46 | --num_train_epochs 1 \ 47 | --per_device_train_batch_size 16 \ 48 | --per_device_eval_batch_size 4 \ 49 | --gradient_accumulation_steps 1 \ 50 | --evaluation_strategy "no" \ 51 | --save_strategy "steps" \ 52 | --save_steps 50000 \ 53 | --save_total_limit 1 \ 54 | --learning_rate 2e-4 \ 55 | --weight_decay 0. \ 56 | --warmup_ratio 0.03 \ 57 | --lr_scheduler_type "cosine" \ 58 | --logging_steps 1 \ 59 | --tf32 True \ 60 | --model_max_length 2048 \ 61 | --gradient_checkpointing True \ 62 | --dataloader_num_workers 4 \ 63 | --lazy_preprocess True \ 64 | --report_to wandb -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION=llama-vicuna-7b 4 | TAG=8k 5 | 6 | CUDA_VISIBLE_DEVICES=0 python llava/serve/cli.py \ 7 | --model-path ./checkpoints/$MODEL_VERSION-$TAG-finetune \ 8 | --pts-file $1 -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLM_VERSION=lmsys/vicuna-13b-v1.1 4 | MODEL_VERSION=shapellm-13b 5 | TAG=v1.0 6 | 7 | deepspeed llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero2.json \ 9 | --model_name_or_path $LLM_VERSION \ 10 | --version plain \ 11 | --data_path ./playground/data/shapellm/cap3d_objaverse_785k.json \ 12 | --point_folder ./playground/data/shapellm/cap3d_pcs \ 13 | --vision_tower ReConV2/cfgs/pretrain/large/openshape.yaml \ 14 | --vision_tower_path ./checkpoints/recon/large.pth \ 15 | --sample_points_num 10000 \ 16 | --with_color True \ 17 | --occlusion False \ 18 | --prompt_token_num 32 \ 19 | --with_ape True \ 20 | --with_local True \ 21 | --with_global True \ 22 | --mm_projector_type mlp2x_gelu \ 23 | --tune_mm_mlp_adapter True \ 24 | --mm_vision_select_layer -2 \ 25 | --mm_use_pt_start_end False \ 26 | --mm_use_pt_patch_token False \ 27 | --bf16 True \ 28 | --output_dir ./checkpoints/$MODEL_VERSION-$TAG-pretrain \ 29 | --num_train_epochs 1 \ 30 | --per_device_train_batch_size 32 \ 31 | --per_device_eval_batch_size 4 \ 32 | --gradient_accumulation_steps 1 \ 33 | --evaluation_strategy "no" \ 34 | --save_strategy "steps" \ 35 | --save_steps 24000 \ 36 | --save_total_limit 1 \ 37 | --learning_rate 1e-3 \ 38 | --weight_decay 0. \ 39 | --warmup_ratio 0.03 \ 40 | --lr_scheduler_type "cosine" \ 41 | --logging_steps 1 \ 42 | --tf32 True \ 43 | --model_max_length 2048 \ 44 | --gradient_checkpointing True \ 45 | --dataloader_num_workers 4 \ 46 | --lazy_preprocess True \ 47 | --report_to wandb -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------