├── imgs ├── result.png ├── teaser.png └── overview.png ├── requirements.txt ├── evaluation ├── __pycache__ │ ├── ba.cpython-310.pyc │ ├── ba.cpython-39.pyc │ ├── compare.cpython-310.pyc │ ├── compare.cpython-39.pyc │ ├── tensor_to_pycolmap.cpython-39.pyc │ └── tensor_to_pycolmap.cpython-310.pyc ├── quarot │ ├── __pycache__ │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── args_utils.cpython-310.pyc │ │ ├── args_utils.cpython-39.pyc │ │ ├── quant_utils.cpython-39.pyc │ │ ├── train_utils.cpython-39.pyc │ │ ├── quant_utils.cpython-310.pyc │ │ ├── quarot_linear.cpython-39.pyc │ │ ├── quarot_utils.cpython-310.pyc │ │ ├── quarot_utils.cpython-39.pyc │ │ ├── qvdit_linear.cpython-310.pyc │ │ ├── qvdit_utils.cpython-310.pyc │ │ ├── train_utils.cpython-310.pyc │ │ ├── function_utils.cpython-310.pyc │ │ ├── function_utils.cpython-39.pyc │ │ └── quarot_linear.cpython-310.pyc │ ├── function_utils.py │ ├── quarot_linear.py │ ├── utils.py │ ├── quant_utils.py │ └── args_utils.py ├── test.sh ├── .gradio │ └── certificate.pem ├── README.md ├── preprocess_co3d.py └── tensor_to_pycolmap.py ├── vggt ├── heads │ ├── __pycache__ │ │ ├── utils.cpython-39.pyc │ │ ├── dpt_head.cpython-39.pyc │ │ ├── head_act.cpython-39.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── dpt_head.cpython-310.pyc │ │ ├── head_act.cpython-310.pyc │ │ ├── track_head.cpython-39.pyc │ │ ├── camera_head.cpython-310.pyc │ │ ├── camera_head.cpython-39.pyc │ │ └── track_head.cpython-310.pyc │ ├── track_modules │ │ ├── __pycache__ │ │ │ ├── blocks.cpython-39.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ ├── utils.cpython-39.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── blocks.cpython-310.pyc │ │ │ ├── modules.cpython-310.pyc │ │ │ ├── modules.cpython-39.pyc │ │ │ ├── base_track_predictor.cpython-310.pyc │ │ │ └── base_track_predictor.cpython-39.pyc │ │ ├── __init__.py │ │ ├── modules.py │ │ ├── base_track_predictor.py │ │ ├── utils.py │ │ └── blocks.py │ ├── head_act.py │ ├── utils.py │ ├── track_head.py │ └── camera_head.py ├── layers │ ├── __pycache__ │ │ ├── mlp.cpython-310.pyc │ │ ├── mlp.cpython-39.pyc │ │ ├── rope.cpython-39.pyc │ │ ├── block.cpython-310.pyc │ │ ├── block.cpython-39.pyc │ │ ├── rope.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── attention.cpython-39.pyc │ │ ├── drop_path.cpython-39.pyc │ │ ├── attention.cpython-310.pyc │ │ ├── drop_path.cpython-310.pyc │ │ ├── layer_scale.cpython-310.pyc │ │ ├── layer_scale.cpython-39.pyc │ │ ├── patch_embed.cpython-310.pyc │ │ ├── patch_embed.cpython-39.pyc │ │ ├── swiglu_ffn.cpython-310.pyc │ │ ├── swiglu_ffn.cpython-39.pyc │ │ ├── quarot_utils.cpython-310.pyc │ │ ├── vision_transformer.cpython-310.pyc │ │ └── vision_transformer.cpython-39.pyc │ ├── __init__.py │ ├── layer_scale.py │ ├── drop_path.py │ ├── mlp.py │ ├── swiglu_ffn.py │ ├── patch_embed.py │ ├── attention.py │ ├── rope.py │ └── block.py ├── models │ ├── __pycache__ │ │ ├── vggt.cpython-39.pyc │ │ ├── vggt.cpython-310.pyc │ │ ├── aggregator.cpython-310.pyc │ │ └── aggregator.cpython-39.pyc │ └── vggt.py └── utils │ ├── __pycache__ │ ├── geometry.cpython-39.pyc │ ├── load_fn.cpython-310.pyc │ ├── load_fn.cpython-39.pyc │ ├── pose_enc.cpython-39.pyc │ ├── rotation.cpython-39.pyc │ ├── geometry.cpython-310.pyc │ ├── pose_enc.cpython-310.pyc │ └── rotation.cpython-310.pyc │ ├── rotation.py │ ├── pose_enc.py │ ├── load_fn.py │ ├── geometry.py │ └── visual_track.py ├── LICENSE └── README.md /imgs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/imgs/result.png -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/imgs/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.1 2 | torchvision==0.18.1 3 | numpy==1.26.1 4 | Pillow 5 | huggingface_hub 6 | einops 7 | safetensors 8 | -------------------------------------------------------------------------------- /evaluation/__pycache__/ba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/ba.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/ba.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/ba.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/mlp.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/mlp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/mlp.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/rope.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/rope.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/models/__pycache__/vggt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/models/__pycache__/vggt.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/compare.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/compare.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/compare.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/compare.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/dpt_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/dpt_head.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/head_act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/head_act.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/block.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/block.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/block.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/rope.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/rope.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/models/__pycache__/vggt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/models/__pycache__/vggt.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/geometry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/geometry.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/load_fn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/load_fn.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/load_fn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/load_fn.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/pose_enc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/pose_enc.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/rotation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/rotation.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/dpt_head.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/dpt_head.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/head_act.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/head_act.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/track_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/track_head.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/drop_path.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/drop_path.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/geometry.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/geometry.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/pose_enc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/pose_enc.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/utils/__pycache__/rotation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/utils/__pycache__/rotation.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/camera_head.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/camera_head.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/camera_head.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/camera_head.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/__pycache__/track_head.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/__pycache__/track_head.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/drop_path.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/drop_path.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/layer_scale.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/layer_scale.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/layer_scale.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/layer_scale.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/patch_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/patch_embed.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/patch_embed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/patch_embed.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/swiglu_ffn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/swiglu_ffn.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/models/__pycache__/aggregator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/models/__pycache__/aggregator.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/models/__pycache__/aggregator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/models/__pycache__/aggregator.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/quarot_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/quarot_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/tensor_to_pycolmap.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/tensor_to_pycolmap.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/args_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/args_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/args_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/args_utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quant_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quant_utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/train_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/train_utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/tensor_to_pycolmap.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/__pycache__/tensor_to_pycolmap.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quant_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quant_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quarot_linear.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quarot_linear.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quarot_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quarot_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quarot_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quarot_utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/qvdit_linear.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/qvdit_linear.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/qvdit_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/qvdit_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/train_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/train_utils.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/vision_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/layers/__pycache__/vision_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/layers/__pycache__/vision_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/function_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/function_utils.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/function_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/function_utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/quarot/__pycache__/quarot_linear.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/evaluation/quarot/__pycache__/quarot_linear.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wlfeng0509/QuantVGGT/HEAD/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-39.pyc -------------------------------------------------------------------------------- /vggt/heads/track_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /vggt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /evaluation/test.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=6 python test_co3d_1.py \ 3 | --model_path /data/fwl/VGGT_quant/VGGT-1B/model_tracker_fixed_e20.pt \ 4 | --co3d_dir /data1/3d_datasets/ \ 5 | --co3d_anno_dir /data1/fwl/datasets/co3d_v2_annotations/ \ 6 | --dtype quantvggt_w4a4\ 7 | --seed 0 \ 8 | --lac \ 9 | --lwc \ 10 | --exp_name a44_quant \ 11 | --debug_mode all \ 12 | --fast_eval \ 13 | --resume_qs \ 14 | 15 | -------------------------------------------------------------------------------- /vggt/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 wlfeng0509 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /vggt/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /vggt/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | # from .quarot_utils import random_hadamard_matrix 15 | import torch 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__( 20 | self, 21 | in_features: int, 22 | hidden_features: Optional[int] = None, 23 | out_features: Optional[int] = None, 24 | act_layer: Callable[..., nn.Module] = nn.GELU, 25 | drop: float = 0.0, 26 | bias: bool = True, 27 | ) -> None: 28 | super().__init__() 29 | out_features = out_features or in_features 30 | hidden_features = hidden_features or in_features 31 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 32 | self.act = act_layer() 33 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 34 | self.drop = nn.Dropout(drop) 35 | 36 | def forward(self, x: Tensor) -> Tensor: 37 | x = self.fc1(x) 38 | x = self.act(x) 39 | x = self.drop(x) 40 | x = self.fc2(x) 41 | x = self.drop(x) 42 | return x 43 | -------------------------------------------------------------------------------- /evaluation/.gradio/certificate.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw 3 | TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh 4 | cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 5 | WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu 6 | ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY 7 | MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc 8 | h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ 9 | 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U 10 | A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW 11 | T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH 12 | B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC 13 | B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv 14 | KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn 15 | OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn 16 | jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw 17 | qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI 18 | rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV 19 | HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq 20 | hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL 21 | ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ 22 | 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK 23 | NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 24 | ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur 25 | TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC 26 | jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc 27 | oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq 28 | 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA 29 | mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d 30 | emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= 31 | -----END CERTIFICATE----- 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantized Visual Geometry Grounded Transformer 2 | 3 | [arXiv](https://arxiv.org/abs/2509.21302) | [BibTeX](#bibtex) 4 | 5 | ------ 6 | 7 | This project is the official implementation of our QuantVGGT: "Quantized Visual Geometry Grounded Transformer". 8 | 9 | ![teaser](imgs/teaser.png) 10 | 11 | ![overview](imgs/overview.png) 12 | 13 | ------ 14 | 15 | ## Results 16 | 17 | ![result](imgs/result.png) 18 | 19 | ## Updates 20 | 21 | - [October 10, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available. 22 | 23 | ## Quick Start 24 | 25 | First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub). 26 | 27 | ``` 28 | git clone git@github.com:wlfeng0509/QuantVGGT.git 29 | cd QuantVGGT 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | Then download the pre trained weights provided by [VGGT](https://github.com/facebookresearch/vggt) and prepare Co3D dataset following [this](https://github.com/facebookresearch/vggt/tree/evaluation/evaluation). 34 | 35 | Then download the pre trained W4A4 quantization parameters from [huggingface](https://huggingface.co/wlfeng/QuantVGGT/tree/main) and place the downloaded folder under *evaluation\outputs\w4a4* branch. 36 | 37 | We can now use the provided script for inference **(remember to change the data path within the script)**. 38 | 39 | ``` 40 | cd evaluation 41 | bash test.sh 42 | ``` 43 | 44 | Also, you can use the quantized model for predicting other 3D attributes following the guidance [here](https://github.com/facebookresearch/vggt/tree/evaluation#detailed-usage). 45 | 46 | ## Comments 47 | 48 | * Our codebase is heavily builds on [VGGT](https://github.com/facebookresearch/vggt) and [QuaRot](https://github.com/spcl/QuaRot). Thanks for open-sourcing! 49 | 50 | ## BibTeX 51 | 52 | If you find *QuantVGGT* is useful and helpful to your work, please kindly cite this paper: 53 | 54 | ``` 55 | @article{feng2025quantized, 56 | title={Quantized Visual Geometry Grounded Transformer}, 57 | author={Feng, Weilun and Qin, Haotong and Wu, Mingqiang and Yang, Chuanguang and Li, Yuqi and Li, Xiangqi and An, Zhulin and Huang, Libo and Zhang, Yulun and Magno, Michele and others}, 58 | journal={arXiv preprint arXiv:2509.21302}, 59 | year={2025} 60 | } 61 | ``` 62 | 63 | -------------------------------------------------------------------------------- /vggt/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | # try: 39 | # if XFORMERS_ENABLED: 40 | # from xformers.ops import SwiGLU 41 | 42 | # XFORMERS_AVAILABLE = True 43 | # warnings.warn("xFormers is available (SwiGLU)") 44 | # else: 45 | # warnings.warn("xFormers is disabled (SwiGLU)") 46 | # raise ImportError 47 | # except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | # warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /evaluation/quarot/function_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | from scipy.linalg import qr 5 | from collections import OrderedDict 6 | 7 | def get_init_scale(w_smax, x_smax, alpha=0.5): 8 | return (w_smax.pow(1 - alpha) / x_smax.pow(alpha)).clamp(min=1e-5) 9 | 10 | 11 | def get_decompose_dim(n): 12 | a = int(math.sqrt(n)) 13 | if a * a < n: 14 | a += 1 15 | while True: 16 | tmp = a*a - n 17 | b = int(math.sqrt(tmp)) 18 | if b * b == tmp: 19 | break 20 | a += 1 21 | return a - b, a + b 22 | 23 | 24 | def get_random_orthg(size): 25 | H = np.random.randn(size, size) 26 | Q, R = qr(H) 27 | Q_modified = Q @ np.diag(np.sign(np.diag(R))) 28 | return torch.from_numpy(Q_modified) 29 | 30 | 31 | def get_init_weight(dim, ): 32 | return get_random_orthg(dim) 33 | 34 | 35 | def get_inverse(matrix): 36 | dtype = matrix.dtype 37 | return matrix.double().inverse().to(dtype) 38 | 39 | 40 | def get_n_set_parameters_byname(model, required_names): 41 | params = [] 42 | for r_name in required_names: 43 | for name, param in model.named_parameters(): 44 | if name.find(r_name) > -1: 45 | param.requires_grad = True 46 | params.append(param) 47 | return params 48 | 49 | def get_paras_dict_by_name(model, required_names, destination=None, prefix=''): 50 | if destination is None: 51 | destination = OrderedDict() 52 | for r_name in required_names: 53 | for name, param in model.named_parameters(): 54 | if name.find(r_name) > -1: 55 | destination[prefix + name] = param.detach() 56 | return destination 57 | 58 | 59 | def check_params_grad(model): 60 | for name, param in model.named_parameters(): 61 | print(name, ':{}'.format(param.requires_grad)) 62 | return 63 | 64 | 65 | def print_trainable_parameters(model): 66 | """ 67 | Prints the number of trainable parameters in the model. 68 | """ 69 | trainable_params = 0 70 | all_param = 0 71 | for _, param in model.named_parameters(): 72 | all_param += param.numel() 73 | if param.requires_grad: 74 | trainable_params += param.numel() 75 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param:.2f}%") 76 | 77 | 78 | def set_require_grad_all(model, requires_grad): 79 | for name, param in model.named_parameters(): 80 | param.requires_grad = requires_grad 81 | return -------------------------------------------------------------------------------- /vggt/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /vggt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | import torch.nn.functional as F 17 | # from .quarot_utils import random_hadamard_matrix 18 | import torch 19 | 20 | XFORMERS_AVAILABLE = False 21 | 22 | 23 | class Attention(nn.Module): 24 | def __init__( 25 | self, 26 | dim: int, 27 | num_heads: int = 8, 28 | qkv_bias: bool = True, 29 | proj_bias: bool = True, 30 | attn_drop: float = 0.0, 31 | proj_drop: float = 0.0, 32 | norm_layer: nn.Module = nn.LayerNorm, 33 | qk_norm: bool = False, 34 | fused_attn: bool = True, # use F.scaled_dot_product_attention or not 35 | rope=None, 36 | ) -> None: 37 | super().__init__() 38 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 39 | self.num_heads = num_heads 40 | self.head_dim = dim // num_heads 41 | self.scale = self.head_dim**-0.5 42 | self.fused_attn = fused_attn 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 46 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | self.rope = rope 51 | 52 | def forward(self, x: Tensor, pos=None) -> Tensor: 53 | B, N, C = x.shape 54 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 55 | q, k, v = qkv.unbind(0) 56 | q, k = self.q_norm(q), self.k_norm(k) 57 | 58 | if self.rope is not None: 59 | q = self.rope(q, pos) 60 | k = self.rope(k, pos) 61 | 62 | if self.fused_attn: 63 | x = F.scaled_dot_product_attention( 64 | q, 65 | k, 66 | v, 67 | dropout_p=self.attn_drop.p if self.training else 0.0, 68 | ) 69 | else: 70 | q = q * self.scale 71 | attn = q @ k.transpose(-2, -1) 72 | attn = attn.softmax(dim=-1) 73 | attn = self.attn_drop(attn) 74 | x = attn @ v 75 | 76 | x = x.transpose(1, 2).reshape(B, N, C) 77 | x = self.proj(x) 78 | x = self.proj_drop(x) 79 | return x 80 | 81 | 82 | class MemEffAttention(Attention): 83 | def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: 84 | assert pos is None 85 | if not XFORMERS_AVAILABLE: 86 | if attn_bias is not None: 87 | raise AssertionError("xFormers is required for using nested tensors") 88 | return super().forward(x) 89 | 90 | B, N, C = x.shape 91 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 92 | 93 | q, k, v = unbind(qkv, 2) 94 | 95 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 96 | x = x.reshape([B, N, C]) 97 | 98 | x = self.proj(x) 99 | x = self.proj_drop(x) 100 | return x 101 | -------------------------------------------------------------------------------- /vggt/heads/head_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): 13 | """ 14 | Activate pose parameters with specified activation functions. 15 | 16 | Args: 17 | pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] 18 | trans_act: Activation type for translation component 19 | quat_act: Activation type for quaternion component 20 | fl_act: Activation type for focal length component 21 | 22 | Returns: 23 | Activated pose parameters tensor 24 | """ 25 | T = pred_pose_enc[..., :3] 26 | quat = pred_pose_enc[..., 3:7] 27 | fl = pred_pose_enc[..., 7:] # or fov 28 | 29 | T = base_pose_act(T, trans_act) 30 | quat = base_pose_act(quat, quat_act) 31 | fl = base_pose_act(fl, fl_act) # or fov 32 | 33 | pred_pose_enc = torch.cat([T, quat, fl], dim=-1) 34 | 35 | return pred_pose_enc 36 | 37 | 38 | def base_pose_act(pose_enc, act_type="linear"): 39 | """ 40 | Apply basic activation function to pose parameters. 41 | 42 | Args: 43 | pose_enc: Tensor containing encoded pose parameters 44 | act_type: Activation type ("linear", "inv_log", "exp", "relu") 45 | 46 | Returns: 47 | Activated pose parameters 48 | """ 49 | if act_type == "linear": 50 | return pose_enc 51 | elif act_type == "inv_log": 52 | return inverse_log_transform(pose_enc) 53 | elif act_type == "exp": 54 | return torch.exp(pose_enc) 55 | elif act_type == "relu": 56 | return F.relu(pose_enc) 57 | else: 58 | raise ValueError(f"Unknown act_type: {act_type}") 59 | 60 | 61 | def activate_head(out, activation="norm_exp", conf_activation="expp1"): 62 | """ 63 | Process network output to extract 3D points and confidence values. 64 | 65 | Args: 66 | out: Network output tensor (B, C, H, W) 67 | activation: Activation type for 3D points 68 | conf_activation: Activation type for confidence values 69 | 70 | Returns: 71 | Tuple of (3D points tensor, confidence tensor) 72 | """ 73 | # Move channels from last dim to the 4th dimension => (B, H, W, C) 74 | fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected 75 | 76 | # Split into xyz (first C-1 channels) and confidence (last channel) 77 | xyz = fmap[:, :, :, :-1] 78 | conf = fmap[:, :, :, -1] 79 | 80 | if activation == "norm_exp": 81 | d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) 82 | xyz_normed = xyz / d 83 | pts3d = xyz_normed * torch.expm1(d) 84 | elif activation == "norm": 85 | pts3d = xyz / xyz.norm(dim=-1, keepdim=True) 86 | elif activation == "exp": 87 | pts3d = torch.exp(xyz) 88 | elif activation == "relu": 89 | pts3d = F.relu(xyz) 90 | elif activation == "inv_log": 91 | pts3d = inverse_log_transform(xyz) 92 | elif activation == "xy_inv_log": 93 | xy, z = xyz.split([2, 1], dim=-1) 94 | z = inverse_log_transform(z) 95 | pts3d = torch.cat([xy * z, z], dim=-1) 96 | elif activation == "sigmoid": 97 | pts3d = torch.sigmoid(xyz) 98 | elif activation == "linear": 99 | pts3d = xyz 100 | else: 101 | raise ValueError(f"Unknown activation: {activation}") 102 | 103 | if conf_activation == "expp1": 104 | conf_out = 1 + conf.exp() 105 | elif conf_activation == "expp0": 106 | conf_out = conf.exp() 107 | elif conf_activation == "sigmoid": 108 | conf_out = torch.sigmoid(conf) 109 | else: 110 | raise ValueError(f"Unknown conf_activation: {conf_activation}") 111 | 112 | return pts3d, conf_out 113 | 114 | 115 | def inverse_log_transform(y): 116 | """ 117 | Apply inverse log transform: sign(y) * (exp(|y|) - 1) 118 | 119 | Args: 120 | y: Input tensor 121 | 122 | Returns: 123 | Transformed tensor 124 | """ 125 | return torch.sign(y) * (torch.expm1(torch.abs(y))) 126 | -------------------------------------------------------------------------------- /vggt/heads/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: 12 | """ 13 | Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) 14 | 15 | Args: 16 | pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates 17 | embed_dim: Output channel dimension for embeddings 18 | 19 | Returns: 20 | Tensor of shape (H, W, embed_dim) with positional embeddings 21 | """ 22 | H, W, grid_dim = pos_grid.shape 23 | assert grid_dim == 2 24 | pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) 25 | 26 | # Process x and y coordinates separately 27 | emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] 28 | emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] 29 | 30 | # Combine and reshape 31 | emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] 32 | 33 | return emb.view(H, W, embed_dim) # [H, W, D] 34 | 35 | 36 | def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: 37 | """ 38 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 39 | 40 | Args: 41 | - embed_dim: The embedding dimension. 42 | - pos: The position to generate the embedding from. 43 | 44 | Returns: 45 | - emb: The generated 1D positional embedding. 46 | """ 47 | assert embed_dim % 2 == 0 48 | device = pos.device 49 | omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) 50 | omega /= embed_dim / 2.0 51 | omega = 1.0 / omega_0**omega # (D/2,) 52 | 53 | pos = pos.reshape(-1) # (M,) 54 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 55 | 56 | emb_sin = torch.sin(out) # (M, D/2) 57 | emb_cos = torch.cos(out) # (M, D/2) 58 | 59 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 60 | return emb.float() 61 | 62 | 63 | # Inspired by https://github.com/microsoft/moge 64 | 65 | 66 | def create_uv_grid( 67 | width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None 68 | ) -> torch.Tensor: 69 | """ 70 | Create a normalized UV grid of shape (width, height, 2). 71 | 72 | The grid spans horizontally and vertically according to an aspect ratio, 73 | ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right 74 | corner is at (x_span, y_span), normalized by the diagonal of the plane. 75 | 76 | Args: 77 | width (int): Number of points horizontally. 78 | height (int): Number of points vertically. 79 | aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. 80 | dtype (torch.dtype, optional): Data type of the resulting tensor. 81 | device (torch.device, optional): Device on which the tensor is created. 82 | 83 | Returns: 84 | torch.Tensor: A (width, height, 2) tensor of UV coordinates. 85 | """ 86 | # Derive aspect ratio if not explicitly provided 87 | if aspect_ratio is None: 88 | aspect_ratio = float(width) / float(height) 89 | 90 | # Compute normalized spans for X and Y 91 | diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 92 | span_x = aspect_ratio / diag_factor 93 | span_y = 1.0 / diag_factor 94 | 95 | # Establish the linspace boundaries 96 | left_x = -span_x * (width - 1) / width 97 | right_x = span_x * (width - 1) / width 98 | top_y = -span_y * (height - 1) / height 99 | bottom_y = span_y * (height - 1) / height 100 | 101 | # Generate 1D coordinates 102 | x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) 103 | y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) 104 | 105 | # Create 2D meshgrid (width x height) and stack into UV 106 | uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") 107 | uv_grid = torch.stack((uu, vv), dim=-1) 108 | 109 | return uv_grid 110 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # VGGT Evaluation 2 | 3 | This repository contains code to reproduce the evaluation results presented in the VGGT paper. 4 | 5 | ## Table of Contents 6 | 7 | - [Camera Pose Estimation on Co3D](#camera-pose-estimation-on-co3d) 8 | - [Model Weights](#model-weights) 9 | - [Setup](#setup) 10 | - [Dataset Preparation](#dataset-preparation) 11 | - [Running the Evaluation](#running-the-evaluation) 12 | - [Expected Results](#expected-results) 13 | - [Checklist](#checklist) 14 | 15 | ## Camera Pose Estimation on Co3D 16 | 17 | ### Model Weights 18 | 19 | We have addressed a minor bug in the publicly released checkpoint related to the TrackHead configuration. Specifically, the `pos_embed` flag was incorrectly set to `False`. The following checkpoint incorporates this fix by fine-tuning the tracker head with `pos_embed` as `True` while preserving all other parameters. This fix will be merged into the main branch in a future update. 20 | 21 | ```bash 22 | wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt 23 | ``` 24 | 25 | Note: The default checkpoint remains functional, though you may observe a slight performance decrease (approximately 0.3% in AUC@30) when using Bundle Adjustment (BA). If using the default checkpoint, ensure you set `pos_embed` to `False` for the TrackHead. This modification only affects tracking-based evaluations and has no impact on feed-forward estimation performance, as tracking is not utilized in the feed-forward approach. 26 | 27 | ### Setup 28 | 29 | Install the required dependencies: 30 | 31 | ```bash 32 | # Install VGGT as a package 33 | pip install -e . 34 | 35 | # Install evaluation dependencies 36 | pip install pycolmap==3.10.0 pyceres==2.3 37 | 38 | # Install LightGlue for keypoint detection 39 | git clone https://github.com/cvg/LightGlue.git 40 | cd LightGlue 41 | python -m pip install -e . 42 | cd .. 43 | ``` 44 | 45 | ### Dataset Preparation 46 | 47 | 1. Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d) 48 | 49 | 2. Preprocess the dataset (approximately 5 minutes): 50 | ```bash 51 | python preprocess_co3d.py --category all \ 52 | --co3d_v2_dir /YOUR/CO3D/PATH \ 53 | --output_dir /YOUR/CO3D/ANNO/PATH 54 | ``` 55 | 56 | Replace `/YOUR/CO3D/PATH` with the path to your downloaded Co3D dataset, and `/YOUR/CO3D/ANNO/PATH` with the desired output directory for the processed annotations. 57 | 58 | ### Running the Evaluation 59 | 60 | Choose one of these evaluation modes: 61 | 62 | ```bash 63 | # Standard VGGT evaluation 64 | python test_co3d.py \ 65 | --model_path /YOUR/MODEL/PATH \ 66 | --co3d_dir /YOUR/CO3D/PATH \ 67 | --co3d_anno_dir /YOUR/CO3D/ANNO/PATH \ 68 | --seed 0 69 | 70 | # VGGT with Bundle Adjustment 71 | python test_co3d.py \ 72 | --model_path /YOUR/MODEL/PATH \ 73 | --co3d_dir /YOUR/CO3D/PATH \ 74 | --co3d_anno_dir /YOUR/CO3D/ANNO/PATH \ 75 | --seed 0 \ 76 | --use_ba 77 | ``` 78 | 79 | 80 | 81 | 82 | ### Expected Results 83 | 84 | #### Quick Evaluation 85 | Full evaluation on Co3D can take a long time. For faster trials, you can run with ```--fast_eval```. This does exactly the same but limiting to evaluate over at most 10 sequence per category. 86 | 87 | Use `--fast_eval` to test on a subset of data (max 10 sequences per category): 88 | 89 | - Feed-forward estimation: 90 | - AUC@30: 89.45 91 | - AUC@15: 83.29 92 | - AUC@5: 66.86 93 | - AUC@3: 56.08 94 | 95 | - With Bundle Adjustment (`--use_ba`): 96 | - AUC@30: 90.11 97 | - AUC@15: 84.39 98 | - AUC@5: 70.02 99 | - AUC@3: 60.51 100 | 101 | #### Full Evaluation 102 | 103 | - Feedforward estimation achieves a Mean AUC@30 of 89.5% (slightly higher than the 88.2% reported in the paper due to implementation differences) 104 | - With Bundle Adjustment, you can expect a Mean AUC@30 between 90.5% and 92.5% 105 | 106 | > **Note:** For simplicity, this script did not optimize the inference speed, so timing results may differ from those reported in the paper. For example, when using ba, keypoint extractor models are re-initialized for each sequence rather than being loaded once. 107 | 108 | ## Checklist 109 | 110 | The following features are planned for future releases: 111 | 112 | - [x] Camera pose estimation code on Co3D 113 | - [x] VGGT+BA (Bundle Adjustment) on Co3D 114 | - [ ] Evaluation on Re10K dataset 115 | - [ ] Evaluation on IMC dataset 116 | - [ ] Evaluation of multi-view depth estimation 117 | 118 | --- 119 | -------------------------------------------------------------------------------- /vggt/heads/track_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn as nn 8 | from .dpt_head import DPTHead 9 | from .track_modules.base_track_predictor import BaseTrackerPredictor 10 | 11 | 12 | class TrackHead(nn.Module): 13 | """ 14 | Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. 15 | The tracking is performed iteratively, refining predictions over multiple iterations. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | dim_in, 21 | patch_size=14, 22 | features=128, 23 | iters=4, 24 | predict_conf=True, 25 | stride=2, 26 | corr_levels=7, 27 | corr_radius=4, 28 | hidden_size=384, 29 | pos_embed=True, 30 | ): 31 | """ 32 | Initialize the TrackHead module. 33 | 34 | Args: 35 | dim_in (int): Input dimension of tokens from the backbone. 36 | patch_size (int): Size of image patches used in the vision transformer. 37 | features (int): Number of feature channels in the feature extractor output. 38 | iters (int): Number of refinement iterations for tracking predictions. 39 | predict_conf (bool): Whether to predict confidence scores for tracked points. 40 | stride (int): Stride value for the tracker predictor. 41 | corr_levels (int): Number of correlation pyramid levels 42 | corr_radius (int): Radius for correlation computation, controlling the search area. 43 | hidden_size (int): Size of hidden layers in the tracker network. 44 | """ 45 | super().__init__() 46 | 47 | self.patch_size = patch_size 48 | 49 | # Feature extractor based on DPT architecture 50 | # Processes tokens into feature maps for tracking 51 | self.feature_extractor = DPTHead( 52 | dim_in=dim_in, 53 | patch_size=patch_size, 54 | features=features, 55 | feature_only=True, # Only output features, no activation 56 | down_ratio=2, # Reduces spatial dimensions by factor of 2 57 | pos_embed=pos_embed, 58 | ) 59 | 60 | # Tracker module that predicts point trajectories 61 | # Takes feature maps and predicts coordinates and visibility 62 | self.tracker = BaseTrackerPredictor( 63 | latent_dim=features, # Match the output_dim of feature extractor 64 | predict_conf=predict_conf, 65 | stride=stride, 66 | corr_levels=corr_levels, 67 | corr_radius=corr_radius, 68 | hidden_size=hidden_size, 69 | ) 70 | 71 | self.iters = iters 72 | 73 | def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): 74 | """ 75 | Forward pass of the TrackHead. 76 | 77 | Args: 78 | aggregated_tokens_list (list): List of aggregated tokens from the backbone. 79 | images (torch.Tensor): Input images of shape (B, S, C, H, W) where: 80 | B = batch size, S = sequence length. 81 | patch_start_idx (int): Starting index for patch tokens. 82 | query_points (torch.Tensor, optional): Initial query points to track. 83 | If None, points are initialized by the tracker. 84 | iters (int, optional): Number of refinement iterations. If None, uses self.iters. 85 | 86 | Returns: 87 | tuple: 88 | - coord_preds (torch.Tensor): Predicted coordinates for tracked points. 89 | - vis_scores (torch.Tensor): Visibility scores for tracked points. 90 | - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). 91 | """ 92 | B, S, _, H, W = images.shape 93 | 94 | # Extract features from tokens 95 | # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 96 | feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) 97 | 98 | # Use default iterations if not specified 99 | if iters is None: 100 | iters = self.iters 101 | 102 | # Perform tracking using the extracted features 103 | coord_preds, vis_scores, conf_scores = self.tracker( 104 | query_points=query_points, 105 | fmaps=feature_maps, 106 | iters=iters, 107 | ) 108 | 109 | return coord_preds, vis_scores, conf_scores 110 | -------------------------------------------------------------------------------- /vggt/models/vggt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from huggingface_hub import PyTorchModelHubMixin # used for model hub 10 | 11 | from vggt.models.aggregator import Aggregator 12 | from vggt.heads.camera_head import CameraHead 13 | from vggt.heads.dpt_head import DPTHead 14 | from vggt.heads.track_head import TrackHead 15 | 16 | 17 | class VGGT(nn.Module, PyTorchModelHubMixin): 18 | def __init__(self, img_size=518, patch_size=14, embed_dim=1024): 19 | super().__init__() 20 | 21 | self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) 22 | self.camera_head = CameraHead(dim_in=2 * embed_dim) 23 | self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") 24 | self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") 25 | self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) 26 | 27 | def forward( 28 | self, 29 | images: torch.Tensor, 30 | query_points: torch.Tensor = None, 31 | ): 32 | """ 33 | Forward pass of the VGGT model. 34 | 35 | Args: 36 | images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. 37 | B: batch size, S: sequence length, 3: RGB channels, H: height, W: width 38 | query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. 39 | Shape: [N, 2] or [B, N, 2], where N is the number of query points. 40 | Default: None 41 | 42 | Returns: 43 | dict: A dictionary containing the following predictions: 44 | - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) 45 | - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] 46 | - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] 47 | - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] 48 | - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] 49 | - images (torch.Tensor): Original input images, preserved for visualization 50 | 51 | If query_points is provided, also includes: 52 | - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates 53 | - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] 54 | - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] 55 | """ 56 | 57 | # If without batch dimension, add it 58 | if len(images.shape) == 4: 59 | images = images.unsqueeze(0) 60 | if query_points is not None and len(query_points.shape) == 2: 61 | query_points = query_points.unsqueeze(0) 62 | 63 | aggregated_tokens_list, patch_start_idx = self.aggregator(images) 64 | 65 | predictions = {} 66 | 67 | with torch.cuda.amp.autocast(enabled=False): 68 | if self.camera_head is not None: 69 | pose_enc_list = self.camera_head(aggregated_tokens_list) 70 | predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration 71 | 72 | if self.depth_head is not None: 73 | depth, depth_conf = self.depth_head( 74 | aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx 75 | ) 76 | predictions["depth"] = depth 77 | predictions["depth_conf"] = depth_conf 78 | 79 | if self.point_head is not None: 80 | pts3d, pts3d_conf = self.point_head( 81 | aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx 82 | ) 83 | predictions["world_points"] = pts3d 84 | predictions["world_points_conf"] = pts3d_conf 85 | 86 | if self.track_head is not None and query_points is not None: 87 | track_list, vis, conf = self.track_head( 88 | aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points 89 | ) 90 | predictions["track"] = track_list[-1] # track of the last iteration 91 | predictions["vis"] = vis 92 | predictions["conf"] = conf 93 | 94 | predictions["images"] = images 95 | 96 | return predictions 97 | -------------------------------------------------------------------------------- /evaluation/preprocess_co3d.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/amyxlase/relpose-plus-plus/blob/main/preprocess/preprocess_co3d.py 2 | 3 | 4 | """ 5 | Usage: 6 | python -m preprocess.preprocess_co3d --category all \ 7 | --co3d_v2_dir /path/to/co3d_v2 8 | """ 9 | import argparse 10 | import gzip 11 | import json 12 | import os 13 | import os.path as osp 14 | from glob import glob 15 | 16 | # import ipdb 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | from tqdm.auto import tqdm 20 | 21 | # fmt: off 22 | CATEGORIES = [ 23 | "apple", "backpack", "ball", "banana", "baseballbat", "baseballglove", 24 | "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot", 25 | "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag", 26 | "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave", "motorcycle", 27 | "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich", 28 | "skateboard", "stopsign", "suitcase", "teddybear", "toaster", "toilet", "toybus", 29 | "toyplane", "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", 30 | ] 31 | # fmt: on 32 | 33 | 34 | def get_parser(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--category", type=str, default="all") 37 | parser.add_argument("--output_dir", type=str, default="/data2/fwl/datasets/co3d_v2_annotations/") 38 | parser.add_argument("--co3d_v2_dir", type=str, default="/data1/3d_datasets/") 39 | parser.add_argument( 40 | "--min_quality", 41 | type=float, 42 | default=0.5, 43 | help="Minimum viewpoint quality score.", 44 | ) 45 | return parser 46 | 47 | 48 | 49 | 50 | def process_poses(co3d_dir, category, output_dir, min_quality): 51 | category_dir = osp.join(co3d_dir, category) 52 | print("Processing category:", category) 53 | frame_file = osp.join(category_dir, "frame_annotations.jgz") 54 | sequence_file = osp.join(category_dir, "sequence_annotations.jgz") 55 | subset_lists_file = osp.join(category_dir, "set_lists/set_lists_fewview_dev.json") 56 | 57 | # bbox_file = osp.join(output_dir, f"{category}_bbox.jgz") 58 | 59 | with open(subset_lists_file) as f: 60 | subset_lists_data = json.load(f) 61 | 62 | with gzip.open(sequence_file, "r") as fin: 63 | sequence_data = json.loads(fin.read()) 64 | 65 | with gzip.open(frame_file, "r") as fin: 66 | frame_data = json.loads(fin.read()) 67 | 68 | # with gzip.open(bbox_file, "r") as fin: 69 | # bbox_data = json.loads(fin.read()) 70 | 71 | frame_data_processed = {} 72 | for f_data in frame_data: 73 | sequence_name = f_data["sequence_name"] 74 | if sequence_name not in frame_data_processed: 75 | frame_data_processed[sequence_name] = {} 76 | frame_data_processed[sequence_name][f_data["frame_number"]] = f_data 77 | 78 | good_quality_sequences = set() 79 | for seq_data in sequence_data: 80 | if seq_data["viewpoint_quality_score"] > min_quality: 81 | good_quality_sequences.add(seq_data["sequence_name"]) 82 | 83 | for subset in ["train", "test"]: 84 | category_data = {} # {sequence_name: [{filepath, R, T}]} 85 | for seq_name, frame_number, filepath in subset_lists_data[subset]: 86 | if seq_name not in good_quality_sequences: 87 | continue 88 | 89 | if seq_name not in category_data: 90 | category_data[seq_name] = [] 91 | 92 | # mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") 93 | # bbox = bbox_data[mask_path] 94 | # if bbox == []: 95 | # Mask did not include any object. 96 | # continue 97 | 98 | frame_data = frame_data_processed[seq_name][frame_number] 99 | category_data[seq_name].append( 100 | { 101 | "filepath": filepath, 102 | "R": frame_data["viewpoint"]["R"], 103 | "T": frame_data["viewpoint"]["T"], 104 | "focal_length": frame_data["viewpoint"]["focal_length"], 105 | "principal_point": frame_data["viewpoint"]["principal_point"], 106 | # "bbox": bbox, 107 | } 108 | ) 109 | 110 | output_file = osp.join(output_dir, f"{category}_{subset}.jgz") 111 | with gzip.open(output_file, "w") as f: 112 | f.write(json.dumps(category_data).encode("utf-8")) 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = get_parser() 118 | args = parser.parse_args() 119 | if args.category == "all": 120 | categories = CATEGORIES 121 | else: 122 | categories = [args.category] 123 | for category in categories: 124 | process_poses( 125 | co3d_dir=args.co3d_v2_dir, 126 | category=category, 127 | output_dir=args.output_dir, 128 | min_quality=args.min_quality, 129 | ) 130 | -------------------------------------------------------------------------------- /vggt/utils/rotation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d 8 | 9 | import torch 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | 14 | def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Quaternion Order: XYZW or say ijkr, scalar-last 17 | 18 | Convert rotations given as quaternions to rotation matrices. 19 | Args: 20 | quaternions: quaternions with real part last, 21 | as tensor of shape (..., 4). 22 | 23 | Returns: 24 | Rotation matrices as tensor of shape (..., 3, 3). 25 | """ 26 | i, j, k, r = torch.unbind(quaternions, -1) 27 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 28 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 29 | 30 | o = torch.stack( 31 | ( 32 | 1 - two_s * (j * j + k * k), 33 | two_s * (i * j - k * r), 34 | two_s * (i * k + j * r), 35 | two_s * (i * j + k * r), 36 | 1 - two_s * (i * i + k * k), 37 | two_s * (j * k - i * r), 38 | two_s * (i * k - j * r), 39 | two_s * (j * k + i * r), 40 | 1 - two_s * (i * i + j * j), 41 | ), 42 | -1, 43 | ) 44 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 45 | 46 | 47 | def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: 48 | """ 49 | Convert rotations given as rotation matrices to quaternions. 50 | 51 | Args: 52 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 53 | 54 | Returns: 55 | quaternions with real part last, as tensor of shape (..., 4). 56 | Quaternion Order: XYZW or say ijkr, scalar-last 57 | """ 58 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 59 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 60 | 61 | batch_dim = matrix.shape[:-2] 62 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) 63 | 64 | q_abs = _sqrt_positive_part( 65 | torch.stack( 66 | [ 67 | 1.0 + m00 + m11 + m22, 68 | 1.0 + m00 - m11 - m22, 69 | 1.0 - m00 + m11 - m22, 70 | 1.0 - m00 - m11 + m22, 71 | ], 72 | dim=-1, 73 | ) 74 | ) 75 | 76 | # we produce the desired quaternion multiplied by each of r, i, j, k 77 | quat_by_rijk = torch.stack( 78 | [ 79 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 80 | # `int`. 81 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 82 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 83 | # `int`. 84 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 85 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 86 | # `int`. 87 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 88 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 89 | # `int`. 90 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 91 | ], 92 | dim=-2, 93 | ) 94 | 95 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 96 | # the candidate won't be picked. 97 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 98 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 99 | 100 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 101 | # forall i; we pick the best-conditioned one (with the largest denominator) 102 | out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) 103 | 104 | # Convert from rijk to ijkr 105 | out = out[..., [1, 2, 3, 0]] 106 | 107 | out = standardize_quaternion(out) 108 | 109 | return out 110 | 111 | 112 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Returns torch.sqrt(torch.max(0, x)) 115 | but with a zero subgradient where x is 0. 116 | """ 117 | ret = torch.zeros_like(x) 118 | positive_mask = x > 0 119 | if torch.is_grad_enabled(): 120 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 121 | else: 122 | ret = torch.where(positive_mask, torch.sqrt(x), ret) 123 | return ret 124 | 125 | 126 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 127 | """ 128 | Convert a unit quaternion to a standard form: one in which the real 129 | part is non negative. 130 | 131 | Args: 132 | quaternions: Quaternions with real part last, 133 | as tensor of shape (..., 4). 134 | 135 | Returns: 136 | Standardized quaternions as tensor of shape (..., 4). 137 | """ 138 | return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) 139 | -------------------------------------------------------------------------------- /vggt/utils/pose_enc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from .rotation import quat_to_mat, mat_to_quat 9 | 10 | 11 | def extri_intri_to_pose_encoding( 12 | extrinsics, 13 | intrinsics, 14 | image_size_hw=None, # e.g., (256, 512) 15 | pose_encoding_type="absT_quaR_FoV", 16 | ): 17 | """Convert camera extrinsics and intrinsics to a compact pose encoding. 18 | 19 | This function transforms camera parameters into a unified pose encoding format, 20 | which can be used for various downstream tasks like pose prediction or representation. 21 | 22 | Args: 23 | extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, 24 | where B is batch size and S is sequence length. 25 | In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. 26 | The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. 27 | intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. 28 | Defined in pixels, with format: 29 | [[fx, 0, cx], 30 | [0, fy, cy], 31 | [0, 0, 1]] 32 | where fx, fy are focal lengths and (cx, cy) is the principal point 33 | image_size_hw (tuple): Tuple of (height, width) of the image in pixels. 34 | Required for computing field of view values. For example: (256, 512). 35 | pose_encoding_type (str): Type of pose encoding to use. Currently only 36 | supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). 37 | 38 | Returns: 39 | torch.Tensor: Encoded camera pose parameters with shape BxSx9. 40 | For "absT_quaR_FoV" type, the 9 dimensions are: 41 | - [:3] = absolute translation vector T (3D) 42 | - [3:7] = rotation as quaternion quat (4D) 43 | - [7:] = field of view (2D) 44 | """ 45 | 46 | # extrinsics: BxSx3x4 47 | # intrinsics: BxSx3x3 48 | 49 | if pose_encoding_type == "absT_quaR_FoV": 50 | R = extrinsics[:, :, :3, :3] # BxSx3x3 51 | T = extrinsics[:, :, :3, 3] # BxSx3 52 | 53 | quat = mat_to_quat(R) 54 | # Note the order of h and w here 55 | H, W = image_size_hw 56 | fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) 57 | fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) 58 | pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() 59 | else: 60 | raise NotImplementedError 61 | 62 | return pose_encoding 63 | 64 | 65 | def pose_encoding_to_extri_intri( 66 | pose_encoding, 67 | image_size_hw=None, # e.g., (256, 512) 68 | pose_encoding_type="absT_quaR_FoV", 69 | build_intrinsics=True, 70 | ): 71 | """Convert a pose encoding back to camera extrinsics and intrinsics. 72 | 73 | This function performs the inverse operation of extri_intri_to_pose_encoding, 74 | reconstructing the full camera parameters from the compact encoding. 75 | 76 | Args: 77 | pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, 78 | where B is batch size and S is sequence length. 79 | For "absT_quaR_FoV" type, the 9 dimensions are: 80 | - [:3] = absolute translation vector T (3D) 81 | - [3:7] = rotation as quaternion quat (4D) 82 | - [7:] = field of view (2D) 83 | image_size_hw (tuple): Tuple of (height, width) of the image in pixels. 84 | Required for reconstructing intrinsics from field of view values. 85 | For example: (256, 512). 86 | pose_encoding_type (str): Type of pose encoding used. Currently only 87 | supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). 88 | build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. 89 | If False, only extrinsics are returned and intrinsics will be None. 90 | 91 | Returns: 92 | tuple: (extrinsics, intrinsics) 93 | - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. 94 | In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world 95 | transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is 96 | a 3x1 translation vector. 97 | - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, 98 | or None if build_intrinsics is False. Defined in pixels, with format: 99 | [[fx, 0, cx], 100 | [0, fy, cy], 101 | [0, 0, 1]] 102 | where fx, fy are focal lengths and (cx, cy) is the principal point, 103 | assumed to be at the center of the image (W/2, H/2). 104 | """ 105 | 106 | intrinsics = None 107 | 108 | if pose_encoding_type == "absT_quaR_FoV": 109 | T = pose_encoding[..., :3] 110 | quat = pose_encoding[..., 3:7] 111 | fov_h = pose_encoding[..., 7] 112 | fov_w = pose_encoding[..., 8] 113 | 114 | R = quat_to_mat(quat) 115 | extrinsics = torch.cat([R, T[..., None]], dim=-1) 116 | 117 | if build_intrinsics: 118 | H, W = image_size_hw 119 | fy = (H / 2.0) / torch.tan(fov_h / 2.0) 120 | fx = (W / 2.0) / torch.tan(fov_w / 2.0) 121 | intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) 122 | intrinsics[..., 0, 0] = fx 123 | intrinsics[..., 1, 1] = fy 124 | intrinsics[..., 0, 2] = W / 2 125 | intrinsics[..., 1, 2] = H / 2 126 | intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 127 | else: 128 | raise NotImplementedError 129 | 130 | return extrinsics, intrinsics 131 | -------------------------------------------------------------------------------- /evaluation/quarot/quarot_linear.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tkinter import NO 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .quarot_utils import random_hadamard_matrix 7 | 8 | from .quant_utils import WeightQuantizer, ActivationQuantizer 9 | 10 | 11 | class VGGTQuantizedLinear(nn.Module): 12 | def __init__(self, args, linear: nn.Linear): 13 | super(VGGTQuantizedLinear, self).__init__() 14 | 15 | self.args = args 16 | self.not_rot = args.not_rot 17 | self.not_smooth = args.not_smooth 18 | self.lwc = args.lwc 19 | self.lac = args.lac 20 | self.rv = args.rv 21 | self.linear = linear 22 | self.weight_quantizer = WeightQuantizer() 23 | self.weight_quantizer.configure(args.w_bits, perchannel=True, sym=not(args.w_asym), mse=False) 24 | 25 | self.act_quantizer = ActivationQuantizer(bits=args.a_bits, sym=not(args.a_asym), lac=self.lac, 26 | groupsize=args.a_groupsize, ) 27 | if self.lwc: 28 | lwc_dim = self.linear.weight.shape[0] if self.lwc else -1 29 | init_value = 4. 30 | self.clip_factor_w_max = nn.Parameter(torch.ones((lwc_dim, 1)) * init_value, requires_grad=True) 31 | self.clip_factor_w_min = nn.Parameter(torch.ones((lwc_dim, 1)) * init_value, requires_grad=True) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | if not args.not_smooth: 35 | self.act_quantizer.register_buffer("act_scale", None) 36 | self.register_parameter("channel_wise_scale", 37 | nn.Parameter(torch.ones((1, self.linear.weight.shape[1])))) 38 | self.smooth_quant_momentum = 0.95 39 | self.smooth_quant_alpha = 0.5 40 | self.smooth_quant_running_stat = True 41 | 42 | if not self.not_rot: 43 | self.register_parameter("rotation_matrix", 44 | torch.nn.Parameter(random_hadamard_matrix(self.linear.weight.shape[1], "cuda").to(dtype=torch.float32))) 45 | 46 | self.ori_mode = True 47 | self.train_mode = False 48 | self.eval_mode = False 49 | 50 | 51 | def apply_wclip(self, weight): 52 | wmin, wmax = weight.min(1, keepdim=True)[0], weight.max(1, keepdim=True)[0] 53 | wmax *= self.sigmoid(self.clip_factor_w_max) 54 | wmin *= self.sigmoid(self.clip_factor_w_min) 55 | weight = torch.clamp(weight, min=wmin, max=wmax) 56 | return weight 57 | 58 | 59 | def _ori_forward(self, hidden_states): 60 | weight = self.linear.weight.data 61 | bias = self.linear.bias 62 | if not self.not_smooth : 63 | if self.smooth_quant_running_stat: 64 | if not self.not_rot: 65 | hidden_states = torch.matmul(hidden_states.float(), self.rotation_matrix) 66 | weight = torch.matmul(weight.float(), self.rotation_matrix) 67 | 68 | cur_act_scale = hidden_states.abs().max(dim=-2)[0].mean(dim=0, keepdim=True) 69 | if self.act_quantizer.act_scale is None: 70 | self.act_quantizer.act_scale = torch.zeros(1).to(hidden_states) 71 | if self.act_quantizer.act_scale.abs().mean() == 0: 72 | self.act_quantizer.act_scale = cur_act_scale 73 | else: 74 | self.act_quantizer.act_scale = self.act_quantizer.act_scale * self.smooth_quant_momentum + cur_act_scale * ( 75 | 1 - self.smooth_quant_momentum) 76 | else: 77 | assert self.act_quantizer.act_scale is not None 78 | assert self.act_quantizer.act_scale.mean() != 0 79 | 80 | return F.linear(hidden_states, weight, bias) 81 | 82 | def _train_forward(self, hidden_states): 83 | weight = self.linear.weight.data 84 | if not self.not_rot: 85 | weight = torch.matmul(weight, self.rotation_matrix) 86 | if not self.not_smooth: 87 | weight = weight * self.channel_wise_scale 88 | if self.lwc: 89 | weight = self.apply_wclip(weight) 90 | 91 | self.weight_quantizer.find_params(weight) 92 | weight = self.weight_quantizer(weight) 93 | 94 | if not self.not_rot: 95 | hidden_states = torch.matmul(hidden_states, self.rotation_matrix) 96 | if not self.not_smooth: 97 | hidden_states = hidden_states / self.channel_wise_scale 98 | 99 | hidden_states = self.act_quantizer(hidden_states) 100 | bias = self.linear.bias 101 | output = F.linear(hidden_states, weight, bias) 102 | 103 | return output 104 | 105 | def _eval_forward(self, hidden_states): 106 | x_dtype = hidden_states.dtype 107 | if not self.not_rot: 108 | hidden_states = torch.matmul(hidden_states.float(), self.rotation_matrix) 109 | if not self.not_smooth: 110 | hidden_states = hidden_states / self.channel_wise_scale 111 | hidden_states = self.act_quantizer(hidden_states).to(x_dtype) 112 | output = self.linear(hidden_states) 113 | return output 114 | 115 | def forward(self, hidden_states): 116 | if self.ori_mode: 117 | return self._ori_forward(hidden_states) 118 | if self.train_mode: 119 | return self._train_forward(hidden_states) 120 | if self.eval_mode: 121 | return self._eval_forward(hidden_states) 122 | 123 | def reparameterize(self): 124 | target_device = self.linear.weight.device 125 | ori_dtype = self.linear.weight.dtype 126 | weight = self.linear.weight.data.detach().to(torch.float32) 127 | if not self.not_rot: 128 | weight = torch.matmul(weight, self.rotation_matrix.to(weight.device)) 129 | if not self.not_smooth: 130 | weight = weight * self.channel_wise_scale.to(weight.device) 131 | 132 | if self.lwc: 133 | weight = self.apply_wclip(weight) 134 | 135 | self.weight_quantizer.find_params(weight) 136 | weight = self.weight_quantizer(weight) 137 | self.linear.weight.data = weight.to(device=target_device, dtype=ori_dtype) 138 | 139 | self.ori_mode = False 140 | self.train_mode = False 141 | self.eval_mode = True 142 | -------------------------------------------------------------------------------- /vggt/utils/load_fn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision import transforms as TF 10 | 11 | 12 | def load_and_preprocess_images(image_path_list, mode="crop"): 13 | """ 14 | A quick start function to load and preprocess images for model input. 15 | This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. 16 | 17 | Args: 18 | image_path_list (list): List of paths to image files 19 | mode (str, optional): Preprocessing mode, either "crop" or "pad". 20 | - "crop" (default): Sets width to 518px and center crops height if needed. 21 | - "pad": Preserves all pixels by making the largest dimension 518px 22 | and padding the smaller dimension to reach a square shape. 23 | 24 | Returns: 25 | torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) 26 | 27 | Raises: 28 | ValueError: If the input list is empty or if mode is invalid 29 | 30 | Notes: 31 | - Images with different dimensions will be padded with white (value=1.0) 32 | - A warning is printed when images have different shapes 33 | - When mode="crop": The function ensures width=518px while maintaining aspect ratio 34 | and height is center-cropped if larger than 518px 35 | - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio 36 | and the smaller dimension is padded to reach a square shape (518x518) 37 | - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements 38 | """ 39 | # Check for empty list 40 | if len(image_path_list) == 0: 41 | raise ValueError("At least 1 image is required") 42 | 43 | # Validate mode 44 | if mode not in ["crop", "pad"]: 45 | raise ValueError("Mode must be either 'crop' or 'pad'") 46 | 47 | images = [] 48 | shapes = set() 49 | to_tensor = TF.ToTensor() 50 | target_size = 518 51 | 52 | # First process all images and collect their shapes 53 | for image_path in image_path_list: 54 | 55 | # Open image 56 | img = Image.open(image_path) 57 | 58 | # If there's an alpha channel, blend onto white background: 59 | if img.mode == "RGBA": 60 | # Create white background 61 | background = Image.new("RGBA", img.size, (255, 255, 255, 255)) 62 | # Alpha composite onto the white background 63 | img = Image.alpha_composite(background, img) 64 | 65 | # Now convert to "RGB" (this step assigns white for transparent areas) 66 | img = img.convert("RGB") 67 | 68 | width, height = img.size 69 | 70 | if mode == "pad": 71 | # Make the largest dimension 518px while maintaining aspect ratio 72 | if width >= height: 73 | new_width = target_size 74 | new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 75 | else: 76 | new_height = target_size 77 | new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 78 | else: # mode == "crop" 79 | # Original behavior: set width to 518px 80 | new_width = target_size 81 | # Calculate height maintaining aspect ratio, divisible by 14 82 | new_height = round(height * (new_width / width) / 14) * 14 83 | 84 | # Resize with new dimensions (width, height) 85 | img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) 86 | img = to_tensor(img) # Convert to tensor (0, 1) 87 | 88 | # Center crop height if it's larger than 518 (only in crop mode) 89 | if mode == "crop" and new_height > target_size: 90 | start_y = (new_height - target_size) // 2 91 | img = img[:, start_y : start_y + target_size, :] 92 | 93 | # For pad mode, pad to make a square of target_size x target_size 94 | if mode == "pad": 95 | h_padding = target_size - img.shape[1] 96 | w_padding = target_size - img.shape[2] 97 | 98 | if h_padding > 0 or w_padding > 0: 99 | pad_top = h_padding // 2 100 | pad_bottom = h_padding - pad_top 101 | pad_left = w_padding // 2 102 | pad_right = w_padding - pad_left 103 | 104 | # Pad with white (value=1.0) 105 | img = torch.nn.functional.pad( 106 | img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 107 | ) 108 | 109 | shapes.add((img.shape[1], img.shape[2])) 110 | images.append(img) 111 | 112 | # Check if we have different shapes 113 | # In theory our model can also work well with different shapes 114 | if len(shapes) > 1: 115 | print(f"Warning: Found images with different shapes: {shapes}") 116 | # Find maximum dimensions 117 | max_height = max(shape[0] for shape in shapes) 118 | max_width = max(shape[1] for shape in shapes) 119 | 120 | # Pad images if necessary 121 | padded_images = [] 122 | for img in images: 123 | h_padding = max_height - img.shape[1] 124 | w_padding = max_width - img.shape[2] 125 | 126 | if h_padding > 0 or w_padding > 0: 127 | pad_top = h_padding // 2 128 | pad_bottom = h_padding - pad_top 129 | pad_left = w_padding // 2 130 | pad_right = w_padding - pad_left 131 | 132 | img = torch.nn.functional.pad( 133 | img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 134 | ) 135 | padded_images.append(img) 136 | images = padded_images 137 | 138 | images = torch.stack(images) # concatenate images 139 | 140 | # Ensure correct shape when single image 141 | if len(image_path_list) == 1: 142 | # Verify shape is (1, C, H, W) 143 | if images.dim() == 3: 144 | images = images.unsqueeze(0) 145 | 146 | return images 147 | -------------------------------------------------------------------------------- /vggt/utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def unproject_depth_map_to_point_map( 13 | depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray 14 | ) -> np.ndarray: 15 | """ 16 | Unproject a batch of depth maps to 3D world coordinates. 17 | 18 | Args: 19 | depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) 20 | extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) 21 | intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) 22 | 23 | Returns: 24 | np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) 25 | """ 26 | if isinstance(depth_map, torch.Tensor): 27 | depth_map = depth_map.cpu().numpy() 28 | if isinstance(extrinsics_cam, torch.Tensor): 29 | extrinsics_cam = extrinsics_cam.cpu().numpy() 30 | if isinstance(intrinsics_cam, torch.Tensor): 31 | intrinsics_cam = intrinsics_cam.cpu().numpy() 32 | 33 | world_points_list = [] 34 | for frame_idx in range(depth_map.shape[0]): 35 | cur_world_points, _, _ = depth_to_world_coords_points( 36 | depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] 37 | ) 38 | world_points_list.append(cur_world_points) 39 | world_points_array = np.stack(world_points_list, axis=0) 40 | 41 | return world_points_array 42 | 43 | 44 | def depth_to_world_coords_points( 45 | depth_map: np.ndarray, 46 | extrinsic: np.ndarray, 47 | intrinsic: np.ndarray, 48 | eps=1e-8, 49 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 50 | """ 51 | Convert a depth map to world coordinates. 52 | 53 | Args: 54 | depth_map (np.ndarray): Depth map of shape (H, W). 55 | intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). 56 | extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. 57 | 58 | Returns: 59 | tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). 60 | """ 61 | if depth_map is None: 62 | return None, None, None 63 | 64 | # Valid depth mask 65 | point_mask = depth_map > eps 66 | 67 | # Convert depth map to camera coordinates 68 | cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) 69 | 70 | # Multiply with the inverse of extrinsic matrix to transform to world coordinates 71 | # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) 72 | cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] 73 | 74 | R_cam_to_world = cam_to_world_extrinsic[:3, :3] 75 | t_cam_to_world = cam_to_world_extrinsic[:3, 3] 76 | 77 | # Apply the rotation and translation to the camera coordinates 78 | world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 79 | # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world 80 | 81 | return world_coords_points, cam_coords_points, point_mask 82 | 83 | 84 | def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 85 | """ 86 | Convert a depth map to camera coordinates. 87 | 88 | Args: 89 | depth_map (np.ndarray): Depth map of shape (H, W). 90 | intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). 91 | 92 | Returns: 93 | tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) 94 | """ 95 | H, W = depth_map.shape 96 | assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" 97 | assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" 98 | 99 | # Intrinsic parameters 100 | fu, fv = intrinsic[0, 0], intrinsic[1, 1] 101 | cu, cv = intrinsic[0, 2], intrinsic[1, 2] 102 | 103 | # Generate grid of pixel coordinates 104 | u, v = np.meshgrid(np.arange(W), np.arange(H)) 105 | 106 | # Unproject to camera coordinates 107 | x_cam = (u - cu) * depth_map / fu 108 | y_cam = (v - cv) * depth_map / fv 109 | z_cam = depth_map 110 | 111 | # Stack to form camera coordinates 112 | cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) 113 | 114 | return cam_coords 115 | 116 | 117 | def closed_form_inverse_se3(se3, R=None, T=None): 118 | """ 119 | Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. 120 | 121 | If `R` and `T` are provided, they must correspond to the rotation and translation 122 | components of `se3`. Otherwise, they will be extracted from `se3`. 123 | 124 | Args: 125 | se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. 126 | R (optional): Nx3x3 array or tensor of rotation matrices. 127 | T (optional): Nx3x1 array or tensor of translation vectors. 128 | 129 | Returns: 130 | Inverted SE3 matrices with the same type and device as `se3`. 131 | 132 | Shapes: 133 | se3: (N, 4, 4) 134 | R: (N, 3, 3) 135 | T: (N, 3, 1) 136 | """ 137 | # Check if se3 is a numpy array or a torch tensor 138 | is_numpy = isinstance(se3, np.ndarray) 139 | 140 | # Validate shapes 141 | if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): 142 | raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") 143 | 144 | # Extract R and T if not provided 145 | if R is None: 146 | R = se3[:, :3, :3] # (N,3,3) 147 | if T is None: 148 | T = se3[:, :3, 3:] # (N,3,1) 149 | 150 | # Transpose R 151 | if is_numpy: 152 | # Compute the transpose of the rotation for NumPy 153 | R_transposed = np.transpose(R, (0, 2, 1)) 154 | # -R^T t for NumPy 155 | top_right = -np.matmul(R_transposed, T) 156 | inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) 157 | else: 158 | R_transposed = R.transpose(1, 2) # (N,3,3) 159 | top_right = -torch.bmm(R_transposed, T) # (N,3,1) 160 | inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) 161 | inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) 162 | 163 | inverted_matrix[:, :3, :3] = R_transposed 164 | inverted_matrix[:, :3, 3:] = top_right 165 | 166 | return inverted_matrix 167 | -------------------------------------------------------------------------------- /vggt/heads/camera_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from vggt.layers import Mlp 15 | from vggt.layers.block import Block 16 | from vggt.heads.head_act import activate_pose 17 | 18 | 19 | class CameraHead(nn.Module): 20 | """ 21 | CameraHead predicts camera parameters from token representations using iterative refinement. 22 | 23 | It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | dim_in: int = 2048, 29 | trunk_depth: int = 4, 30 | pose_encoding_type: str = "absT_quaR_FoV", 31 | num_heads: int = 16, 32 | mlp_ratio: int = 4, 33 | init_values: float = 0.01, 34 | trans_act: str = "linear", 35 | quat_act: str = "linear", 36 | fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. 37 | ): 38 | super().__init__() 39 | 40 | if pose_encoding_type == "absT_quaR_FoV": 41 | self.target_dim = 9 42 | else: 43 | raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") 44 | 45 | self.trans_act = trans_act 46 | self.quat_act = quat_act 47 | self.fl_act = fl_act 48 | self.trunk_depth = trunk_depth 49 | 50 | # Build the trunk using a sequence of transformer blocks. 51 | self.trunk = nn.Sequential( 52 | *[ 53 | Block( 54 | dim=dim_in, 55 | num_heads=num_heads, 56 | mlp_ratio=mlp_ratio, 57 | init_values=init_values, 58 | ) 59 | for _ in range(trunk_depth) 60 | ] 61 | ) 62 | 63 | # Normalizations for camera token and trunk output. 64 | self.token_norm = nn.LayerNorm(dim_in) 65 | self.trunk_norm = nn.LayerNorm(dim_in) 66 | 67 | # Learnable empty camera pose token. 68 | self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) 69 | self.embed_pose = nn.Linear(self.target_dim, dim_in) 70 | 71 | # Module for producing modulation parameters: shift, scale, and a gate. 72 | self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) 73 | 74 | # Adaptive layer normalization without affine parameters. 75 | self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) 76 | self.pose_branch = Mlp( 77 | in_features=dim_in, 78 | hidden_features=dim_in // 2, 79 | out_features=self.target_dim, 80 | drop=0, 81 | ) 82 | 83 | def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: 84 | """ 85 | Forward pass to predict camera parameters. 86 | 87 | Args: 88 | aggregated_tokens_list (list): List of token tensors from the network; 89 | the last tensor is used for prediction. 90 | num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. 91 | 92 | Returns: 93 | list: A list of predicted camera encodings (post-activation) from each iteration. 94 | """ 95 | # Use tokens from the last block for camera prediction. 96 | tokens = aggregated_tokens_list[-1] 97 | 98 | # Extract the camera tokens 99 | pose_tokens = tokens[:, :, 0] 100 | pose_tokens = self.token_norm(pose_tokens) 101 | 102 | pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) 103 | return pred_pose_enc_list 104 | 105 | def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: 106 | """ 107 | Iteratively refine camera pose predictions. 108 | 109 | Args: 110 | pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. 111 | num_iterations (int): Number of refinement iterations. 112 | 113 | Returns: 114 | list: List of activated camera encodings from each iteration. 115 | """ 116 | B, S, C = pose_tokens.shape # S is expected to be 1. 117 | pred_pose_enc = None 118 | pred_pose_enc_list = [] 119 | 120 | for _ in range(num_iterations): 121 | # Use a learned empty pose for the first iteration. 122 | if pred_pose_enc is None: 123 | module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) 124 | else: 125 | # Detach the previous prediction to avoid backprop through time. 126 | pred_pose_enc = pred_pose_enc.detach() 127 | module_input = self.embed_pose(pred_pose_enc) 128 | 129 | # Generate modulation parameters and split them into shift, scale, and gate components. 130 | shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) 131 | 132 | # Adaptive layer normalization and modulation. 133 | pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) 134 | pose_tokens_modulated = pose_tokens_modulated + pose_tokens 135 | 136 | pose_tokens_modulated = self.trunk(pose_tokens_modulated) 137 | # Compute the delta update for the pose encoding. 138 | pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) 139 | 140 | if pred_pose_enc is None: 141 | pred_pose_enc = pred_pose_enc_delta 142 | else: 143 | pred_pose_enc = pred_pose_enc + pred_pose_enc_delta 144 | 145 | # Apply final activation functions for translation, quaternion, and field-of-view. 146 | activated_pose = activate_pose( 147 | pred_pose_enc, 148 | trans_act=self.trans_act, 149 | quat_act=self.quat_act, 150 | fl_act=self.fl_act, 151 | ) 152 | pred_pose_enc_list.append(activated_pose) 153 | 154 | return pred_pose_enc_list 155 | 156 | 157 | def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: 158 | """ 159 | Modulate the input tensor using scaling and shifting parameters. 160 | """ 161 | # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 162 | return x * (1 + scale) + shift 163 | -------------------------------------------------------------------------------- /vggt/heads/track_modules/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from functools import partial 12 | from typing import Callable 13 | import collections 14 | from torch import Tensor 15 | from itertools import repeat 16 | 17 | 18 | # From PyTorch internals 19 | def _ntuple(n): 20 | def parse(x): 21 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 22 | return tuple(x) 23 | return tuple(repeat(x, n)) 24 | 25 | return parse 26 | 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | 32 | def default(val, d): 33 | return val if exists(val) else d 34 | 35 | 36 | to_2tuple = _ntuple(2) 37 | 38 | 39 | class ResidualBlock(nn.Module): 40 | """ 41 | ResidualBlock: construct a block of two conv layers with residual connections 42 | """ 43 | 44 | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): 45 | super(ResidualBlock, self).__init__() 46 | 47 | self.conv1 = nn.Conv2d( 48 | in_planes, 49 | planes, 50 | kernel_size=kernel_size, 51 | padding=1, 52 | stride=stride, 53 | padding_mode="zeros", 54 | ) 55 | self.conv2 = nn.Conv2d( 56 | planes, 57 | planes, 58 | kernel_size=kernel_size, 59 | padding=1, 60 | padding_mode="zeros", 61 | ) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | num_groups = planes // 8 65 | 66 | if norm_fn == "group": 67 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 68 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 69 | if not stride == 1: 70 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 71 | 72 | elif norm_fn == "batch": 73 | self.norm1 = nn.BatchNorm2d(planes) 74 | self.norm2 = nn.BatchNorm2d(planes) 75 | if not stride == 1: 76 | self.norm3 = nn.BatchNorm2d(planes) 77 | 78 | elif norm_fn == "instance": 79 | self.norm1 = nn.InstanceNorm2d(planes) 80 | self.norm2 = nn.InstanceNorm2d(planes) 81 | if not stride == 1: 82 | self.norm3 = nn.InstanceNorm2d(planes) 83 | 84 | elif norm_fn == "none": 85 | self.norm1 = nn.Sequential() 86 | self.norm2 = nn.Sequential() 87 | if not stride == 1: 88 | self.norm3 = nn.Sequential() 89 | else: 90 | raise NotImplementedError 91 | 92 | if stride == 1: 93 | self.downsample = None 94 | else: 95 | self.downsample = nn.Sequential( 96 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 97 | self.norm3, 98 | ) 99 | 100 | def forward(self, x): 101 | y = x 102 | y = self.relu(self.norm1(self.conv1(y))) 103 | y = self.relu(self.norm2(self.conv2(y))) 104 | 105 | if self.downsample is not None: 106 | x = self.downsample(x) 107 | 108 | return self.relu(x + y) 109 | 110 | 111 | class Mlp(nn.Module): 112 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 113 | 114 | def __init__( 115 | self, 116 | in_features, 117 | hidden_features=None, 118 | out_features=None, 119 | act_layer=nn.GELU, 120 | norm_layer=None, 121 | bias=True, 122 | drop=0.0, 123 | use_conv=False, 124 | ): 125 | super().__init__() 126 | out_features = out_features or in_features 127 | hidden_features = hidden_features or in_features 128 | bias = to_2tuple(bias) 129 | drop_probs = to_2tuple(drop) 130 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 131 | 132 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 133 | self.act = act_layer() 134 | self.drop1 = nn.Dropout(drop_probs[0]) 135 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 136 | self.drop2 = nn.Dropout(drop_probs[1]) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) 140 | x = self.act(x) 141 | x = self.drop1(x) 142 | x = self.fc2(x) 143 | x = self.drop2(x) 144 | return x 145 | 146 | 147 | class AttnBlock(nn.Module): 148 | def __init__( 149 | self, 150 | hidden_size, 151 | num_heads, 152 | attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, 153 | mlp_ratio=4.0, 154 | **block_kwargs 155 | ): 156 | """ 157 | Self attention block 158 | """ 159 | super().__init__() 160 | 161 | self.norm1 = nn.LayerNorm(hidden_size) 162 | self.norm2 = nn.LayerNorm(hidden_size) 163 | 164 | self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) 165 | 166 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 167 | 168 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) 169 | 170 | def forward(self, x, mask=None): 171 | # Prepare the mask for PyTorch's attention (it expects a different format) 172 | # attn_mask = mask if mask is not None else None 173 | # Normalize before attention 174 | x = self.norm1(x) 175 | 176 | # PyTorch's MultiheadAttention returns attn_output, attn_output_weights 177 | # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) 178 | 179 | attn_output, _ = self.attn(x, x, x) 180 | 181 | # Add & Norm 182 | x = x + attn_output 183 | x = x + self.mlp(self.norm2(x)) 184 | return x 185 | 186 | 187 | class CrossAttnBlock(nn.Module): 188 | def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): 189 | """ 190 | Cross attention block 191 | """ 192 | super().__init__() 193 | 194 | self.norm1 = nn.LayerNorm(hidden_size) 195 | self.norm_context = nn.LayerNorm(hidden_size) 196 | self.norm2 = nn.LayerNorm(hidden_size) 197 | 198 | self.cross_attn = nn.MultiheadAttention( 199 | embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs 200 | ) 201 | 202 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 203 | 204 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) 205 | 206 | def forward(self, x, context, mask=None): 207 | # Normalize inputs 208 | x = self.norm1(x) 209 | context = self.norm_context(context) 210 | 211 | # Apply cross attention 212 | # Note: nn.MultiheadAttention returns attn_output, attn_output_weights 213 | attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) 214 | 215 | # Add & Norm 216 | x = x + attn_output 217 | x = x + self.mlp(self.norm2(x)) 218 | return x 219 | -------------------------------------------------------------------------------- /evaluation/quarot/utils.py: -------------------------------------------------------------------------------- 1 | from .args_utils import parser_gen, get_config, create_logger 2 | from .quarot_linear import VGGTQuantizedLinear 3 | # from .train_utils import cali_qs_quant 4 | from .function_utils import get_paras_dict_by_name 5 | import torch.nn as nn 6 | import torch 7 | import logging 8 | import os 9 | import math 10 | from collections import defaultdict 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | def load_qs_parameters(args, model, path=None): 14 | if path is None: 15 | qs_frame_parameters = torch.load(os.path.join(args.exp_dir, f"qs_frame_parameters_total.pth")) 16 | qs_global_parameters = torch.load(os.path.join(args.exp_dir, f"qs_global_parameters_total.pth")) 17 | else: 18 | qs_frame_parameters = torch.load(os.path.join(path, f"qs_frame_parameters_total.pth")) 19 | qs_global_parameters = torch.load(os.path.join(path, f"qs_global_parameters_total.pth")) 20 | frame_layers = model.aggregator.frame_blocks 21 | global_layers = model.aggregator.global_blocks 22 | 23 | for i in range(len(qs_frame_parameters.keys())): 24 | qs_frame_param = qs_frame_parameters[i] 25 | frame_layers[i].load_state_dict(qs_frame_param, strict=False) 26 | 27 | for i in range(len(qs_global_parameters.keys())): 28 | qs_global_param = qs_global_parameters[i] 29 | global_layers[i].load_state_dict(qs_global_param, strict=False) 30 | 31 | model.eval() 32 | return model 33 | 34 | 35 | def mark_ignore(module,ignore_quantize): 36 | for child in module.children(): 37 | if isinstance(child, torch.nn.Linear): 38 | child.ignore_quantize = ignore_quantize 39 | else: 40 | mark_ignore(child,ignore_quantize) 41 | 42 | 43 | def set_ignore_quantize(model, ignore_quantize=True): 44 | model.aggregator.patch_embed.patch_embed.ignore_quantize = True 45 | model.aggregator.patch_embed.norm.ignore_quantize = True 46 | model.aggregator.patch_embed.head.ignore_quantize = True 47 | 48 | for head in [model.aggregator.patch_embed.blocks , 49 | model.camera_head, 50 | model.point_head, 51 | model.depth_head, 52 | model.track_head]: 53 | mark_ignore(head,ignore_quantize) 54 | 55 | 56 | def is_power_of_two(n): 57 | return (n & (n - 1)) == 0 and n != 0 58 | 59 | 60 | def check_linear_dims(model): 61 | non_power_of_two = defaultdict(list) 62 | 63 | for name, module in model.named_modules(): 64 | if isinstance(module, torch.nn.Linear): 65 | in_features = module.in_features 66 | out_features = module.out_features 67 | 68 | if not is_power_of_two(in_features): 69 | non_power_of_two[in_features].append({ 70 | "layer_name": name, 71 | "in_features": in_features, 72 | "out_features": out_features 73 | }) 74 | 75 | return non_power_of_two 76 | 77 | def quantize_linear(module, device="cuda", args=None): 78 | if isinstance(module, nn.Linear): 79 | use_rot = True 80 | if device is not None: 81 | module = module.to(device) 82 | 83 | if getattr(module, 'ignore_quantize', False): 84 | return module 85 | 86 | if getattr(module, 'higher_bits', False): 87 | original_a_bits = args.a_bits 88 | original_w_bits = args.w_bits 89 | args.w_bits = 8 90 | args.a_bits = 8 91 | if use_rot: 92 | new_layer = VGGTQuantizedLinear(args, module) 93 | args.a_bits = original_a_bits 94 | args.w_bits = original_w_bits 95 | else: 96 | if use_rot: 97 | new_layer = VGGTQuantizedLinear(args, module) 98 | 99 | return new_layer 100 | else: 101 | for name, child in module.named_children(): 102 | new_child = quantize_linear( 103 | child, device, args 104 | ) 105 | if new_child is not child: 106 | setattr(module, name, new_child) 107 | if device is not None: 108 | module.to(device=device) 109 | return module 110 | 111 | 112 | def save_hadamard_matrix(model, path): 113 | model_params = get_paras_dict_by_name(model, required_names=["rotation_matrix"]) 114 | torch.save(model_params, os.path.join(path, f"hadamard_matrix.pth")) 115 | 116 | def model_reparameterize(model): 117 | for name, module in model.named_modules(): 118 | if isinstance(module,VGGTQuantizedLinear): 119 | module.reparameterize() 120 | 121 | def init_logger(log_dir="logs", log_file="app.log"): 122 | logger = logging.getLogger("my_app") 123 | logger.setLevel(logging.DEBUG) # 全局最低级别 124 | 125 | os.makedirs(log_dir, exist_ok=True) 126 | log_path = os.path.join(log_dir, log_file) 127 | 128 | console_handler = logging.StreamHandler() 129 | console_handler.setLevel(logging.INFO) 130 | 131 | file_handler = logging.FileHandler(log_path, encoding="utf-8") 132 | file_handler.setLevel(logging.DEBUG) 133 | 134 | formatter = logging.Formatter( 135 | "%(asctime)s [%(levelname)s] %(name)s: %(message)s", 136 | datefmt="%Y-%m-%d %H:%M:%S" 137 | ) 138 | console_handler.setFormatter(formatter) 139 | file_handler.setFormatter(formatter) 140 | 141 | logger.addHandler(console_handler) 142 | logger.addHandler(file_handler) 143 | 144 | return logger 145 | 146 | def VggtQuantModel(config,model,calib_data, wbit, abit, resume_qs=False, 147 | use_gptq=False, resume_gptq=False, model_id=None,exp_name=None): 148 | 149 | device = next(model.parameters()).device 150 | model.to(device) 151 | set_ignore_quantize(model) 152 | quantize_linear(model, args=config) 153 | 154 | if resume_qs: 155 | model = load_qs_parameters(config, model) 156 | print("resume!") 157 | after_resume_qs(model) 158 | return 159 | 160 | if not config.not_smooth : 161 | with torch.no_grad(): 162 | num_calib = max(1, math.floor(config.nsamples / 5)) 163 | for idx, batch in enumerate(calib_data[:num_calib]): 164 | input_dict = {} 165 | for key in batch.keys(): 166 | if key == "images": 167 | if torch.is_tensor(batch[key]): 168 | input_dict[key] = batch[key].to(device) 169 | else: 170 | input_dict[key] = batch[key] 171 | 172 | output = model(**input_dict) 173 | after_smoothfactor_init(model) 174 | 175 | if not config.not_smooth or config.lac or config.lwc: 176 | logger = init_logger(config.exp_dir) 177 | cali_qs_quant(config, model, calib_data, device, logger) 178 | print("calib done") 179 | else: 180 | print("Do not need calib") 181 | 182 | after_resume_qs(model) 183 | return 184 | 185 | 186 | def after_smoothfactor_init(model): 187 | for name, module in model.named_modules(): 188 | if isinstance(module, (VGGTQuantizedLinear)): 189 | module.smooth_quant_running_stat = False 190 | module.ori_mode = True 191 | module.train_mode = False 192 | module.eval_mode = False 193 | module.channel_wise_scale = nn.Parameter(module.act_quantizer.act_scale.pow(module.smooth_quant_alpha) / module.linear.weight.abs().max(dim=0)[0].pow(1 - module.smooth_quant_alpha)) 194 | module.channel_wise_scale.data = torch.clamp(module.channel_wise_scale.data, min=1e-5) 195 | 196 | 197 | def after_resume_qs(model): 198 | for name, module in model.named_modules(): 199 | if isinstance(module, (VGGTQuantizedLinear)): 200 | module.smooth_quant_running_stat = False 201 | module.ori_mode = False 202 | module.train_mode = False 203 | module.eval_mode = True 204 | module.reparameterize() 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /vggt/layers/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # Implementation of 2D Rotary Position Embeddings (RoPE). 8 | 9 | # This module provides a clean implementation of 2D Rotary Position Embeddings, 10 | # which extends the original RoPE concept to handle 2D spatial positions. 11 | 12 | # Inspired by: 13 | # https://github.com/meta-llama/codellama/blob/main/llama/model.py 14 | # https://github.com/naver-ai/rope-vit 15 | 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from typing import Dict, Tuple 22 | 23 | 24 | class PositionGetter: 25 | """Generates and caches 2D spatial positions for patches in a grid. 26 | 27 | This class efficiently manages the generation of spatial coordinates for patches 28 | in a 2D grid, caching results to avoid redundant computations. 29 | 30 | Attributes: 31 | position_cache: Dictionary storing precomputed position tensors for different 32 | grid dimensions. 33 | """ 34 | 35 | def __init__(self): 36 | """Initializes the position generator with an empty cache.""" 37 | self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} 38 | 39 | def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: 40 | """Generates spatial positions for a batch of patches. 41 | 42 | Args: 43 | batch_size: Number of samples in the batch. 44 | height: Height of the grid in patches. 45 | width: Width of the grid in patches. 46 | device: Target device for the position tensor. 47 | 48 | Returns: 49 | Tensor of shape (batch_size, height*width, 2) containing y,x coordinates 50 | for each position in the grid, repeated for each batch item. 51 | """ 52 | if (height, width) not in self.position_cache: 53 | y_coords = torch.arange(height, device=device) 54 | x_coords = torch.arange(width, device=device) 55 | positions = torch.cartesian_prod(y_coords, x_coords) 56 | self.position_cache[height, width] = positions 57 | 58 | cached_positions = self.position_cache[height, width] 59 | return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() 60 | 61 | 62 | class RotaryPositionEmbedding2D(nn.Module): 63 | """2D Rotary Position Embedding implementation. 64 | 65 | This module applies rotary position embeddings to input tokens based on their 66 | 2D spatial positions. It handles the position-dependent rotation of features 67 | separately for vertical and horizontal dimensions. 68 | 69 | Args: 70 | frequency: Base frequency for the position embeddings. Default: 100.0 71 | scaling_factor: Scaling factor for frequency computation. Default: 1.0 72 | 73 | Attributes: 74 | base_frequency: Base frequency for computing position embeddings. 75 | scaling_factor: Factor to scale the computed frequencies. 76 | frequency_cache: Cache for storing precomputed frequency components. 77 | """ 78 | 79 | def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): 80 | """Initializes the 2D RoPE module.""" 81 | super().__init__() 82 | self.base_frequency = frequency 83 | self.scaling_factor = scaling_factor 84 | self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} 85 | 86 | def _compute_frequency_components( 87 | self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype 88 | ) -> Tuple[torch.Tensor, torch.Tensor]: 89 | """Computes frequency components for rotary embeddings. 90 | 91 | Args: 92 | dim: Feature dimension (must be even). 93 | seq_len: Maximum sequence length. 94 | device: Target device for computations. 95 | dtype: Data type for the computed tensors. 96 | 97 | Returns: 98 | Tuple of (cosine, sine) tensors for frequency components. 99 | """ 100 | cache_key = (dim, seq_len, device, dtype) 101 | if cache_key not in self.frequency_cache: 102 | # Compute frequency bands 103 | exponents = torch.arange(0, dim, 2, device=device).float() / dim 104 | inv_freq = 1.0 / (self.base_frequency**exponents) 105 | 106 | # Generate position-dependent frequencies 107 | positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) 108 | angles = torch.einsum("i,j->ij", positions, inv_freq) 109 | 110 | # Compute and cache frequency components 111 | angles = angles.to(dtype) 112 | angles = torch.cat((angles, angles), dim=-1) 113 | cos_components = angles.cos().to(dtype) 114 | sin_components = angles.sin().to(dtype) 115 | self.frequency_cache[cache_key] = (cos_components, sin_components) 116 | 117 | return self.frequency_cache[cache_key] 118 | 119 | @staticmethod 120 | def _rotate_features(x: torch.Tensor) -> torch.Tensor: 121 | """Performs feature rotation by splitting and recombining feature dimensions. 122 | 123 | Args: 124 | x: Input tensor to rotate. 125 | 126 | Returns: 127 | Rotated feature tensor. 128 | """ 129 | feature_dim = x.shape[-1] 130 | x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] 131 | return torch.cat((-x2, x1), dim=-1) 132 | 133 | def _apply_1d_rope( 134 | self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor 135 | ) -> torch.Tensor: 136 | """Applies 1D rotary position embeddings along one dimension. 137 | 138 | Args: 139 | tokens: Input token features. 140 | positions: Position indices. 141 | cos_comp: Cosine components for rotation. 142 | sin_comp: Sine components for rotation. 143 | 144 | Returns: 145 | Tokens with applied rotary position embeddings. 146 | """ 147 | # Embed positions with frequency components 148 | cos = F.embedding(positions, cos_comp)[:, None, :, :] 149 | sin = F.embedding(positions, sin_comp)[:, None, :, :] 150 | 151 | # Apply rotation 152 | return (tokens * cos) + (self._rotate_features(tokens) * sin) 153 | 154 | def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: 155 | """Applies 2D rotary position embeddings to input tokens. 156 | 157 | Args: 158 | tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). 159 | The feature dimension (dim) must be divisible by 4. 160 | positions: Position tensor of shape (batch_size, n_tokens, 2) containing 161 | the y and x coordinates for each token. 162 | 163 | Returns: 164 | Tensor of same shape as input with applied 2D rotary position embeddings. 165 | 166 | Raises: 167 | AssertionError: If input dimensions are invalid or positions are malformed. 168 | """ 169 | # Validate inputs 170 | assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" 171 | assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" 172 | 173 | # Compute feature dimension for each spatial direction 174 | feature_dim = tokens.size(-1) // 2 175 | 176 | # Get frequency components 177 | max_position = int(positions.max()) + 1 178 | cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) 179 | 180 | # Split features for vertical and horizontal processing 181 | vertical_features, horizontal_features = tokens.chunk(2, dim=-1) 182 | 183 | # Apply RoPE separately for each dimension 184 | vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) 185 | horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) 186 | 187 | # Combine processed features 188 | return torch.cat((vertical_features, horizontal_features), dim=-1) 189 | -------------------------------------------------------------------------------- /vggt/heads/track_modules/base_track_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from einops import rearrange, repeat 10 | 11 | 12 | from .blocks import EfficientUpdateFormer, CorrBlock 13 | from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed 14 | from .modules import Mlp 15 | 16 | 17 | class BaseTrackerPredictor(nn.Module): 18 | def __init__( 19 | self, 20 | stride=1, 21 | corr_levels=5, 22 | corr_radius=4, 23 | latent_dim=128, 24 | hidden_size=384, 25 | use_spaceatt=True, 26 | depth=6, 27 | max_scale=518, 28 | predict_conf=True, 29 | ): 30 | super(BaseTrackerPredictor, self).__init__() 31 | """ 32 | The base template to create a track predictor 33 | 34 | Modified from https://github.com/facebookresearch/co-tracker/ 35 | and https://github.com/facebookresearch/vggsfm 36 | """ 37 | 38 | self.stride = stride 39 | self.latent_dim = latent_dim 40 | self.corr_levels = corr_levels 41 | self.corr_radius = corr_radius 42 | self.hidden_size = hidden_size 43 | self.max_scale = max_scale 44 | self.predict_conf = predict_conf 45 | 46 | self.flows_emb_dim = latent_dim // 2 47 | 48 | self.corr_mlp = Mlp( 49 | in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, 50 | hidden_features=self.hidden_size, 51 | out_features=self.latent_dim, 52 | ) 53 | 54 | self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 55 | 56 | self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) 57 | 58 | space_depth = depth if use_spaceatt else 0 59 | time_depth = depth 60 | 61 | self.updateformer = EfficientUpdateFormer( 62 | space_depth=space_depth, 63 | time_depth=time_depth, 64 | input_dim=self.transformer_dim, 65 | hidden_size=self.hidden_size, 66 | output_dim=self.latent_dim + 2, 67 | mlp_ratio=4.0, 68 | add_space_attn=use_spaceatt, 69 | ) 70 | 71 | self.fmap_norm = nn.LayerNorm(self.latent_dim) 72 | self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) 73 | 74 | # A linear layer to update track feats at each iteration 75 | self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) 76 | 77 | self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) 78 | 79 | if predict_conf: 80 | self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) 81 | 82 | def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): 83 | """ 84 | query_points: B x N x 2, the number of batches, tracks, and xy 85 | fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. 86 | note HH and WW is the size of feature maps instead of original images 87 | """ 88 | B, N, D = query_points.shape 89 | B, S, C, HH, WW = fmaps.shape 90 | 91 | assert D == 2, "Input points must be 2D coordinates" 92 | 93 | # apply a layernorm to fmaps here 94 | fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) 95 | fmaps = fmaps.permute(0, 1, 4, 2, 3) 96 | 97 | # Scale the input query_points because we may downsample the images 98 | # by down_ratio or self.stride 99 | # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map 100 | # its query_points should be query_points/4 101 | if down_ratio > 1: 102 | query_points = query_points / float(down_ratio) 103 | 104 | query_points = query_points / float(self.stride) 105 | 106 | # Init with coords as the query points 107 | # It means the search will start from the position of query points at the reference frames 108 | coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) 109 | 110 | # Sample/extract the features of the query points in the query frame 111 | query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) 112 | 113 | # init track feats by query feats 114 | track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C 115 | # back up the init coords 116 | coords_backup = coords.clone() 117 | 118 | fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) 119 | 120 | coord_preds = [] 121 | 122 | # Iterative Refinement 123 | for _ in range(iters): 124 | # Detach the gradients from the last iteration 125 | # (in my experience, not very important for performance) 126 | coords = coords.detach() 127 | 128 | fcorrs = fcorr_fn.corr_sample(track_feats, coords) 129 | 130 | corr_dim = fcorrs.shape[3] 131 | fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) 132 | fcorrs_ = self.corr_mlp(fcorrs_) 133 | 134 | # Movement of current coords relative to query points 135 | flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) 136 | 137 | flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) 138 | 139 | # (In my trials, it is also okay to just add the flows_emb instead of concat) 140 | flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) 141 | 142 | track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) 143 | 144 | # Concatenate them as the input for the transformers 145 | transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) 146 | 147 | # 2D positional embed 148 | # TODO: this can be much simplified 149 | pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) 150 | sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) 151 | 152 | sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) 153 | 154 | x = transformer_input + sampled_pos_emb 155 | 156 | # Add the query ref token to the track feats 157 | query_ref_token = torch.cat( 158 | [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 159 | ) 160 | x = x + query_ref_token.to(x.device).to(x.dtype) 161 | 162 | # B, N, S, C 163 | x = rearrange(x, "(b n) s d -> b n s d", b=B) 164 | 165 | # Compute the delta coordinates and delta track features 166 | delta, _ = self.updateformer(x) 167 | 168 | # BN, S, C 169 | delta = rearrange(delta, " b n s d -> (b n) s d", b=B) 170 | delta_coords_ = delta[:, :, :2] 171 | delta_feats_ = delta[:, :, 2:] 172 | 173 | track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) 174 | delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) 175 | 176 | # Update the track features 177 | track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ 178 | 179 | track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC 180 | 181 | # B x S x N x 2 182 | coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) 183 | 184 | # Force coord0 as query 185 | # because we assume the query points should not be changed 186 | coords[:, 0] = coords_backup[:, 0] 187 | 188 | # The predicted tracks are in the original image scale 189 | if down_ratio > 1: 190 | coord_preds.append(coords * self.stride * down_ratio) 191 | else: 192 | coord_preds.append(coords * self.stride) 193 | 194 | # B, S, N 195 | vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) 196 | if apply_sigmoid: 197 | vis_e = torch.sigmoid(vis_e) 198 | 199 | if self.predict_conf: 200 | conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) 201 | if apply_sigmoid: 202 | conf_e = torch.sigmoid(conf_e) 203 | else: 204 | conf_e = None 205 | 206 | if return_feat: 207 | return coord_preds, vis_e, track_feats, query_track_feat, conf_e 208 | else: 209 | return coord_preds, vis_e, conf_e 210 | -------------------------------------------------------------------------------- /vggt/heads/track_modules/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from https://github.com/facebookresearch/vggsfm 8 | # and https://github.com/facebookresearch/co-tracker/tree/main 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from typing import Optional, Tuple, Union 16 | 17 | 18 | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: 19 | """ 20 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. 21 | It is a wrapper of get_2d_sincos_pos_embed_from_grid. 22 | Args: 23 | - embed_dim: The embedding dimension. 24 | - grid_size: The grid size. 25 | Returns: 26 | - pos_embed: The generated 2D positional embedding. 27 | """ 28 | if isinstance(grid_size, tuple): 29 | grid_size_h, grid_size_w = grid_size 30 | else: 31 | grid_size_h = grid_size_w = grid_size 32 | grid_h = torch.arange(grid_size_h, dtype=torch.float) 33 | grid_w = torch.arange(grid_size_w, dtype=torch.float) 34 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") 35 | grid = torch.stack(grid, dim=0) 36 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 37 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 38 | if return_grid: 39 | return ( 40 | pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), 41 | grid, 42 | ) 43 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) 44 | 45 | 46 | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: 47 | """ 48 | This function generates a 2D positional embedding from a given grid using sine and cosine functions. 49 | 50 | Args: 51 | - embed_dim: The embedding dimension. 52 | - grid: The grid to generate the embedding from. 53 | 54 | Returns: 55 | - emb: The generated 2D positional embedding. 56 | """ 57 | assert embed_dim % 2 == 0 58 | 59 | # use half of dimensions to encode grid_h 60 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 61 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 62 | 63 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) 64 | return emb 65 | 66 | 67 | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: 68 | """ 69 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 70 | 71 | Args: 72 | - embed_dim: The embedding dimension. 73 | - pos: The position to generate the embedding from. 74 | 75 | Returns: 76 | - emb: The generated 1D positional embedding. 77 | """ 78 | assert embed_dim % 2 == 0 79 | omega = torch.arange(embed_dim // 2, dtype=torch.double) 80 | omega /= embed_dim / 2.0 81 | omega = 1.0 / 10000**omega # (D/2,) 82 | 83 | pos = pos.reshape(-1) # (M,) 84 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 85 | 86 | emb_sin = torch.sin(out) # (M, D/2) 87 | emb_cos = torch.cos(out) # (M, D/2) 88 | 89 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 90 | return emb[None].float() 91 | 92 | 93 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: 94 | """ 95 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions. 96 | 97 | Args: 98 | - xy: The coordinates to generate the embedding from. 99 | - C: The size of the embedding. 100 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. 101 | 102 | Returns: 103 | - pe: The generated 2D positional embedding. 104 | """ 105 | B, N, D = xy.shape 106 | assert D == 2 107 | 108 | x = xy[:, :, 0:1] 109 | y = xy[:, :, 1:2] 110 | div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) 111 | 112 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 113 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 114 | 115 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 116 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 117 | 118 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 119 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 120 | 121 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) 122 | if cat_coords: 123 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) 124 | return pe 125 | 126 | 127 | def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): 128 | r"""Sample a tensor using bilinear interpolation 129 | 130 | `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at 131 | coordinates :attr:`coords` using bilinear interpolation. It is the same 132 | as `torch.nn.functional.grid_sample()` but with a different coordinate 133 | convention. 134 | 135 | The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where 136 | :math:`B` is the batch size, :math:`C` is the number of channels, 137 | :math:`H` is the height of the image, and :math:`W` is the width of the 138 | image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is 139 | interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. 140 | 141 | Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, 142 | in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note 143 | that in this case the order of the components is slightly different 144 | from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. 145 | 146 | If `align_corners` is `True`, the coordinate :math:`x` is assumed to be 147 | in the range :math:`[0,W-1]`, with 0 corresponding to the center of the 148 | left-most image pixel :math:`W-1` to the center of the right-most 149 | pixel. 150 | 151 | If `align_corners` is `False`, the coordinate :math:`x` is assumed to 152 | be in the range :math:`[0,W]`, with 0 corresponding to the left edge of 153 | the left-most pixel :math:`W` to the right edge of the right-most 154 | pixel. 155 | 156 | Similar conventions apply to the :math:`y` for the range 157 | :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range 158 | :math:`[0,T-1]` and :math:`[0,T]`. 159 | 160 | Args: 161 | input (Tensor): batch of input images. 162 | coords (Tensor): batch of coordinates. 163 | align_corners (bool, optional): Coordinate convention. Defaults to `True`. 164 | padding_mode (str, optional): Padding mode. Defaults to `"border"`. 165 | 166 | Returns: 167 | Tensor: sampled points. 168 | """ 169 | coords = coords.detach().clone() 170 | ############################################################ 171 | # IMPORTANT: 172 | coords = coords.to(input.device).to(input.dtype) 173 | ############################################################ 174 | 175 | sizes = input.shape[2:] 176 | 177 | assert len(sizes) in [2, 3] 178 | 179 | if len(sizes) == 3: 180 | # t x y -> x y t to match dimensions T H W in grid_sample 181 | coords = coords[..., [1, 2, 0]] 182 | 183 | if align_corners: 184 | scale = torch.tensor( 185 | [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype 186 | ) 187 | else: 188 | scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) 189 | 190 | coords.mul_(scale) # coords = coords * scale 191 | coords.sub_(1) # coords = coords - 1 192 | 193 | return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) 194 | 195 | 196 | def sample_features4d(input, coords): 197 | r"""Sample spatial features 198 | 199 | `sample_features4d(input, coords)` samples the spatial features 200 | :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. 201 | 202 | The field is sampled at coordinates :attr:`coords` using bilinear 203 | interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 204 | 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the 205 | same convention as :func:`bilinear_sampler` with `align_corners=True`. 206 | 207 | The output tensor has one feature per point, and has shape :math:`(B, 208 | R, C)`. 209 | 210 | Args: 211 | input (Tensor): spatial features. 212 | coords (Tensor): points. 213 | 214 | Returns: 215 | Tensor: sampled features. 216 | """ 217 | 218 | B, _, _, _ = input.shape 219 | 220 | # B R 2 -> B R 1 2 221 | coords = coords.unsqueeze(2) 222 | 223 | # B C R 1 224 | feats = bilinear_sampler(input, coords) 225 | 226 | return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C 227 | -------------------------------------------------------------------------------- /vggt/utils/visual_track.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 8 | import torch 9 | import numpy as np 10 | import os 11 | 12 | 13 | def color_from_xy(x, y, W, H, cmap_name="hsv"): 14 | """ 15 | Map (x, y) -> color in (R, G, B). 16 | 1) Normalize x,y to [0,1]. 17 | 2) Combine them into a single scalar c in [0,1]. 18 | 3) Use matplotlib's colormap to convert c -> (R,G,B). 19 | 20 | You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). 21 | """ 22 | import matplotlib.cm 23 | import matplotlib.colors 24 | 25 | x_norm = x / max(W - 1, 1) 26 | y_norm = y / max(H - 1, 1) 27 | # Simple combination: 28 | c = (x_norm + y_norm) / 2.0 29 | 30 | cmap = matplotlib.cm.get_cmap(cmap_name) 31 | # cmap(c) -> (r,g,b,a) in [0,1] 32 | rgba = cmap(c) 33 | r, g, b = rgba[0], rgba[1], rgba[2] 34 | return (r, g, b) # in [0,1], RGB order 35 | 36 | 37 | def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): 38 | """ 39 | Given all tracks in one sample (b), compute a (N,3) array of RGB color values 40 | in [0,255]. The color is determined by the (x,y) position in the first 41 | visible frame for each track. 42 | 43 | Args: 44 | tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. 45 | vis_mask_b: (S, N) boolean mask; if None, assume all are visible. 46 | image_width, image_height: used for normalizing (x, y). 47 | cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). 48 | 49 | Returns: 50 | track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. 51 | """ 52 | S, N, _ = tracks_b.shape 53 | track_colors = np.zeros((N, 3), dtype=np.uint8) 54 | 55 | if vis_mask_b is None: 56 | # treat all as visible 57 | vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) 58 | 59 | for i in range(N): 60 | # Find first visible frame for track i 61 | visible_frames = torch.where(vis_mask_b[:, i])[0] 62 | if len(visible_frames) == 0: 63 | # track is never visible; just assign black or something 64 | track_colors[i] = (0, 0, 0) 65 | continue 66 | 67 | first_s = int(visible_frames[0].item()) 68 | # use that frame's (x,y) 69 | x, y = tracks_b[first_s, i].tolist() 70 | 71 | # map (x,y) -> (R,G,B) in [0,1] 72 | r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) 73 | # scale to [0,255] 74 | r, g, b = int(r * 255), int(g * 255), int(b * 255) 75 | track_colors[i] = (r, g, b) 76 | 77 | return track_colors 78 | 79 | 80 | def visualize_tracks_on_images( 81 | images, 82 | tracks, 83 | track_vis_mask=None, 84 | out_dir="track_visuals_concat_by_xy", 85 | image_format="CHW", # "CHW" or "HWC" 86 | normalize_mode="[0,1]", 87 | cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" 88 | frames_per_row=4, # New parameter for grid layout 89 | save_grid=True, # Flag to control whether to save the grid image 90 | ): 91 | """ 92 | Visualizes frames in a grid layout with specified frames per row. 93 | Each track's color is determined by its (x,y) position 94 | in the first visible frame (or frame 0 if always visible). 95 | Finally convert the BGR result to RGB before saving. 96 | Also saves each individual frame as a separate PNG file. 97 | 98 | Args: 99 | images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. 100 | tracks: torch.Tensor (S, N, 2), last dim = (x, y). 101 | track_vis_mask: torch.Tensor (S, N) or None. 102 | out_dir: folder to save visualizations. 103 | image_format: "CHW" or "HWC". 104 | normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 105 | cmap_name: a matplotlib colormap name for color_from_xy. 106 | frames_per_row: number of frames to display in each row of the grid. 107 | save_grid: whether to save all frames in one grid image. 108 | 109 | Returns: 110 | None (saves images in out_dir). 111 | """ 112 | 113 | if len(tracks.shape) == 4: 114 | tracks = tracks.squeeze(0) 115 | images = images.squeeze(0) 116 | if track_vis_mask is not None: 117 | track_vis_mask = track_vis_mask.squeeze(0) 118 | 119 | import matplotlib 120 | 121 | matplotlib.use("Agg") # for non-interactive (optional) 122 | 123 | os.makedirs(out_dir, exist_ok=True) 124 | 125 | S = images.shape[0] 126 | _, N, _ = tracks.shape # (S, N, 2) 127 | 128 | # Move to CPU 129 | images = images.cpu().clone() 130 | tracks = tracks.cpu().clone() 131 | if track_vis_mask is not None: 132 | track_vis_mask = track_vis_mask.cpu().clone() 133 | 134 | # Infer H, W from images shape 135 | if image_format == "CHW": 136 | # e.g. images[s].shape = (3, H, W) 137 | H, W = images.shape[2], images.shape[3] 138 | else: 139 | # e.g. images[s].shape = (H, W, 3) 140 | H, W = images.shape[1], images.shape[2] 141 | 142 | # Pre-compute the color for each track i based on first visible position 143 | track_colors_rgb = get_track_colors_by_position( 144 | tracks, # shape (S, N, 2) 145 | vis_mask_b=track_vis_mask if track_vis_mask is not None else None, 146 | image_width=W, 147 | image_height=H, 148 | cmap_name=cmap_name, 149 | ) 150 | 151 | # We'll accumulate each frame's drawn image in a list 152 | frame_images = [] 153 | 154 | for s in range(S): 155 | # shape => either (3, H, W) or (H, W, 3) 156 | img = images[s] 157 | 158 | # Convert to (H, W, 3) 159 | if image_format == "CHW": 160 | img = img.permute(1, 2, 0) # (H, W, 3) 161 | # else "HWC", do nothing 162 | 163 | img = img.numpy().astype(np.float32) 164 | 165 | # Scale to [0,255] if needed 166 | if normalize_mode == "[0,1]": 167 | img = np.clip(img, 0, 1) * 255.0 168 | elif normalize_mode == "[-1,1]": 169 | img = (img + 1.0) * 0.5 * 255.0 170 | img = np.clip(img, 0, 255.0) 171 | # else no normalization 172 | 173 | # Convert to uint8 174 | img = img.astype(np.uint8) 175 | 176 | # For drawing in OpenCV, convert to BGR 177 | img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 178 | 179 | # Draw each visible track 180 | cur_tracks = tracks[s] # shape (N, 2) 181 | if track_vis_mask is not None: 182 | valid_indices = torch.where(track_vis_mask[s])[0] 183 | else: 184 | valid_indices = range(N) 185 | 186 | cur_tracks_np = cur_tracks.numpy() 187 | for i in valid_indices: 188 | x, y = cur_tracks_np[i] 189 | pt = (int(round(x)), int(round(y))) 190 | 191 | # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR 192 | R, G, B = track_colors_rgb[i] 193 | color_bgr = (int(B), int(G), int(R)) 194 | cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) 195 | 196 | # Convert back to RGB for consistent final saving: 197 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 198 | 199 | # Save individual frame 200 | frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") 201 | # Convert to BGR for OpenCV imwrite 202 | frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) 203 | cv2.imwrite(frame_path, frame_bgr) 204 | 205 | frame_images.append(img_rgb) 206 | 207 | # Only create and save the grid image if save_grid is True 208 | if save_grid: 209 | # Calculate grid dimensions 210 | num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division 211 | 212 | # Create a grid of images 213 | grid_img = None 214 | for row in range(num_rows): 215 | start_idx = row * frames_per_row 216 | end_idx = min(start_idx + frames_per_row, S) 217 | 218 | # Concatenate this row horizontally 219 | row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) 220 | 221 | # If this row has fewer than frames_per_row images, pad with black 222 | if end_idx - start_idx < frames_per_row: 223 | padding_width = (frames_per_row - (end_idx - start_idx)) * W 224 | padding = np.zeros((H, padding_width, 3), dtype=np.uint8) 225 | row_img = np.concatenate([row_img, padding], axis=1) 226 | 227 | # Add this row to the grid 228 | if grid_img is None: 229 | grid_img = row_img 230 | else: 231 | grid_img = np.concatenate([grid_img, row_img], axis=0) 232 | 233 | out_path = os.path.join(out_dir, "tracks_grid.png") 234 | # Convert back to BGR for OpenCV imwrite 235 | grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) 236 | cv2.imwrite(out_path, grid_img_bgr) 237 | print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") 238 | 239 | print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") 240 | -------------------------------------------------------------------------------- /evaluation/quarot/quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def round_ste(x: torch.Tensor): 4 | """ 5 | Implement Straight-Through Estimator for rounding operation. 6 | """ 7 | return (x.round() - x).detach() + x 8 | 9 | 10 | def get_qmin_qmax(bits, sym): 11 | if sym: 12 | q_max = torch.tensor(2 ** (bits - 1) - 1) 13 | q_min = -q_max -1 14 | else: 15 | q_max, q_min = torch.tensor(2 ** bits - 1), 0 16 | return q_max, q_min 17 | 18 | 19 | def sym_quant(x, scale, maxq): 20 | scale = scale.to(x.device) 21 | q = torch.clamp(round_ste(x / scale), -(maxq + 1), maxq) 22 | return q, scale 23 | 24 | # q是量化后的x 25 | def sym_dequant(q, scale): 26 | return scale * q 27 | 28 | # 先量化再解量化 29 | def sym_quant_dequant(x, scale, maxq): 30 | return sym_dequant(*sym_quant(x, scale, maxq)) 31 | 32 | 33 | def asym_quant(x, scale, zero, maxq): 34 | scale = scale.to(x.device) 35 | zero = zero.to(x.device) 36 | q = torch.clamp(round_ste(x / scale) + zero, 0, maxq) 37 | return q, scale, zero 38 | 39 | 40 | def asym_dequant(q, scale, zero): 41 | return scale * (q - zero) 42 | 43 | 44 | def asym_quant_dequant(x, scale, zero, maxq): 45 | return asym_dequant(*asym_quant(x, scale, zero, maxq)) 46 | 47 | class ActivationQuantizer(torch.nn.Module): 48 | ''' 49 | A class for quantizing the activations. We only support (both sym. and asym.) per-token quantization 50 | for the activations. 51 | ''' 52 | def __init__(self, bits, sym=False, lac=False, groupsize=-1, clip_ratio=None, ): 53 | super(ActivationQuantizer, self).__init__() 54 | self.bits = bits 55 | self.q_max, self.q_min = get_qmin_qmax(bits, sym) 56 | self.sym = sym 57 | self.groupsize = groupsize # 用于group_wise,尚未实现 58 | if self.groupsize > 0: 59 | raise NotImplementedError("Not support per-group quantization for activation yet.") 60 | self.lac = lac 61 | self._clip_ratio = clip_ratio 62 | 63 | if self.lac: 64 | init_value = 4. 65 | self.sigmoid = torch.nn.Sigmoid() 66 | 67 | self.clip_factor_a_max = torch.nn.Parameter(torch.ones((1, ))*init_value, requires_grad=True) 68 | self.clip_factor_a_min = torch.nn.Parameter(torch.ones((1, ))*init_value, requires_grad=True) 69 | 70 | self.enable = True 71 | def forward(self, x): 72 | if self.bits == 16 or (not self.enable): 73 | return x 74 | fq_x = self.fake_quant(x) 75 | return fq_x 76 | 77 | def fake_quant(self, x): 78 | x_dtype = x.dtype 79 | scale, zero = self.get_scale_zero(x) 80 | if self.sym: 81 | return sym_quant_dequant(x, scale, self.q_max.to(x)).to(x_dtype) 82 | else: 83 | return asym_quant_dequant(x, scale, zero, self.q_max.to(x)).to(x_dtype) # TODO 84 | 85 | def get_scale_zero(self, x): 86 | q_max = self.q_max.to(x) 87 | init_shape = x.shape 88 | if len(init_shape) == 2: 89 | reshaped_x = x.reshape((-1, x.shape[-1])) 90 | xmax, xmin = reshaped_x.amax(1, keepdim=True), reshaped_x.amin(1, keepdim=True) 91 | tmp = torch.zeros_like(xmax) 92 | xmax, xmin = torch.maximum(xmax, tmp), torch.minimum(xmin, tmp) 93 | 94 | 95 | if self.lac: 96 | xmax = xmax * self.sigmoid(self.clip_factor_a_max) 97 | xmin = xmin * self.sigmoid(self.clip_factor_a_min) 98 | elif self._clip_ratio is not None: 99 | xmax = xmax * self._clip_ratio 100 | xmin = xmin * self._clip_ratio 101 | if self.sym: 102 | xmax = torch.maximum(torch.abs(xmin), xmax) 103 | tmp = xmax == 0 104 | scale = (xmax / q_max) 105 | scale[tmp] = 1 106 | scale = scale.repeat(1, reshaped_x.shape[-1]).reshape(init_shape) 107 | zero = torch.zeros_like(scale) 108 | else: 109 | tmp = (xmin == 0) & (xmax == 0) 110 | xmin[tmp] = -1 111 | xmax[tmp] = +1 112 | scale = (xmax - xmin) / q_max 113 | zero = torch.round(-xmin / scale) 114 | 115 | scale = scale.repeat(1, reshaped_x.shape[-1]).reshape(init_shape) 116 | zero = zero.repeat(1, reshaped_x.shape[-1]).reshape(init_shape) 117 | else: 118 | reshaped_x = x 119 | 120 | xmax, xmin = reshaped_x.amax(-1, keepdim=True), reshaped_x.amin(-1, keepdim=True) 121 | tmp = torch.zeros_like(xmax) 122 | xmax, xmin = torch.maximum(xmax, tmp), torch.minimum(xmin, tmp) 123 | 124 | if self.lac: 125 | xmax = xmax * self.sigmoid(self.clip_factor_a_max) 126 | xmin = xmin * self.sigmoid(self.clip_factor_a_min) 127 | elif self._clip_ratio is not None: 128 | xmax = xmax * self._clip_ratio 129 | xmin = xmin * self._clip_ratio 130 | if self.sym: 131 | xmax = torch.maximum(torch.abs(xmin), xmax) 132 | tmp = xmax == 0 133 | scale = (xmax / q_max) 134 | scale[tmp] = 1 135 | scale = scale.repeat(1, 1, reshaped_x.shape[-1]).reshape(init_shape) 136 | zero = torch.zeros_like(scale) 137 | else: 138 | tmp = (xmin == 0) & (xmax == 0) 139 | xmin[tmp] = -1 140 | xmax[tmp] = +1 141 | scale = (xmax - xmin) / q_max 142 | zero = torch.round(-xmin / scale) 143 | 144 | scale = scale.repeat(1, reshaped_x.shape[-1]).reshape(init_shape) 145 | zero = zero.repeat(1, reshaped_x.shape[-1]).reshape(init_shape) 146 | 147 | return scale, zero 148 | 149 | 150 | class WeightQuantizer(torch.nn.Module): 151 | def __init__(self, shape=1): 152 | super(WeightQuantizer, self).__init__() 153 | self.register_buffer('maxq', torch.tensor(0)) 154 | self.register_buffer('scale', torch.zeros(shape)) 155 | self.register_buffer('zero', torch.zeros(shape)) 156 | 157 | self.enable = True 158 | 159 | def configure( 160 | self, 161 | bits, perchannel=False, sym=True, 162 | mse=False, norm=2.4, grid=100, maxshrink=.8 163 | ): 164 | self.bits = bits 165 | self.perchannel = perchannel 166 | self.sym = sym 167 | 168 | self.mse = mse 169 | self.norm = norm 170 | self.grid = grid 171 | self.maxshrink = maxshrink 172 | if sym: 173 | self.maxq = torch.tensor(2**(bits-1)-1) 174 | else: 175 | self.maxq = torch.tensor(2**bits - 1) 176 | 177 | 178 | self.init_done = False 179 | 180 | def find_params(self, x): 181 | 182 | if self.bits == 16 or (not self.enable): 183 | return 184 | dev = x.device 185 | self.maxq = self.maxq.to(dev) 186 | 187 | shape = x.shape 188 | if self.perchannel: 189 | x = x.flatten(1) 190 | else: 191 | x = x.flatten().unsqueeze(0) 192 | 193 | 194 | tmp = torch.zeros(x.shape[0], device=dev) 195 | xmin = torch.minimum(x.min(1)[0], tmp) 196 | xmax = torch.maximum(x.max(1)[0], tmp) 197 | 198 | if self.sym: 199 | xmax = torch.maximum(torch.abs(xmin), xmax).clamp(min=1e-5) 200 | self.scale = xmax / self.maxq 201 | self.zero = torch.zeros_like(self.scale) 202 | else: 203 | 204 | tmp = (xmin == 0) & (xmax == 0) 205 | 206 | xmin[tmp] = -1 207 | xmax[tmp] = +1 208 | self.scale = (xmax - xmin).clamp(min=1e-5) / self.maxq 209 | self.zero = torch.round(-xmin / self.scale) 210 | 211 | 212 | if self.mse: 213 | best = torch.full([x.shape[0]], float('inf'), device=dev) 214 | for i in range(int(self.maxshrink * self.grid)): 215 | p = 1 - i / self.grid 216 | xmin1 = p * xmin 217 | xmax1 = p * xmax 218 | 219 | if self.sym: 220 | scale1 = xmax1 / self.maxq 221 | zero1 = torch.zeros_like(scale1) 222 | q = sym_quant_dequant(x, scale1.unsqueeze(1), self.maxq) 223 | else: 224 | scale1 = (xmax1 - xmin1) / self.maxq 225 | zero1 = torch.round(-xmin1 / scale1) 226 | q = asym_quant_dequant(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 227 | 228 | q -= x 229 | q.abs_() 230 | q.pow_(self.norm) 231 | err = torch.sum(q, 1) 232 | tmp = err < best 233 | if torch.any(tmp): 234 | best[tmp] = err[tmp] 235 | self.scale[tmp] = scale1[tmp] 236 | self.zero[tmp] = zero1[tmp] 237 | 238 | if not self.perchannel: 239 | tmp = shape[0] 240 | self.scale = self.scale.repeat(tmp) 241 | self.zero = self.zero.repeat(tmp) 242 | 243 | shape = [-1] + [1] * (len(shape) - 1) 244 | self.scale = self.scale.reshape(shape) 245 | self.zero = self.zero.reshape(shape) 246 | return 247 | 248 | def quantize(self, x): 249 | x_dtype = x.dtype 250 | if self.enable and self.ready() and self.bits < 16: 251 | if self.sym: 252 | return sym_quant_dequant(x, self.scale, self.maxq).to(x_dtype) 253 | return asym_quant_dequant(x, self.scale, self.zero, self.maxq).to(x_dtype) 254 | return x 255 | 256 | def forward(self, x): 257 | return self.quantize(x) 258 | 259 | def enabled(self): 260 | return self.maxq > 0 261 | 262 | def ready(self): 263 | return torch.all(self.scale != 0) 264 | 265 | 266 | def set_quantizer_state(model, enable=True): 267 | for m in model.modules(): 268 | if isinstance(m, (WeightQuantizer, ActivationQuantizer)): 269 | m.enable = enable 270 | return model 271 | 272 | 273 | def set_weight_quantizer_state(model, enable=True): 274 | for m in model.modules(): 275 | if isinstance(m, WeightQuantizer): 276 | m.enable = enable 277 | return model 278 | 279 | 280 | def set_act_quantizer_state(model, enable=True): 281 | for m in model.modules(): 282 | if isinstance(m, ActivationQuantizer): 283 | m.enable = enable 284 | return model 285 | -------------------------------------------------------------------------------- /vggt/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | from typing import Callable, List, Any, Tuple, Dict 13 | import warnings 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | XFORMERS_AVAILABLE = False 25 | 26 | 27 | class Block(nn.Module): 28 | def __init__( 29 | self, 30 | dim: int, 31 | num_heads: int, 32 | mlp_ratio: float = 4.0, 33 | qkv_bias: bool = True, 34 | proj_bias: bool = True, 35 | ffn_bias: bool = True, 36 | drop: float = 0.0, 37 | attn_drop: float = 0.0, 38 | init_values=None, 39 | drop_path: float = 0.0, 40 | act_layer: Callable[..., nn.Module] = nn.GELU, 41 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 42 | attn_class: Callable[..., nn.Module] = Attention, 43 | ffn_layer: Callable[..., nn.Module] = Mlp, 44 | qk_norm: bool = False, 45 | fused_attn: bool = True, # use F.scaled_dot_product_attention or not 46 | rope=None, 47 | ) -> None: 48 | super().__init__() 49 | 50 | self.norm1 = norm_layer(dim) 51 | 52 | self.attn = attn_class( 53 | dim, 54 | num_heads=num_heads, 55 | qkv_bias=qkv_bias, 56 | proj_bias=proj_bias, 57 | attn_drop=attn_drop, 58 | proj_drop=drop, 59 | qk_norm=qk_norm, 60 | fused_attn=fused_attn, 61 | rope=rope, 62 | ) 63 | 64 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 65 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 66 | 67 | self.norm2 = norm_layer(dim) 68 | mlp_hidden_dim = int(dim * mlp_ratio) 69 | self.mlp = ffn_layer( 70 | in_features=dim, 71 | hidden_features=mlp_hidden_dim, 72 | act_layer=act_layer, 73 | drop=drop, 74 | bias=ffn_bias, 75 | ) 76 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 77 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 78 | 79 | self.sample_drop_ratio = drop_path 80 | 81 | def forward(self, x: Tensor, pos=None) -> Tensor: 82 | def attn_residual_func(x: Tensor, pos=None) -> Tensor: 83 | return self.ls1(self.attn(self.norm1(x), pos=pos)) 84 | 85 | def ffn_residual_func(x: Tensor) -> Tensor: 86 | return self.ls2(self.mlp(self.norm2(x))) 87 | 88 | if self.training and self.sample_drop_ratio > 0.1: 89 | # the overhead is compensated only for a drop path rate larger than 0.1 90 | x = drop_add_residual_stochastic_depth( 91 | x, 92 | pos=pos, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x, pos=pos)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x, pos=pos) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | pos=None, 115 | ) -> Tensor: 116 | # 1) extract subset using permutation 117 | b, n, d = x.shape 118 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 119 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 120 | x_subset = x[brange] 121 | 122 | # 2) apply residual_func to get residual 123 | if pos is not None: 124 | # if necessary, apply rope to the subset 125 | pos = pos[brange] 126 | residual = residual_func(x_subset, pos=pos) 127 | else: 128 | residual = residual_func(x_subset) 129 | 130 | x_flat = x.flatten(1) 131 | residual = residual.flatten(1) 132 | 133 | residual_scale_factor = b / sample_subset_size 134 | 135 | # 3) add the residual 136 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 137 | return x_plus_residual.view_as(x) 138 | 139 | 140 | def get_branges_scales(x, sample_drop_ratio=0.0): 141 | b, n, d = x.shape 142 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 143 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 144 | residual_scale_factor = b / sample_subset_size 145 | return brange, residual_scale_factor 146 | 147 | 148 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 149 | if scaling_vector is None: 150 | x_flat = x.flatten(1) 151 | residual = residual.flatten(1) 152 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 153 | else: 154 | x_plus_residual = scaled_index_add( 155 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 156 | ) 157 | return x_plus_residual 158 | 159 | 160 | attn_bias_cache: Dict[Tuple, Any] = {} 161 | 162 | 163 | def get_attn_bias_and_cat(x_list, branges=None): 164 | """ 165 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 166 | """ 167 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 168 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 169 | if all_shapes not in attn_bias_cache.keys(): 170 | seqlens = [] 171 | for b, x in zip(batch_sizes, x_list): 172 | for _ in range(b): 173 | seqlens.append(x.shape[1]) 174 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 175 | attn_bias._batch_sizes = batch_sizes 176 | attn_bias_cache[all_shapes] = attn_bias 177 | 178 | if branges is not None: 179 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 180 | else: 181 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 182 | cat_tensors = torch.cat(tensors_bs1, dim=1) 183 | 184 | return attn_bias_cache[all_shapes], cat_tensors 185 | 186 | 187 | def drop_add_residual_stochastic_depth_list( 188 | x_list: List[Tensor], 189 | residual_func: Callable[[Tensor, Any], Tensor], 190 | sample_drop_ratio: float = 0.0, 191 | scaling_vector=None, 192 | ) -> Tensor: 193 | # 1) generate random set of indices for dropping samples in the batch 194 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 195 | branges = [s[0] for s in branges_scales] 196 | residual_scale_factors = [s[1] for s in branges_scales] 197 | 198 | # 2) get attention bias and index+concat the tensors 199 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 200 | 201 | # 3) apply residual_func to get residual, and split the result 202 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 203 | 204 | outputs = [] 205 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 206 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 207 | return outputs 208 | 209 | 210 | class NestedTensorBlock(Block): 211 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 212 | """ 213 | x_list contains a list of tensors to nest together and run 214 | """ 215 | assert isinstance(self.attn, MemEffAttention) 216 | 217 | if self.training and self.sample_drop_ratio > 0.0: 218 | 219 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 220 | return self.attn(self.norm1(x), attn_bias=attn_bias) 221 | 222 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 223 | return self.mlp(self.norm2(x)) 224 | 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=attn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | x_list = drop_add_residual_stochastic_depth_list( 232 | x_list, 233 | residual_func=ffn_residual_func, 234 | sample_drop_ratio=self.sample_drop_ratio, 235 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 236 | ) 237 | return x_list 238 | else: 239 | 240 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 241 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 242 | 243 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 244 | return self.ls2(self.mlp(self.norm2(x))) 245 | 246 | attn_bias, x = get_attn_bias_and_cat(x_list) 247 | x = x + attn_residual_func(x, attn_bias=attn_bias) 248 | x = x + ffn_residual_func(x) 249 | return attn_bias.split(x) 250 | 251 | def forward(self, x_or_x_list): 252 | if isinstance(x_or_x_list, Tensor): 253 | return super().forward(x_or_x_list) 254 | elif isinstance(x_or_x_list, list): 255 | if not XFORMERS_AVAILABLE: 256 | raise AssertionError("xFormers is required for using nested tensors") 257 | return self.forward_nested(x_or_x_list) 258 | else: 259 | raise AssertionError 260 | -------------------------------------------------------------------------------- /vggt/heads/track_modules/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Modified from https://github.com/facebookresearch/co-tracker/ 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from .utils import bilinear_sampler 16 | from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock 17 | 18 | 19 | class EfficientUpdateFormer(nn.Module): 20 | """ 21 | Transformer model that updates track estimates. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | space_depth=6, 27 | time_depth=6, 28 | input_dim=320, 29 | hidden_size=384, 30 | num_heads=8, 31 | output_dim=130, 32 | mlp_ratio=4.0, 33 | add_space_attn=True, 34 | num_virtual_tracks=64, 35 | ): 36 | super().__init__() 37 | 38 | self.out_channels = 2 39 | self.num_heads = num_heads 40 | self.hidden_size = hidden_size 41 | self.add_space_attn = add_space_attn 42 | 43 | # Add input LayerNorm before linear projection 44 | self.input_norm = nn.LayerNorm(input_dim) 45 | self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) 46 | 47 | # Add output LayerNorm before final projection 48 | self.output_norm = nn.LayerNorm(hidden_size) 49 | self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) 50 | self.num_virtual_tracks = num_virtual_tracks 51 | 52 | if self.add_space_attn: 53 | self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) 54 | else: 55 | self.virual_tracks = None 56 | 57 | self.time_blocks = nn.ModuleList( 58 | [ 59 | AttnBlock( 60 | hidden_size, 61 | num_heads, 62 | mlp_ratio=mlp_ratio, 63 | attn_class=nn.MultiheadAttention, 64 | ) 65 | for _ in range(time_depth) 66 | ] 67 | ) 68 | 69 | if add_space_attn: 70 | self.space_virtual_blocks = nn.ModuleList( 71 | [ 72 | AttnBlock( 73 | hidden_size, 74 | num_heads, 75 | mlp_ratio=mlp_ratio, 76 | attn_class=nn.MultiheadAttention, 77 | ) 78 | for _ in range(space_depth) 79 | ] 80 | ) 81 | self.space_point2virtual_blocks = nn.ModuleList( 82 | [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] 83 | ) 84 | self.space_virtual2point_blocks = nn.ModuleList( 85 | [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] 86 | ) 87 | assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) 88 | self.initialize_weights() 89 | 90 | def initialize_weights(self): 91 | def _basic_init(module): 92 | if isinstance(module, nn.Linear): 93 | torch.nn.init.xavier_uniform_(module.weight) 94 | if module.bias is not None: 95 | nn.init.constant_(module.bias, 0) 96 | torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) 97 | 98 | self.apply(_basic_init) 99 | 100 | def forward(self, input_tensor, mask=None): 101 | # Apply input LayerNorm 102 | input_tensor = self.input_norm(input_tensor) 103 | tokens = self.input_transform(input_tensor) 104 | 105 | init_tokens = tokens 106 | 107 | B, _, T, _ = tokens.shape 108 | 109 | if self.add_space_attn: 110 | virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) 111 | tokens = torch.cat([tokens, virtual_tokens], dim=1) 112 | 113 | _, N, _, _ = tokens.shape 114 | 115 | j = 0 116 | for i in range(len(self.time_blocks)): 117 | time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C 118 | 119 | time_tokens = self.time_blocks[i](time_tokens) 120 | 121 | tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C 122 | if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): 123 | space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C 124 | point_tokens = space_tokens[:, : N - self.num_virtual_tracks] 125 | virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] 126 | 127 | virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) 128 | virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) 129 | point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) 130 | 131 | space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) 132 | tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C 133 | j += 1 134 | 135 | if self.add_space_attn: 136 | tokens = tokens[:, : N - self.num_virtual_tracks] 137 | 138 | tokens = tokens + init_tokens 139 | 140 | # Apply output LayerNorm before final projection 141 | tokens = self.output_norm(tokens) 142 | flow = self.flow_head(tokens) 143 | 144 | return flow, None 145 | 146 | 147 | class CorrBlock: 148 | def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): 149 | """ 150 | Build a pyramid of feature maps from the input. 151 | 152 | fmaps: Tensor (B, S, C, H, W) 153 | num_levels: number of pyramid levels (each downsampled by factor 2) 154 | radius: search radius for sampling correlation 155 | multiple_track_feats: if True, split the target features per pyramid level 156 | padding_mode: passed to grid_sample / bilinear_sampler 157 | """ 158 | B, S, C, H, W = fmaps.shape 159 | self.S, self.C, self.H, self.W = S, C, H, W 160 | self.num_levels = num_levels 161 | self.radius = radius 162 | self.padding_mode = padding_mode 163 | self.multiple_track_feats = multiple_track_feats 164 | 165 | # Build pyramid: each level is half the spatial resolution of the previous 166 | self.fmaps_pyramid = [fmaps] # level 0 is full resolution 167 | current_fmaps = fmaps 168 | for i in range(num_levels - 1): 169 | B, S, C, H, W = current_fmaps.shape 170 | # Merge batch & sequence dimensions 171 | current_fmaps = current_fmaps.reshape(B * S, C, H, W) 172 | # Avg pool down by factor 2 173 | current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) 174 | _, _, H_new, W_new = current_fmaps.shape 175 | current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) 176 | self.fmaps_pyramid.append(current_fmaps) 177 | 178 | # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. 179 | # This grid is added to the (scaled) coordinate centroids. 180 | r = self.radius 181 | dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) 182 | dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) 183 | # delta: for every (dy,dx) displacement (i.e. Δx, Δy) 184 | self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) 185 | 186 | def corr_sample(self, targets, coords): 187 | """ 188 | Instead of storing the entire correlation pyramid, we compute each level's correlation 189 | volume, sample it immediately, then discard it. This saves GPU memory. 190 | 191 | Args: 192 | targets: Tensor (B, S, N, C) — features for the current targets. 193 | coords: Tensor (B, S, N, 2) — coordinates at full resolution. 194 | 195 | Returns: 196 | Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) 197 | """ 198 | B, S, N, C = targets.shape 199 | 200 | # If you have multiple track features, split them per level. 201 | if self.multiple_track_feats: 202 | targets_split = torch.split(targets, C // self.num_levels, dim=-1) 203 | 204 | out_pyramid = [] 205 | for i, fmaps in enumerate(self.fmaps_pyramid): 206 | # Get current spatial resolution H, W for this pyramid level. 207 | B, S, C, H, W = fmaps.shape 208 | # Reshape feature maps for correlation computation: 209 | # fmap2s: (B, S, C, H*W) 210 | fmap2s = fmaps.view(B, S, C, H * W) 211 | # Choose appropriate target features. 212 | fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) 213 | 214 | # Compute correlation directly 215 | corrs = compute_corr_level(fmap1, fmap2s, C) 216 | corrs = corrs.view(B, S, N, H, W) 217 | 218 | # Prepare sampling grid: 219 | # Scale down the coordinates for the current level. 220 | centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) 221 | # Make sure our precomputed delta grid is on the same device/dtype. 222 | delta_lvl = self.delta.to(coords.device).to(coords.dtype) 223 | # Now the grid for grid_sample is: 224 | # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) 225 | coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) 226 | 227 | # Sample from the correlation volume using bilinear interpolation. 228 | # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. 229 | corrs_sampled = bilinear_sampler( 230 | corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode 231 | ) 232 | # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. 233 | corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) 234 | out_pyramid.append(corrs_sampled) 235 | 236 | # Concatenate all levels along the last dimension. 237 | out = torch.cat(out_pyramid, dim=-1).contiguous() 238 | return out 239 | 240 | 241 | def compute_corr_level(fmap1, fmap2s, C): 242 | # fmap1: (B, S, N, C) 243 | # fmap2s: (B, S, C, H*W) 244 | corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) 245 | corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) 246 | return corrs / math.sqrt(C) 247 | -------------------------------------------------------------------------------- /evaluation/quarot/args_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | import logging 5 | from termcolor import colored 6 | import pprint 7 | 8 | 9 | class VGGTQuantizedConfig: 10 | def __init__(self): 11 | # General Arguments 12 | self.seed = 0 # Random seed for HuggingFace and PyTorch 13 | self.hf_token = None # HuggingFace token for model access 14 | 15 | # Activation Quantization Arguments 16 | self.a_bits = 16 # Number of bits for inputs of the linear layers 17 | self.a_groupsize = -1 # Groupsize for activation quantization 18 | self.a_asym = False # Use asymmetric activation quantization 19 | 20 | # Weight Quantization Arguments 21 | self.w_bits = 16 # Number of bits for weights of the linear layers 22 | self.w_groupsize = -1 # Groupsize for weight quantization 23 | self.w_asym = False # Use asymmetric weight quantization 24 | self.gptq = False # Use GPTQ for weight quantization 25 | self.gptq_mse = False # Use MSE search for optimal clipping threshold 26 | self.percdamp = 0.01 # Percent of average Hessian diagonal for dampening 27 | self.act_order = False # Use act-order in GPTQ 28 | 29 | # FlatQuant calibration Arguments 30 | self.epochs = 15 # Number of training epochs 31 | 32 | # 控制类型 33 | self.not_smooth = False # 是否启用smooth对角矩阵 34 | self.not_rot = False 35 | self.lwc = False # Use learnable weight clipping 36 | self.lac = False # Use learnable activation clipping 37 | self.rv = False 38 | 39 | # debug 40 | self.nsamples = 10 # Number of calibration data samples, 41 | self.cali_bsz = 1 # Batch size for FlatQuant 42 | self.qs_lr = 1e-3 # Learning rate for learnable transformation 43 | self.cali_trans = False # ⭐Enable calibration of transformations 44 | self.add_diag = True # Add per-channel scaling 45 | self.resume = False # Resume from previous checkpoint 46 | self.save_matrix = False # Save matrix-style parameters 47 | self.reload_matrix = False # Reload matrices for evaluation 48 | self.matrix_path = None # Path to pre-trained matrix parameters 49 | self.diag_init = "sq_style" # Way to initialize per-channel scaling 50 | self.diag_alpha = 0.5 # Hyperparameter for SmoothQuant initialization 51 | self.warmup = False # Warm up learning rate during training 52 | self.deactive_amp = False # Disable AMP training 53 | self.direct_inv = False # Use PyTorch inverse method 54 | self.separate_vtrans = False # Disable vtrans transformation integration 55 | 56 | 57 | # Experiments Arguments 58 | self.output_dir = "./outputs" # Output directory path 59 | self.exp_name = "exp_5b_test" # Experiment name 60 | self.cache_dir = None 61 | 62 | def update_nsamples(self,calib_data_num): 63 | self.nsamples = calib_data_num 64 | 65 | def update_from_args(self, wbit, abit, model_id, not_smooth, not_rot, lwc , lac,rv ,exp_name=None): 66 | self.w_bits = wbit 67 | self.a_bits = abit 68 | self.not_smooth = not_smooth 69 | self.not_rot = not_rot 70 | self.lwc = lwc 71 | self.lac = lac 72 | self.rv = rv 73 | if self.a_groupsize > -1: 74 | raise NotImplementedError 75 | 76 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 77 | self.quantize = (self.w_bits < 16) or (self.a_bits < 16) 78 | 79 | model_name = model_id.split("/")[-1].lower() 80 | sym_flag = "asym" if self.w_asym else "sym" 81 | self.exp_dir = os.path.join(self.output_dir, 82 | f"w{self.w_bits}a{self.a_bits}", f"{exp_name}_{model_name}_{sym_flag}") 83 | self.cache_dir = os.path.join(self.output_dir, 84 | f"w{self.w_bits}a{self.a_bits}") 85 | 86 | os.makedirs(self.exp_dir, exist_ok=True) 87 | 88 | def get_config(): 89 | return VGGTQuantizedConfig() 90 | 91 | 92 | def parser_gen(): 93 | parser = argparse.ArgumentParser() 94 | 95 | # General Arguments 96 | parser.add_argument('--seed', type=int, default=0, help='Random seed for HuggingFace and PyTorch.') 97 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace token for model access.') 98 | 99 | # Activation Quantization Arguments 100 | parser.add_argument('--a_bits', type=int, default=16, 101 | help='''Number of bits for inputs of the linear layers. 102 | This applies to all linear layers in the model, including down-projection and out-projection.''') 103 | parser.add_argument('--a_groupsize', type=int, default=-1, 104 | help='Groupsize for activation quantization. Note that this should be the same as w_groupsize.') 105 | parser.add_argument('--a_asym', action="store_true", default=False, 106 | help='Use asymmetric activation quantization.') 107 | 108 | # Weight Quantization Arguments 109 | parser.add_argument('--w_bits', type=int, default=16, 110 | help='Number of bits for weights of the linear layers.') 111 | parser.add_argument('--w_groupsize', type=int, default=-1, 112 | help='Groupsize for weight quantization. Note that this should be the same as a_groupsize.') 113 | parser.add_argument('--w_asym', action="store_true", default=False, 114 | help='Use asymmetric weight quantization.') 115 | parser.add_argument('--gptq', action="store_true", default=False, 116 | help='Quantize the weights using GPTQ. If w_bits < 16 and this flag is not set, use RtN.') 117 | parser.add_argument('--gptq_mse', action="store_true", default=False, 118 | help='''Use MSE search to find the optimal clipping threshold for weight quantization. 119 | NOTE: Do not activate while using LWC.''') 120 | parser.add_argument('--percdamp', type=float, default=.01, 121 | help='Percent of the average Hessian diagonal to use for dampening.') 122 | parser.add_argument('--act_order', action="store_true", default=False, 123 | help='Use act-order in GPTQ.') 124 | 125 | # FlatQuant calibration Arguments 126 | parser.add_argument('--epochs', type=int, default=15, help='Number of training epochs.') 127 | parser.add_argument('--nsamples', type=int, default=128, 128 | help='Number of calibration data samples for FlatQuant and GPTQ.') 129 | parser.add_argument('--cali_bsz', type=int, default=4, 130 | help='Batch size for FlatQuant. Default is 4.') 131 | parser.add_argument("--flat_qs", type=float, default=1e-5, 132 | help='Learning rate for learnable transformation.') 133 | parser.add_argument("--cali_trans", default=False, action="store_true", 134 | help="Enable calibration of transformations.") 135 | parser.add_argument("--add_diag", default=False, action="store_true", 136 | help="Add per-channel scaling.") 137 | parser.add_argument("--lwc", default=False, action="store_true", 138 | help="Use learnable weight clipping.") 139 | parser.add_argument("--lac", default=False, action="store_true", 140 | help="Use learnable activation clipping.") 141 | parser.add_argument('--resume', action="store_true", default=False, 142 | help='Resume from a previous checkpoint for evaluation.') 143 | parser.add_argument('--save_matrix', action="store_true", default=False, 144 | help='Save the matrix-style parameters of FlatQuant.') 145 | parser.add_argument('--reload_matrix', action="store_true", default=False, 146 | help='Reload matrices and the inverse matrices for evaluation.') 147 | parser.add_argument('--matrix_path', type=str, default=None, 148 | help='Path to the pre-trained matrix-style parameters of FlatQuant.') 149 | parser.add_argument("--diag_init", type=str, default="sq_style", choices=["sq_style", "one_style"], 150 | help='The way to initialize per-channel scaling. Default is SmoothQuant style.') 151 | parser.add_argument("--diag_alpha", type=float, default=0.3, 152 | help='Hyperparameter for style initialization of per-channel scaling.') 153 | parser.add_argument("--warmup", default=False, action="store_true", help="Warm up the learning rate during training.") 154 | parser.add_argument("--deactive_amp", default=False, action="store_true", help="Disable AMP training.") 155 | parser.add_argument("--direct_inv", default=False, action="store_true", 156 | help="Use the inverse method in PyTorch to directly get the inverse matrix rather than SVD.") 157 | parser.add_argument("--separate_vtrans", default=False, action="store_true", 158 | help="Disable the integration of the vtrans transformation.") 159 | 160 | # Experiments Arguments 161 | parser.add_argument("--output_dir", type=str, default="./outputs", help="Output directory path.") 162 | parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") 163 | 164 | args = parser.parse_args() 165 | if args.a_groupsize > -1: 166 | raise NotImplementedError 167 | 168 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 169 | args.quantize = (args.w_bits < 16) or (args.a_bits < 16) 170 | # cache path 171 | args.cache_dir = os.path.join(args.output_dir, ".cache") 172 | os.makedirs(args.cache_dir, exist_ok=True) 173 | # output path 174 | args.model_name = args.model.split("/")[-1] 175 | args.exp_dir = os.path.join(args.output_dir, args.model_name, f"w{args.w_bits}a{args.a_bits}", args.exp_name) 176 | os.makedirs(args.exp_dir, exist_ok=True) 177 | 178 | logger = create_logger(args.exp_dir) 179 | logger.info('Arguments: ') 180 | logger.info(pprint.pformat(vars(args))) 181 | logger.info('--' * 30) 182 | return args, logger 183 | 184 | 185 | def create_logger(exp_dir, dist_rank=0, name=''): 186 | # create logger 187 | logger = logging.getLogger(name) 188 | logger.setLevel(logging.INFO) 189 | logger.propagate = False 190 | 191 | # create formatter 192 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 193 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 194 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 195 | 196 | # create console handlers for master process 197 | if dist_rank == 0: 198 | console_handler = logging.StreamHandler() 199 | console_handler.setLevel(logging.DEBUG) 200 | console_handler.setFormatter( 201 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 202 | logger.addHandler(console_handler) 203 | 204 | # create file handlers 205 | log_file = os.path.join(exp_dir, f'log_rank{dist_rank}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt') 206 | file_handler = logging.FileHandler(log_file, mode='a') 207 | file_handler.setLevel(logging.DEBUG) 208 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 209 | logger.addHandler(file_handler) 210 | 211 | return logger -------------------------------------------------------------------------------- /evaluation/tensor_to_pycolmap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | import numpy as np 13 | import pycolmap 14 | 15 | 16 | 17 | # TODO: frame_idx should start from 1 instead of 0 in colmap 18 | 19 | 20 | def batch_matrix_to_pycolmap( 21 | points3d, 22 | extrinsics, 23 | intrinsics, 24 | tracks, 25 | image_size, 26 | masks=None, 27 | max_reproj_error=None, 28 | max_points3D_val=3000, 29 | shared_camera=False, 30 | camera_type="SIMPLE_PINHOLE", 31 | extra_params=None, 32 | min_inlier_per_frame=64 33 | ): 34 | """ 35 | Convert Batched Pytorch Tensors to PyCOLMAP 36 | 37 | Check https://github.com/colmap/pycolmap for more details about its format 38 | """ 39 | 40 | # points3d: Px3 41 | # extrinsics: Nx3x4 42 | # intrinsics: Nx3x3 43 | # tracks: NxPx2 44 | # masks: NxP 45 | # image_size: 2, assume all the frames have been padded to the same size 46 | # where N is the number of frames and P is the number of tracks 47 | 48 | N, P, _ = tracks.shape 49 | assert len(extrinsics) == N 50 | assert len(intrinsics) == N 51 | assert len(points3d) == P 52 | assert image_size.shape[0] == 2 53 | 54 | if max_reproj_error is not None: 55 | projected_points_2d, projected_points_cam = project_3D_points(points3d, extrinsics, intrinsics, return_points_cam=True) 56 | projected_diff = (projected_points_2d - tracks).norm(dim=-1) 57 | projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 58 | reproj_mask = projected_diff < max_reproj_error 59 | 60 | if masks is not None and reproj_mask is not None: 61 | masks = torch.logical_and(masks, reproj_mask) 62 | elif masks is not None: 63 | masks = masks 64 | else: 65 | masks = reproj_mask 66 | 67 | assert masks is not None 68 | 69 | 70 | if masks.sum(1).min() < min_inlier_per_frame: 71 | print(f"Not enough inliers per frame, skip BA for this sequence.") 72 | return None, None 73 | 74 | extrinsics = extrinsics.cpu().numpy() 75 | intrinsics = intrinsics.cpu().numpy() 76 | 77 | if extra_params is not None: 78 | extra_params = extra_params.cpu().numpy() 79 | 80 | 81 | tracks = tracks.cpu().numpy() 82 | points3d = points3d.cpu().numpy() 83 | image_size = image_size.cpu().numpy() 84 | 85 | # Reconstruction object, following the format of PyCOLMAP/COLMAP 86 | reconstruction = pycolmap.Reconstruction() 87 | 88 | masks = masks.cpu().numpy() 89 | 90 | inlier_num = masks.sum(0) 91 | valid_mask = inlier_num >= 2 # a track is invalid if without two inliers 92 | valid_idx = np.nonzero(valid_mask)[0] 93 | 94 | # Only add 3D points that have sufficient 2D points 95 | for vidx in valid_idx: 96 | reconstruction.add_point3D( 97 | points3d[vidx], pycolmap.Track(), np.zeros(3) 98 | ) 99 | 100 | num_points3D = len(valid_idx) 101 | camera = None 102 | # frame idx 103 | for fidx in range(N): 104 | # set camera 105 | if camera is None or (not shared_camera): 106 | if camera_type == "PINHOLE": 107 | pycolmap_intri = np.array( 108 | [ 109 | intrinsics[fidx][0, 0], 110 | intrinsics[fidx][1, 1], 111 | intrinsics[fidx][0, 2], 112 | intrinsics[fidx][1, 2], 113 | ] 114 | ) 115 | elif camera_type == "SIMPLE_RADIAL": 116 | focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 117 | pycolmap_intri = np.array( 118 | [ 119 | focal, 120 | intrinsics[fidx][0, 2], 121 | intrinsics[fidx][1, 2], 122 | extra_params[fidx][0], 123 | ] 124 | ) 125 | elif camera_type == "SIMPLE_PINHOLE": 126 | focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 127 | pycolmap_intri = np.array( 128 | [ 129 | focal, 130 | intrinsics[fidx][0, 2], 131 | intrinsics[fidx][1, 2], 132 | ] 133 | ) 134 | else: 135 | raise ValueError( 136 | f"Camera type {camera_type} is not supported yet" 137 | ) 138 | 139 | camera = pycolmap.Camera( 140 | model=camera_type, 141 | width=image_size[0], 142 | height=image_size[1], 143 | params=pycolmap_intri, 144 | camera_id=fidx, 145 | ) 146 | 147 | # add camera 148 | reconstruction.add_camera(camera) 149 | 150 | # set image 151 | cam_from_world = pycolmap.Rigid3d( 152 | pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), 153 | extrinsics[fidx][:3, 3], 154 | ) # Rot and Trans 155 | image = pycolmap.Image( 156 | id=fidx, 157 | name=f"image_{fidx}", 158 | camera_id=camera.camera_id, 159 | cam_from_world=cam_from_world, 160 | ) 161 | 162 | points2D_list = [] 163 | 164 | point2D_idx = 0 165 | # NOTE point3D_id start by 1 166 | for point3D_id in range(1, num_points3D + 1): 167 | original_track_idx = valid_idx[point3D_id - 1] 168 | 169 | if ( 170 | reconstruction.points3D[point3D_id].xyz < max_points3D_val 171 | ).all(): 172 | if masks[fidx][original_track_idx]: 173 | # It seems we don't need +0.5 for BA 174 | point2D_xy = tracks[fidx][original_track_idx] 175 | # Please note when adding the Point2D object 176 | # It not only requires the 2D xy location, but also the id to 3D point 177 | points2D_list.append( 178 | pycolmap.Point2D(point2D_xy, point3D_id) 179 | ) 180 | 181 | # add element 182 | track = reconstruction.points3D[point3D_id].track 183 | track.add_element(fidx, point2D_idx) 184 | point2D_idx += 1 185 | 186 | assert point2D_idx == len(points2D_list) 187 | 188 | try: 189 | image.points2D = pycolmap.ListPoint2D(points2D_list) 190 | image.registered = True 191 | except: 192 | print(f"frame {fidx} is out of BA") 193 | image.registered = False 194 | 195 | # add image 196 | reconstruction.add_image(image) 197 | 198 | return reconstruction, valid_mask 199 | 200 | 201 | def pycolmap_to_batch_matrix( 202 | reconstruction, device="cuda", camera_type="SIMPLE_PINHOLE" 203 | ): 204 | """ 205 | Convert a PyCOLMAP Reconstruction Object to batched PyTorch tensors. 206 | 207 | Args: 208 | reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. 209 | device (str): The device to place the tensors on (default: "cuda"). 210 | camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). 211 | 212 | Returns: 213 | tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. 214 | """ 215 | 216 | num_images = len(reconstruction.images) 217 | max_points3D_id = max(reconstruction.point3D_ids()) 218 | points3D = np.zeros((max_points3D_id, 3)) 219 | 220 | for point3D_id in reconstruction.points3D: 221 | points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz 222 | points3D = torch.from_numpy(points3D).to(device) 223 | 224 | extrinsics = [] 225 | intrinsics = [] 226 | 227 | extra_params = [] if camera_type == "SIMPLE_RADIAL" else None 228 | 229 | for i in range(num_images): 230 | # Extract and append extrinsics 231 | pyimg = reconstruction.images[i] 232 | pycam = reconstruction.cameras[pyimg.camera_id] 233 | matrix = pyimg.cam_from_world.matrix() 234 | extrinsics.append(matrix) 235 | 236 | # Extract and append intrinsics 237 | calibration_matrix = pycam.calibration_matrix() 238 | intrinsics.append(calibration_matrix) 239 | 240 | if camera_type == "SIMPLE_RADIAL": 241 | extra_params.append(pycam.params[-1]) 242 | 243 | # Convert lists to torch tensors 244 | extrinsics = torch.from_numpy(np.stack(extrinsics)).to(device) 245 | 246 | intrinsics = torch.from_numpy(np.stack(intrinsics)).to(device) 247 | 248 | if camera_type == "SIMPLE_RADIAL": 249 | extra_params = torch.from_numpy(np.stack(extra_params)).to(device) 250 | extra_params = extra_params[:, None] 251 | 252 | return points3D, extrinsics, intrinsics, extra_params 253 | 254 | 255 | 256 | 257 | 258 | def project_3D_points( 259 | points3D, 260 | extrinsics, 261 | intrinsics=None, 262 | extra_params=None, 263 | return_points_cam=False, 264 | default=0, 265 | only_points_cam=False, 266 | ): 267 | """ 268 | Transforms 3D points to 2D using extrinsic and intrinsic parameters. 269 | Args: 270 | points3D (torch.Tensor): 3D points of shape Px3. 271 | extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. 272 | intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. 273 | extra_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. 274 | Returns: 275 | torch.Tensor: Transformed 2D points of shape BxNx2. 276 | """ 277 | with torch.cuda.amp.autocast(dtype=torch.double): 278 | N = points3D.shape[0] # Number of points 279 | B = extrinsics.shape[0] # Batch size, i.e., number of cameras 280 | points3D_homogeneous = torch.cat( 281 | [points3D, torch.ones_like(points3D[..., 0:1])], dim=1 282 | ) # Nx4 283 | # Reshape for batch processing 284 | points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand( 285 | B, -1, -1 286 | ) # BxNx4 287 | 288 | # Step 1: Apply extrinsic parameters 289 | # Transform 3D points to camera coordinate system for all cameras 290 | points_cam = torch.bmm( 291 | extrinsics, points3D_homogeneous.transpose(-1, -2) 292 | ) 293 | 294 | if only_points_cam: 295 | return points_cam 296 | 297 | # Step 2: Apply intrinsic parameters and (optional) distortion 298 | points2D = img_from_cam(intrinsics, points_cam, extra_params) 299 | 300 | if return_points_cam: 301 | return points2D, points_cam 302 | return points2D 303 | 304 | 305 | def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): 306 | """ 307 | Applies intrinsic parameters and optional distortion to the given 3D points. 308 | 309 | Args: 310 | intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. 311 | points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. 312 | extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. 313 | default (float, optional): Default value to replace NaNs in the output. 314 | 315 | Returns: 316 | points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. 317 | """ 318 | 319 | # Normalize by the third coordinate (homogeneous division) 320 | points_cam = points_cam / points_cam[:, 2:3, :] 321 | # Extract uv 322 | uv = points_cam[:, :2, :] 323 | 324 | # Apply distortion if extra_params are provided 325 | if extra_params is not None: 326 | uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) 327 | uv = torch.stack([uu, vv], dim=1) 328 | 329 | # Prepare points_cam for batch matrix multiplication 330 | points_cam_homo = torch.cat( 331 | (uv, torch.ones_like(uv[:, :1, :])), dim=1 332 | ) # Bx3xN 333 | # Apply intrinsic parameters using batch matrix multiplication 334 | points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN 335 | 336 | # Extract x and y coordinates 337 | points2D = points2D_homo[:, :2, :] # Bx2xN 338 | 339 | # Replace NaNs with default value 340 | points2D = torch.nan_to_num(points2D, nan=default) 341 | 342 | return points2D.transpose(1, 2) # BxNx2 343 | --------------------------------------------------------------------------------