├── .gitignore ├── .gitmodules ├── README.md ├── assets └── fig1_teaser.png ├── data ├── download_bonn.sh ├── download_ckpt.sh ├── download_davis.py ├── download_dynamic_replica.sh ├── download_kitti.sh ├── download_nyuv2.sh ├── download_pointodyssey.sh ├── download_scannetv2.sh ├── download_sintel.sh ├── download_spring.sh ├── download_tartanair.py ├── download_tartanair.sh ├── download_training_zipfiles.txt ├── download_tum_dynamics.sh ├── download_waymo.sh ├── evaluation_script.md └── prepare_training.md ├── datasets_preprocess ├── habitat │ ├── README.md │ ├── find_scenes.py │ ├── habitat_renderer │ │ ├── __init__.py │ │ ├── habitat_sim_envmaps_renderer.py │ │ ├── multiview_crop_generator.py │ │ ├── projections.py │ │ └── projections_conversions.py │ └── preprocess_habitat.py ├── path_to_root.py ├── prepare_bonn.py ├── prepare_kitti.py ├── prepare_nyuv2.py ├── prepare_scannet.py ├── prepare_tum.py ├── preprocess_arkitscenes.py ├── preprocess_blendedMVS.py ├── preprocess_co3d.py ├── preprocess_megadepth.py ├── preprocess_scannetpp.py ├── preprocess_staticthings3d.py ├── preprocess_waymo.py ├── preprocess_wildrgbd.py ├── scannet_sens_reader.py ├── sintel_get_dynamics.py └── waymo_make_pairs.py ├── demo.py ├── demo_data ├── lady-running.mp4 └── lady-running │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ ├── 00008.jpg │ ├── 00009.jpg │ ├── 00010.jpg │ ├── 00011.jpg │ ├── 00012.jpg │ ├── 00013.jpg │ ├── 00014.jpg │ ├── 00015.jpg │ ├── 00016.jpg │ ├── 00017.jpg │ ├── 00018.jpg │ ├── 00019.jpg │ ├── 00020.jpg │ ├── 00021.jpg │ ├── 00022.jpg │ ├── 00023.jpg │ ├── 00024.jpg │ ├── 00025.jpg │ ├── 00026.jpg │ ├── 00027.jpg │ ├── 00028.jpg │ ├── 00029.jpg │ ├── 00030.jpg │ ├── 00031.jpg │ ├── 00032.jpg │ ├── 00033.jpg │ ├── 00034.jpg │ ├── 00035.jpg │ ├── 00036.jpg │ ├── 00037.jpg │ ├── 00038.jpg │ ├── 00039.jpg │ ├── 00040.jpg │ ├── 00041.jpg │ ├── 00042.jpg │ ├── 00043.jpg │ ├── 00044.jpg │ ├── 00045.jpg │ ├── 00046.jpg │ ├── 00047.jpg │ ├── 00048.jpg │ ├── 00049.jpg │ ├── 00050.jpg │ ├── 00051.jpg │ ├── 00052.jpg │ ├── 00053.jpg │ ├── 00054.jpg │ ├── 00055.jpg │ ├── 00056.jpg │ ├── 00057.jpg │ ├── 00058.jpg │ ├── 00059.jpg │ ├── 00060.jpg │ ├── 00061.jpg │ ├── 00062.jpg │ ├── 00063.jpg │ └── 00064.jpg ├── depth_metric.ipynb ├── dust3r ├── __init__.py ├── cloud_opt │ ├── __init__.py │ ├── base_opt.py │ ├── commons.py │ ├── init_im_poses.py │ ├── modular_optimizer.py │ ├── optimizer.py │ └── pair_viewer.py ├── datasets │ ├── __init__.py │ ├── arkitscenes.py │ ├── base │ │ ├── __init__.py │ │ ├── base_stereo_view_dataset.py │ │ ├── batched_sampler.py │ │ └── easy_dataset.py │ ├── blendedmvs.py │ ├── co3d.py │ ├── dynamic_replica.py │ ├── habitat.py │ ├── megadepth.py │ ├── pointodyssey.py │ ├── scannetpp.py │ ├── sintel.py │ ├── spring_dataset.py │ ├── staticthings3d.py │ ├── tartanair.py │ ├── utils │ │ ├── __init__.py │ │ ├── cropping.py │ │ └── transforms.py │ ├── waymo.py │ └── wildrgbd.py ├── demo.py ├── depth_eval.py ├── eval_metadata.py ├── heads │ ├── __init__.py │ ├── dpt_head.py │ ├── linear_head.py │ └── postprocess.py ├── image_pairs.py ├── inference.py ├── losses.py ├── model.py ├── optim_factory.py ├── patch_embed.py ├── pose_eval.py ├── post_process.py ├── training.py ├── utils │ ├── __init__.py │ ├── device.py │ ├── flow_vis.py │ ├── geometry.py │ ├── goem_opt.py │ ├── image.py │ ├── misc.py │ ├── parallel.py │ ├── path_to_croco.py │ ├── po_utils │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── geom.py │ │ ├── improc.py │ │ └── misc.py │ ├── viz_demo.py │ └── vo_eval.py └── viz.py ├── launch.py ├── requirements.txt ├── requirements_optional.txt └── third_party ├── RAFT ├── LICENSE ├── README.md ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── configs │ │ └── congif_spring_M.json │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── layer.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── demo.py ├── download_models.sh ├── evaluate.py ├── train.py ├── train_mixed.sh └── train_standard.sh ├── __init__.py ├── raft.py └── sam2 ├── .clang-format ├── .github └── workflows │ └── check_fmt.yml ├── .gitignore ├── .watchmanconfig ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INSTALL.md ├── LICENSE ├── LICENSE_cctorch ├── MANIFEST.in ├── README.md ├── assets ├── model_diagram.png └── sa_v_dataset.jpg ├── backend.Dockerfile ├── docker-compose.yaml ├── pyproject.toml ├── sam2 ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── configs │ ├── sam2.1 │ │ ├── sam2.1_hiera_b+.yaml │ │ ├── sam2.1_hiera_l.yaml │ │ ├── sam2.1_hiera_s.yaml │ │ └── sam2.1_hiera_t.yaml │ ├── sam2.1_training │ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml │ └── sam2 │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_l.yaml │ │ ├── sam2_hiera_s.yaml │ │ └── sam2_hiera_t.yaml ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_t.yaml ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ └── transforms.py ├── setup.py ├── tools ├── README.md └── vos_inference.py └── training ├── README.md ├── __init__.py ├── assets ├── MOSE_sample_train_list.txt └── MOSE_sample_val_list.txt ├── dataset ├── __init__.py ├── sam2_datasets.py ├── transforms.py ├── utils.py ├── vos_dataset.py ├── vos_raw_dataset.py ├── vos_sampler.py └── vos_segment_loader.py ├── loss_fns.py ├── model ├── __init__.py └── sam2.py ├── optimizer.py ├── scripts └── sav_frame_extraction_submitit.py ├── train.py ├── trainer.py └── utils ├── __init__.py ├── checkpoint_utils.py ├── data_utils.py ├── distributed.py ├── logger.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/*/ 2 | *checkpoints/ 3 | checkpoints 4 | demo_tmp 5 | davis_results* 6 | tmp/ 7 | tmp*/ 8 | wandb 9 | *.pth 10 | !*error_log.txt 11 | checkpoints* 12 | viser_result 13 | tmp.json 14 | tmp* 15 | avg_error 16 | *.gif 17 | *.mp4 18 | .vscode 19 | .gradio 20 | pyrightconfig.json 21 | results 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | zilejNuE 153 | waymo_processed.zip -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "croco"] 2 | path = croco 3 | url = https://github.com/junyi42/croco 4 | 5 | [submodule "viser"] 6 | path = viser 7 | url = https://github.com/junyi42/viser 8 | -------------------------------------------------------------------------------- /assets/fig1_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/assets/fig1_teaser.png -------------------------------------------------------------------------------- /data/download_bonn.sh: -------------------------------------------------------------------------------- 1 | # bonn 2 | mkdir -p bonn 3 | cd bonn 4 | wget https://www.ipb.uni-bonn.de/html/projects/rgbd_dynamic2019/rgbd_bonn_dataset.zip 5 | unzip rgbd_bonn_dataset.zip 6 | rm rgbd_bonn_dataset.zip 7 | cd .. -------------------------------------------------------------------------------- /data/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../checkpoints/ 2 | gdown --fuzzy https://drive.google.com/file/d/1Z1jO_JmfZj0z3bgMvCwqfUhyZ1bIbc9E/view?usp=sharing -O ../checkpoints/ 3 | # THE original dust3r ckpt 4 | # wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P ../checkpoints/ 5 | 6 | # sea-raft ckpt 7 | cd ../third_party/RAFT 8 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 9 | unzip models.zip 10 | rm models.zip 11 | gdown --fuzzy https://drive.google.com/file/d/1a0C5FTdhjM4rKrfXiGhec7eq2YM141lu/view?usp=drive_link -O models/ 12 | cd ../../data 13 | 14 | # sam2 ckpt 15 | cd ../third_party/sam2 16 | wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt -P checkpoints/ 17 | cd ../../data -------------------------------------------------------------------------------- /data/download_davis.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import requests 3 | import glob 4 | import os 5 | import cv2 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | def download_and_extract_davis(url, extract_to='.'): 10 | local_zip_file = os.path.join(extract_to, 'davis.zip') 11 | 12 | # Download the dataset 13 | print("Downloading DAVIS dataset...") 14 | response = requests.get(url, stream=True) 15 | with open(local_zip_file, 'wb') as file: 16 | for chunk in response.iter_content(chunk_size=1024): 17 | if chunk: 18 | file.write(chunk) 19 | print("Download complete.") 20 | 21 | # Extract the dataset 22 | print("Extracting DAVIS dataset...") 23 | with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: 24 | zip_ref.extractall(extract_to) 25 | print("Extraction complete.") 26 | 27 | # Remove the zip file 28 | os.remove(local_zip_file) 29 | print("Removed the zip file.") 30 | 31 | def create_videos_from_images(image_root, video_root): 32 | if not os.path.exists(video_root): 33 | os.makedirs(video_root) 34 | 35 | # Iterate over each image set 36 | for image_set in os.listdir(image_root): 37 | image_set_path = os.path.join(image_root, image_set) 38 | if os.path.isdir(image_set_path): 39 | images = sorted(glob.glob(os.path.join(image_set_path, '*.jpg'))) 40 | 41 | # Read the first image to get dimensions 42 | frame = cv2.imread(images[0]) 43 | height, width, layers = frame.shape 44 | 45 | # Define the codec and create VideoWriter object 46 | video_file = os.path.join(video_root, f"{image_set}.mp4") 47 | out = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height)) 48 | 49 | for image_file in images: 50 | frame = cv2.imread(image_file) 51 | out.write(frame) 52 | 53 | out.release() 54 | print(f"Video for {image_set} created at {video_file}") 55 | 56 | 57 | davis_url = 'https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip' # Update with the correct URL if needed 58 | extract_path = '../data/davis' 59 | video_output_path = '../data/davis_videos' 60 | 61 | # make directories if they don't exist 62 | if not os.path.exists(extract_path): 63 | os.makedirs(extract_path) 64 | 65 | # Download and extract the dataset 66 | download_and_extract_davis(davis_url, extract_path) 67 | 68 | # Create videos from the image sets 69 | create_videos_from_images(os.path.join(extract_path, 'DAVIS/JPEGImages/480p'), video_output_path) 70 | 71 | # get the mask where fg masked to white 72 | image_paths = glob.glob(r"../data/davis/DAVIS/JPEGImages/480p/*/*.jpg") 73 | image_paths.sort() 74 | print(f"Found {len(image_paths)} images.") 75 | mask_paths = glob.glob(r"../data/davis/DAVIS/Annotations/480p/*/*.png") 76 | mask_paths.sort() 77 | print(f"Found {len(mask_paths)} masks.") 78 | 79 | masked_dir_path = r"../data/davis/DAVIS/masked_images" 80 | os.makedirs(masked_dir_path, exist_ok=True) 81 | 82 | for img_path, mask_path in tqdm(zip(image_paths, mask_paths)): 83 | assert img_path.replace(".jpg", "").replace("JPEGImages", "") == mask_path.replace(".png", "").replace("Annotations", "") 84 | masked_path = img_path.replace(".jpg", ".jpg").replace('JPEGImages', 'masked_images') 85 | 86 | os.makedirs(os.path.dirname(masked_path), exist_ok=True) 87 | 88 | img = cv2.imread(img_path) 89 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 90 | 91 | if img is None: 92 | print(f"Failed to load image {img_path}") 93 | continue 94 | if mask is None: 95 | print(f"Failed to load mask {mask_path}") 96 | continue 97 | 98 | mask = cv2.threshold(mask, 2, 255, cv2.THRESH_BINARY)[1] 99 | mask = cv2.resize(mask, (img.shape[1], img.shape[0])) 100 | masked = np.stack([mask, mask, mask], axis=-1) 101 | success = cv2.imwrite(masked_path, masked) 102 | if success: 103 | # print(f"Saved {masked_path}") 104 | pass 105 | else: 106 | print(f"Failed to save {masked_path}") 107 | -------------------------------------------------------------------------------- /data/download_dynamic_replica.sh: -------------------------------------------------------------------------------- 1 | cd data 2 | mkdir -p dynamic_replica 3 | cd dynamic_replica 4 | 5 | # Generate and loop through the list of URLs 6 | for i in $(seq -w 000 085) 7 | do 8 | # Construct the filename and URL 9 | filename="train_${i}.zip" 10 | url="https://dl.fbaipublicfiles.com/dynamic_replica_v2/train/${filename}" 11 | 12 | # Download the zip file 13 | wget $url 14 | echo "Download of $filename completed" 15 | 16 | # Unzip the file 17 | unzip $filename 18 | echo "Unzipping of $filename completed" 19 | 20 | # Delete any directories ending with 'right' 21 | find . -maxdepth 1 -type d -name '*right' -exec rm -rf {} + 22 | 23 | # Delete the zip file 24 | rm $filename 25 | echo "Deletion of $filename completed" 26 | done 27 | 28 | # process the frame annotations 29 | mv frame_annotations_train.jgz frame_annotations_train.gz 30 | gunzip frame_annotations_train.gz 31 | mv frame_annotations_train frame_annotations_train.json -------------------------------------------------------------------------------- /data/download_kitti.sh: -------------------------------------------------------------------------------- 1 | # kitti 2 | mkdir -p kitti 3 | cd kitti 4 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip 5 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip 6 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0002/2011_09_26_drive_0002_sync.zip 7 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0005/2011_09_26_drive_0005_sync.zip 8 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0013/2011_09_26_drive_0013_sync.zip 9 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0020/2011_09_26_drive_0020_sync.zip 10 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0023/2011_09_26_drive_0023_sync.zip 11 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0036/2011_09_26_drive_0036_sync.zip 12 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0079/2011_09_26_drive_0079_sync.zip 13 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0095/2011_09_26_drive_0095_sync.zip 14 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0113/2011_09_26_drive_0113_sync.zip 15 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_drive_0037/2011_09_28_drive_0037_sync.zip 16 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0026/2011_09_29_drive_0026_sync.zip 17 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0016/2011_09_30_drive_0016_sync.zip 18 | wget https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0047/2011_10_03_drive_0047_sync.zip 19 | find . -name "*.zip" -exec unzip -o -q {} \; 20 | # remove all zip files 21 | find . -name "*.zip" -exec rm {} \; 22 | cd .. 23 | -------------------------------------------------------------------------------- /data/download_nyuv2.sh: -------------------------------------------------------------------------------- 1 | # nyu-v2 2 | mkdir -p nyu_v2 3 | cd nyu_v2 4 | wget https://huggingface.co/datasets/sayakpaul/nyu_depth_v2/resolve/main/data/val-000000.tar -O val-000000.tar 5 | wget https://huggingface.co/datasets/sayakpaul/nyu_depth_v2/resolve/main/data/val-000001.tar -O val-000001.tar 6 | # unzip all 7 | find . -name "*.tar" -exec tar -xvf {} \; 8 | -------------------------------------------------------------------------------- /data/download_pointodyssey.sh: -------------------------------------------------------------------------------- 1 | # Download point_odyssey 2 | mkdir -p point_odyssey 3 | cd point_odyssey 4 | # train 5 | gdown --id 1ivaHRZV6iwxxH4qk8IAIyrOF9jrppDIP 6 | # test 7 | gdown --id 1jn8l28BBNw9f9wYFmd5WOCERH48-GsgB 8 | # sample 9 | gdown --id 1dnl9XMImdwKX2KcZCTuVDhcy5h8qzQIO 10 | # unzip all *.tar.gz 11 | find . -name "*.tar.gz" -exec tar -zxvf {} \; 12 | # remove all zip files 13 | find . -name "*.tar.gz" -exec rm {} \; -------------------------------------------------------------------------------- /data/download_scannetv2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p scannetv2 4 | cd scannetv2 5 | # download the sens http://kaldir.vc.in.tum.de/scannet/v2/scans/scene0707_00/scene0707_00.sens from scene0707_00 to scene0806_00 6 | for i in {707..806}; do 7 | wget http://kaldir.vc.in.tum.de/scannet/v2/scans/scene0${i}_00/scene0${i}_00.sens 8 | done 9 | cd ../.. 10 | 11 | # Set the number of threads 12 | THREADS=4 13 | 14 | # Define the function to process each scene 15 | process_scene() { 16 | scene_id=$(printf "%04d" $1) # Format the scene ID 17 | filename="data/scannetv2/scene${scene_id}_00.sens" 18 | output_path="data/scannetv2/scene${scene_id}_00" 19 | 20 | # Run the data processing command 21 | python datasets_preprocess/scannet_sens_reader.py --filename $filename --output_path $output_path 22 | 23 | # Delete the .sens file 24 | rm -rf $filename 25 | } 26 | 27 | export -f process_scene # Export the function for use by xargs 28 | 29 | # Use seq -w to generate numbers in the range 0707-0806 and use xargs for multi-threading 30 | seq -w 707 806 | xargs -n 1 -P $THREADS -I {} bash -c 'process_scene "$@"' _ {} 31 | 32 | echo "All scenes have been processed." 33 | -------------------------------------------------------------------------------- /data/download_sintel.sh: -------------------------------------------------------------------------------- 1 | # Download Sintel 2 | mkdir -p sintel 3 | cd sintel 4 | # images 5 | wget --no-proxy http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_images.zip 6 | # depth & cameras 7 | wget --no-proxy http://files.is.tue.mpg.de/jwulff/sintel/MPI-Sintel-depth-training-20150305.zip 8 | # flow 9 | wget --no-proxy http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_extras.zip 10 | # unzip all 11 | find . -name "*.zip" -exec unzip -o -q {} \; 12 | # remove all zip files 13 | find . -name "*.zip" -exec rm {} \; 14 | cd .. 15 | 16 | # # preprocess the dynamic labels 17 | # conda activate monst3r 18 | # cd .. 19 | # python datasets_preprocess/sintel_get_dynamics.py --threshold 0.1 --save_dir dynamic_label_perfect 20 | -------------------------------------------------------------------------------- /data/download_spring.sh: -------------------------------------------------------------------------------- 1 | # download spring dataset 2 | gdown --folder https://drive.google.com/drive/folders/1oJqS7YOqtgO6l4WI_fdCZ-Jvp2RUvHZz?usp=sharing -O spring 3 | cd spring 4 | # unzip all 5 | find . -name "*.zip" -exec unzip -o -q {} \; 6 | # remove all zip files 7 | find . -name "*.zip" -exec rm {} \; 8 | # move data/spring/spring to data/spring 9 | mv spring/* . 10 | rm -rf spring -------------------------------------------------------------------------------- /data/download_tartanair.sh: -------------------------------------------------------------------------------- 1 | TARGET_DIR="tartanair" 2 | mkdir -p "$TARGET_DIR" 3 | 4 | python download_tartanair.py --output-dir $TARGET_DIR --rgb --only-left --depth --only-hard 5 | 6 | # Find and unzip all zip files 7 | find "$TARGET_DIR" -type f -name "*.zip" -print0 | while IFS= read -r -d '' zipfile; do 8 | # Get the directory of the zip file 9 | zipdir=$(dirname "$zipfile") 10 | 11 | echo "Unzipping $zipfile to $zipdir" 12 | # Unzip to the respective directory, automatically overwrite existing files 13 | unzip -o -q "$zipfile" -d "$zipdir" 14 | 15 | # Check if the unzip was successful 16 | if [ $? -eq 0 ]; then 17 | echo "Deleting $zipfile" 18 | # Delete the zip file 19 | rm "$zipfile" 20 | else 21 | echo "Failed to unzip $zipfile" 22 | fi 23 | done -------------------------------------------------------------------------------- /data/download_tum_dynamics.sh: -------------------------------------------------------------------------------- 1 | # download tum-dynamic dataset 2 | mkdir -p tum 3 | cd tum 4 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_sitting_static.tgz 5 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_sitting_xyz.tgz 6 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_sitting_halfsphere.tgz 7 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_sitting_rpy.tgz 8 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_walking_static.tgz 9 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_walking_xyz.tgz 10 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_walking_halfsphere.tgz 11 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_walking_rpy.tgz 12 | 13 | # unzip all 14 | find . -name "*.tgz" -exec tar -zxvf {} \; 15 | # remove all zip files 16 | find . -name "*.tgz" -exec rm {} \; 17 | -------------------------------------------------------------------------------- /data/download_waymo.sh: -------------------------------------------------------------------------------- 1 | # download waymo dataset 2 | 3 | mkdir -p waymo 4 | cd waymo 5 | gsutil -m cp -r gs://waymo_open_dataset_v_1_4_2/individual_files/training/ . 6 | wget --no-proxy https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz 7 | cd .. -------------------------------------------------------------------------------- /data/prepare_training.md: -------------------------------------------------------------------------------- 1 | 2 | # Dataset Preparation for Training 3 | 4 | We provide scripts to prepare datasets for training, including **PointOdyssey**, **TartanAir**, **Spring**, and **Waymo**. For evaluation, we also provide a script for preparing the **Sintel** dataset. 5 | 6 | > [!NOTE] 7 | > The scripts provided here are for reference only. Please ensure you have obtained the necessary licenses from the original dataset providers before proceeding. 8 | 9 | ## Download Pre-Trained Models 10 | To download the pre-trained models, run the following commands: 11 | ```bash 12 | cd data 13 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P ../checkpoints/ 14 | cd .. 15 | ``` 16 | 17 | ## Dataset Setup 18 | 19 | ### PointOdyssey 20 | To download and prepare the **PointOdyssey** dataset, execute: 21 | ```bash 22 | cd data 23 | bash download_pointodyssey.sh 24 | cd .. 25 | ``` 26 | 27 | ### TartanAir 28 | To download and prepare the **TartanAir** dataset, execute: 29 | ```bash 30 | cd data 31 | bash download_tartanair.sh 32 | cd .. 33 | ``` 34 | 35 | ### Spring 36 | To download and prepare the **Spring** dataset, execute: 37 | ```bash 38 | cd data 39 | bash download_spring.sh 40 | cd .. 41 | ``` 42 | 43 | ### Waymo 44 | To download and prepare the **Waymo** dataset, follow these steps: 45 | 46 | 1. Set up Google Cloud SDK (if you haven't done so already): 47 | ```bash 48 | curl https://sdk.cloud.google.com | bash 49 | exec -l $SHELL 50 | gcloud init 51 | gcloud auth login 52 | ``` 53 | 54 | 2. Download the Waymo dataset: 55 | ```bash 56 | cd data 57 | bash download_waymo.sh 58 | cd .. 59 | ``` 60 | 61 | 3. Preprocess the dataset and create training pairs: 62 | ```bash 63 | python datasets_preprocess/preprocess_waymo.py 64 | python datasets_preprocess/waymo_make_pairs.py 65 | ``` 66 | 67 | ## Sintel (Evaluation) 68 | To download and prepare the **Sintel** dataset for evaluation, execute: 69 | ```bash 70 | cd data 71 | bash download_sintel.sh 72 | cd .. 73 | ``` 74 | -------------------------------------------------------------------------------- /datasets_preprocess/habitat/README.md: -------------------------------------------------------------------------------- 1 | ## Steps to reproduce synthetic training data using the Habitat-Sim simulator 2 | 3 | ### Create a conda environment 4 | ```bash 5 | conda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge 6 | conda active habitat 7 | conda install pytorch -c pytorch 8 | pip install opencv-python tqdm 9 | ``` 10 | 11 | or (if you get the error `For headless systems, compile with --headless for EGL support`) 12 | ``` 13 | git clone --branch stable https://github.com/facebookresearch/habitat-sim.git 14 | cd habitat-sim 15 | 16 | conda create -n habitat python=3.9 cmake=3.14.0 17 | conda activate habitat 18 | pip install . -v 19 | conda install pytorch -c pytorch 20 | pip install opencv-python tqdm 21 | ``` 22 | 23 | ### Download Habitat-Sim scenes 24 | Download Habitat-Sim scenes: 25 | - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md 26 | - We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets. 27 | - Please put the scenes in a directory `$SCENES_DIR` following the structure below: 28 | (Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting. 29 | The correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`). 30 | ``` 31 | $SCENES_DIR/ 32 | ├──hm3d/ 33 | ├──gibson/ 34 | ├──habitat-test-scenes/ 35 | ├──ReplicaCAD_baked_lighting/ 36 | └──scannet/ 37 | ``` 38 | 39 | ### Download renderings metadata 40 | 41 | Download metadata corresponding to each scene and extract them into a directory `$METADATA_DIR` 42 | ```bash 43 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz 44 | tar -xvzf habitat_5views_v1_512x512_metadata.tar.gz 45 | ``` 46 | 47 | ### Render the scenes 48 | 49 | Render the scenes in an output directory `$OUTPUT_DIR` 50 | ```bash 51 | export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" 52 | export SCENES_DIR="/path/to/habitat/data/scene_datasets/" 53 | export OUTPUT_DIR="data/habitat_processed" 54 | cd datasets_preprocess/habitat/ 55 | export PYTHONPATH=$(pwd) 56 | # Print commandlines to generate images corresponding to each scene 57 | python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR 58 | # Launch these commandlines in parallel e.g. using GNU-Parallel as follows: 59 | python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 60 | ``` 61 | 62 | ### Make a list of scenes 63 | 64 | ```bash 65 | python find_scenes.py --root $OUTPUT_DIR 66 | ``` -------------------------------------------------------------------------------- /datasets_preprocess/habitat/find_scenes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 3 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | # 5 | # -------------------------------------------------------- 6 | # Script to export the list of scenes for habitat (after having rendered them). 7 | # Usage: 8 | # python3 datasets_preprocess/preprocess_co3d.py --root data/habitat_processed 9 | # -------------------------------------------------------- 10 | import numpy as np 11 | import os 12 | from collections import defaultdict 13 | from tqdm import tqdm 14 | 15 | 16 | def find_all_scenes(habitat_root, n_scenes=[100000]): 17 | np.random.seed(777) 18 | 19 | try: 20 | fpath = os.path.join(habitat_root, f'Habitat_all_scenes.txt') 21 | list_subscenes = open(fpath).read().splitlines() 22 | 23 | except IOError: 24 | if input('parsing sub-folders to find scenes? (y/n) ') != 'y': 25 | return 26 | list_subscenes = [] 27 | for root, dirs, files in tqdm(os.walk(habitat_root)): 28 | for f in files: 29 | if not f.endswith('_1_depth.exr'): 30 | continue 31 | scene = os.path.join(os.path.relpath(root, habitat_root), f.replace('_1_depth.exr', '')) 32 | if hash(scene) % 1000 == 0: 33 | print('... adding', scene) 34 | list_subscenes.append(scene) 35 | 36 | with open(fpath, 'w') as f: 37 | f.write('\n'.join(list_subscenes)) 38 | print(f'>> wrote {fpath}') 39 | 40 | print(f'Loaded {len(list_subscenes)} sub-scenes') 41 | 42 | # separate scenes 43 | list_scenes = defaultdict(list) 44 | for scene in list_subscenes: 45 | scene, id = os.path.split(scene) 46 | list_scenes[scene].append(id) 47 | 48 | list_scenes = list(list_scenes.items()) 49 | print(f'from {len(list_scenes)} scenes in total') 50 | 51 | np.random.shuffle(list_scenes) 52 | train_scenes = list_scenes[len(list_scenes)//10:] 53 | val_scenes = list_scenes[:len(list_scenes)//10] 54 | 55 | def write_scene_list(scenes, n, fpath): 56 | sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids] 57 | np.random.shuffle(sub_scenes) 58 | 59 | if len(sub_scenes) < n: 60 | return 61 | 62 | with open(fpath, 'w') as f: 63 | f.write('\n'.join(sub_scenes[:n])) 64 | print(f'>> wrote {fpath}') 65 | 66 | for n in n_scenes: 67 | write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt')) 68 | write_scene_list(val_scenes, n//10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt')) 69 | 70 | 71 | if __name__ == "__main__": 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--root", required=True) 75 | parser.add_argument("--n_scenes", nargs='+', default=[1_000, 1_000_000], type=int) 76 | 77 | args = parser.parse_args() 78 | find_all_scenes(args.root, args.n_scenes) 79 | -------------------------------------------------------------------------------- /datasets_preprocess/habitat/habitat_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /datasets_preprocess/habitat/habitat_renderer/projections_conversions.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Remap data from one projection to an other 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import cv2 9 | from habitat_renderer import projections 10 | 11 | class RemapProjection: 12 | def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0): 13 | """ 14 | Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts. 15 | """ 16 | assert jittering_noise_level >= 0 17 | assert pixel_jittering_iterations >= 0 18 | 19 | maps = [] 20 | # Initial map 21 | self.output_rays = projections.get_projection_rays(output_projection) 22 | map_u, map_v = input_projection.project(self.output_rays) 23 | map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) 24 | maps.append((map_u, map_v)) 25 | 26 | for _ in range(pixel_jittering_iterations): 27 | # Define multiple mappings using some coordinates jittering to mitigate aliasing effects 28 | crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level) 29 | map_u, map_v = input_projection.project(crop_rays) 30 | map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) 31 | maps.append((map_u, map_v)) 32 | self.maps = maps 33 | 34 | def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False): 35 | remapped = [] 36 | for map_u, map_v in self.maps: 37 | res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode) 38 | remapped.append(res) 39 | if single_map: 40 | break 41 | if len(remapped) == 1: 42 | res = remapped[0] 43 | else: 44 | res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype) 45 | return res 46 | -------------------------------------------------------------------------------- /datasets_preprocess/path_to_root.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # DUSt3R repo root import 6 | # -------------------------------------------------------- 7 | 8 | import sys 9 | import os.path as path 10 | HERE_PATH = path.normpath(path.dirname(__file__)) 11 | DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) 12 | # workaround for sibling import 13 | sys.path.insert(0, DUST3R_REPO_PATH) 14 | -------------------------------------------------------------------------------- /datasets_preprocess/prepare_bonn.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | import os 4 | import shutil 5 | dirs = glob.glob("../data/bonn/rgbd_bonn_dataset/*/") 6 | dirs = sorted(dirs) 7 | # extract frames 8 | for dir in dirs: 9 | frames = glob.glob(dir + 'rgb/*.png') 10 | frames = sorted(frames) 11 | # sample 110 frames at the stride of 2 12 | frames = frames[30:140] 13 | # cut frames after 110 14 | new_dir = dir + 'rgb_110/' 15 | 16 | for frame in frames: 17 | os.makedirs(new_dir, exist_ok=True) 18 | shutil.copy(frame, new_dir) 19 | # print(f'cp {frame} {new_dir}') 20 | 21 | depth_frames = glob.glob(dir + 'depth/*.png') 22 | depth_frames = sorted(depth_frames) 23 | # sample 110 frames at the stride of 2 24 | depth_frames = depth_frames[30:140] 25 | # cut frames after 110 26 | new_dir = dir + 'depth_110/' 27 | 28 | for frame in depth_frames: 29 | os.makedirs(new_dir, exist_ok=True) 30 | shutil.copy(frame, new_dir) 31 | # print(f'cp {frame} {new_dir}') 32 | import numpy as np 33 | for dir in dirs: 34 | gt_path = "groundtruth.txt" 35 | gt = np.loadtxt(dir + gt_path) 36 | gt_110 = gt[30:140] 37 | np.savetxt(dir + 'groundtruth_110.txt', gt_110) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /datasets_preprocess/prepare_kitti.py: -------------------------------------------------------------------------------- 1 | # %% 2 | #!/usr/bin/python 3 | 4 | from PIL import Image 5 | import numpy as np 6 | 7 | 8 | def depth_read(filename): 9 | # loads depth map D from png file 10 | # and returns it as a numpy array, 11 | # for details see readme.txt 12 | 13 | depth_png = np.array(Image.open(filename), dtype=int) 14 | # make sure we have a proper 16bit depth map here.. not 8bit! 15 | assert(np.max(depth_png) > 255) 16 | 17 | depth = depth_png.astype(np.float) / 256. 18 | depth[depth_png == 0] = -1. 19 | return depth 20 | 21 | # %% 22 | import glob 23 | import os 24 | import shutil 25 | depth_dirs = glob.glob("../data/kitti/val/*/proj_depth/groundtruth/image_02") 26 | for dir in depth_dirs: 27 | # new depth dir 28 | new_depth_dir = "../data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/" + dir.split("/")[-4]+"_02" 29 | # print(new_depth_dir) 30 | new_image_dir = "../data/kitti/depth_selection/val_selection_cropped/image_gathered/" + dir.split("/")[-4]+"_02" 31 | os.makedirs(new_depth_dir, exist_ok=True) 32 | os.makedirs(new_image_dir, exist_ok=True) 33 | for depth_file in sorted(glob.glob(dir + "/*.png"))[:110]: #../data/kitti/val/2011_09_26_drive_0002_sync/proj_depth/groundtruth/image_02/0000000005.png 34 | new_path = new_depth_dir + "/" + depth_file.split("/")[-1] 35 | shutil.copy(depth_file, new_path) 36 | # get the path of the corresponding image 37 | mid = "_".join(depth_file.split("/")[4].split("_")[:3]) 38 | image_file = depth_file.replace('val', mid).replace('proj_depth/groundtruth/image_02', 'image_02/data') 39 | print(image_file) 40 | # check if the image file exists 41 | if os.path.exists(image_file): 42 | new_path = new_image_dir + "/" + image_file.split("/")[-1] 43 | shutil.copy(image_file, new_path) 44 | else: 45 | print("Image file does not exist: ", image_file) 46 | 47 | 48 | -------------------------------------------------------------------------------- /datasets_preprocess/prepare_nyuv2.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import h5py 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | from PIL import Image 7 | 8 | # Set the path to your dataset directory 9 | dataset_dir = '../data/nyu-v2/val/official/' 10 | 11 | # Get a list of all .h5 files in the dataset directory 12 | file_paths = glob(os.path.join(dataset_dir, '*.h5')) 13 | 14 | # Create output directories for images and depth data 15 | output_image_dir = '../data/nyu-v2/val/nyu_images/' 16 | output_depth_dir = '../data/nyu-v2/val/nyu_depths/' 17 | os.makedirs(output_image_dir, exist_ok=True) 18 | os.makedirs(output_depth_dir, exist_ok=True) 19 | 20 | for file_path in file_paths: 21 | with h5py.File(file_path, 'r') as h5file: 22 | # Read depth and rgb data 23 | depth_data = h5file['depth'][:] 24 | rgb_data = h5file['rgb'][:] 25 | 26 | # Convert rgb data from (3, H, W) to (H, W, 3) 27 | rgb_data = np.transpose(rgb_data, (1, 2, 0)) 28 | 29 | # Ensure that rgb_data is of type uint8 30 | if rgb_data.dtype != np.uint8: 31 | rgb_data = rgb_data.astype(np.uint8) 32 | 33 | # Get the base filename without extension 34 | base_name = os.path.splitext(os.path.basename(file_path))[0] 35 | 36 | # Save the RGB image as PNG 37 | rgb_image = Image.fromarray(rgb_data) 38 | rgb_image.save(os.path.join(output_image_dir, f'{base_name}.png')) 39 | 40 | # Save the depth data as NPY file 41 | np.save(os.path.join(output_depth_dir, f'{base_name}.npy'), depth_data) 42 | 43 | print(f'Processed {base_name}') 44 | 45 | 46 | # %% 47 | import os 48 | import numpy as np 49 | from PIL import Image 50 | 51 | # Paths 52 | depth_npy_dir = '../data/nyu-v2/val/nyu_depths' 53 | output_img_dir = '../data/nyu-v2/val/nyu_depth_imgs' 54 | 55 | # Ensure the output directory exists 56 | os.makedirs(output_img_dir, exist_ok=True) 57 | 58 | # Iterate over all .npy files in the depth directory 59 | for npy_file in os.listdir(depth_npy_dir): 60 | if npy_file.endswith('.npy'): 61 | # Load depth data from .npy file 62 | depth_path = os.path.join(depth_npy_dir, npy_file) 63 | depth_data = np.load(depth_path) 64 | 65 | # Normalize depth data to range [0, 255] for saving as an image 66 | depth_min = depth_data.min() 67 | depth_max = depth_data.max() 68 | depth_normalized = (depth_data - depth_min) / (depth_max - depth_min) 69 | depth_uint8 = (depth_normalized * 255).astype(np.uint8) 70 | 71 | # Convert to an image 72 | depth_img = Image.fromarray(depth_uint8) 73 | 74 | # Save as PNG file 75 | img_name = os.path.splitext(npy_file)[0] + '.png' 76 | img_save_path = os.path.join(output_img_dir, img_name) 77 | depth_img.save(img_save_path) 78 | 79 | print(f'Saved {img_save_path}') 80 | 81 | print("Conversion completed!") 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /datasets_preprocess/prepare_scannet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | seq_list = sorted(os.listdir("../data/scannetv2")) 7 | for seq in seq_list: 8 | img_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/color/*.jpg"), key=lambda x: int(os.path.basename(x).split('.')[0])) 9 | depth_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/depth/*.png"), key=lambda x: int(os.path.basename(x).split('.')[0])) 10 | pose_pathes = sorted(glob.glob(f"../data/scannetv2/{seq}/pose/*.txt"), key=lambda x: int(os.path.basename(x).split('.')[0])) 11 | print(f"{seq}: {len(img_pathes)} {len(depth_pathes)}") 12 | 13 | new_color_dir = f"../data/scannetv2/{seq}/color_90" 14 | new_depth_dir = f"../data/scannetv2/{seq}/depth_90" 15 | 16 | new_img_pathes = img_pathes[:90*3:3] 17 | new_depth_pathes = depth_pathes[:90*3:3] 18 | new_pose_pathes = pose_pathes[:90*3:3] 19 | 20 | os.makedirs(new_color_dir, exist_ok=True) 21 | os.makedirs(new_depth_dir, exist_ok=True) 22 | 23 | for i, (img_path, depth_path) in enumerate(zip(new_img_pathes, new_depth_pathes)): 24 | shutil.copy(img_path, f"{new_color_dir}/frame_{i:04d}.jpg") 25 | shutil.copy(depth_path, f"{new_depth_dir}/frame_{i:04d}.png") 26 | 27 | pose_new_path = f"../data/scannetv2/{seq}/pose_90.txt" 28 | with open(pose_new_path, 'w') as f: 29 | for i, pose_path in enumerate(new_pose_pathes): 30 | with open(pose_path, 'r') as pose_file: 31 | pose = np.loadtxt(pose_file) 32 | pose = pose.reshape(-1) 33 | f.write(f"{' '.join(map(str, pose))}\n") 34 | -------------------------------------------------------------------------------- /datasets_preprocess/prepare_tum.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | def read_file_list(filename): 7 | """ 8 | Reads a trajectory from a text file. 9 | 10 | File format: 11 | The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched) 12 | and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp. 13 | 14 | Input: 15 | filename -- File name 16 | 17 | Output: 18 | dict -- dictionary of (stamp,data) tuples 19 | 20 | """ 21 | file = open(filename) 22 | data = file.read() 23 | lines = data.replace(","," ").replace("\t"," ").split("\n") 24 | list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"] 25 | list = [(float(l[0]),l[1:]) for l in list if len(l)>1] 26 | return dict(list) 27 | 28 | def associate(first_list, second_list, offset, max_difference): 29 | """ 30 | Associate two dictionaries of (stamp, data). As the time stamps never match exactly, we aim 31 | to find the closest match for every input tuple. 32 | 33 | Input: 34 | first_list -- first dictionary of (stamp, data) tuples 35 | second_list -- second dictionary of (stamp, data) tuples 36 | offset -- time offset between both dictionaries (e.g., to model the delay between the sensors) 37 | max_difference -- search radius for candidate generation 38 | 39 | Output: 40 | matches -- list of matched tuples ((stamp1, data1), (stamp2, data2)) 41 | """ 42 | # Convert keys to sets for efficient removal 43 | first_keys = set(first_list.keys()) 44 | second_keys = set(second_list.keys()) 45 | 46 | potential_matches = [(abs(a - (b + offset)), a, b) 47 | for a in first_keys 48 | for b in second_keys 49 | if abs(a - (b + offset)) < max_difference] 50 | potential_matches.sort() 51 | matches = [] 52 | for diff, a, b in potential_matches: 53 | if a in first_keys and b in second_keys: 54 | first_keys.remove(a) 55 | second_keys.remove(b) 56 | matches.append((a, b)) 57 | 58 | matches.sort() 59 | return matches 60 | 61 | dirs = glob.glob("../data/tum/*/") 62 | dirs = sorted(dirs) 63 | # extract frames 64 | for dir in dirs: 65 | frames = [] 66 | gt = [] 67 | first_file = dir + 'rgb.txt' 68 | second_file = dir + 'groundtruth.txt' 69 | 70 | first_list = read_file_list(first_file) 71 | second_list = read_file_list(second_file) 72 | matches = associate(first_list, second_list, 0.0, 0.02) 73 | 74 | # for a,b in matches[:10]: 75 | # print("%f %s %f %s"%(a," ".join(first_list[a]),b," ".join(second_list[b]))) 76 | for a,b in matches: 77 | frames.append(dir + first_list[a][0]) 78 | gt.append([b]+second_list[b]) 79 | 80 | # sample 90 frames at the stride of 3 81 | frames = frames[::3][:90] 82 | # cut frames after 90 83 | new_dir = dir + 'rgb_90/' 84 | 85 | for frame in frames: 86 | os.makedirs(new_dir, exist_ok=True) 87 | shutil.copy(frame, new_dir) 88 | # print(f'cp {frame} {new_dir}') 89 | 90 | gt_90 = gt[::3][:90] 91 | with open(dir + 'groundtruth_90.txt', 'w') as f: 92 | for pose in gt_90: 93 | f.write(f"{' '.join(map(str, pose))}\n") -------------------------------------------------------------------------------- /datasets_preprocess/waymo_make_pairs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from tqdm import tqdm 4 | import cv2 5 | import numpy as np 6 | 7 | import numpy as np 8 | 9 | file_path = "data/waymo/waymo_pairs.npz" 10 | 11 | data = np.load(file_path) 12 | # data.files # ['scenes', 'frames', 'pairs'] 13 | 14 | scenes, frames, pairs = data['scenes'], data['frames'], data['pairs'] 15 | 16 | new_scenes = glob.glob("data/waymo_processed/*.tfrecord/") 17 | new_scenes_last = [scene.split("/")[-2] for scene in new_scenes] 18 | img_lens = [] 19 | for path in tqdm(new_scenes): 20 | imgs = glob.glob(path + "/*.jpg") 21 | img_lens.append(len(imgs)) 22 | 23 | new_frames = list(frames) 24 | new_pairs = [] 25 | strides = [1,2,3,4,5,6,7,8,9] 26 | step = 1 27 | for path in tqdm(new_scenes): 28 | imgs_track1 = glob.glob(path + "/*_1.jpg") 29 | imgs_track1.sort() 30 | imgs_track2 = glob.glob(path + "/*_2.jpg") 31 | imgs_track2.sort() 32 | imgs_track3 = glob.glob(path + "/*_3.jpg") 33 | imgs_track3.sort() 34 | imgs_track4 = glob.glob(path + "/*_4.jpg") 35 | imgs_track4.sort() 36 | imgs_track5 = glob.glob(path + "/*_5.jpg") 37 | imgs_track5.sort() 38 | for stride in strides: 39 | for i in range(0, len(imgs_track1)-stride, step): 40 | if os.path.exists(imgs_track1[i+stride]) and os.path.exists(imgs_track1[i]): 41 | new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track1[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track1[i+stride].split('/')[-1].replace('.jpg',''))]) 42 | for i in range(0, len(imgs_track2)-stride, step): 43 | if os.path.exists(imgs_track2[i+stride]) and os.path.exists(imgs_track2[i]): 44 | new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track2[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track2[i+stride].split('/')[-1].replace('.jpg',''))]) 45 | for i in range(0, len(imgs_track3)-stride, step): 46 | if os.path.exists(imgs_track3[i+stride]) and os.path.exists(imgs_track3[i]): 47 | new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track3[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track3[i+stride].split('/')[-1].replace('.jpg',''))]) 48 | for i in range(0, len(imgs_track4)-stride, step): 49 | if os.path.exists(imgs_track4[i+stride]) and os.path.exists(imgs_track4[i]): 50 | new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track4[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track4[i+stride].split('/')[-1].replace('.jpg',''))]) 51 | for i in range(0, len(imgs_track5)-stride, step): 52 | if os.path.exists(imgs_track5[i+stride]) and os.path.exists(imgs_track5[i]): 53 | new_pairs.append([new_scenes_last.index(path.split("/")[-2]), new_frames.index(imgs_track5[i].split('/')[-1].replace('.jpg','')), new_frames.index(imgs_track5[i+stride].split('/')[-1].replace('.jpg',''))]) 54 | 55 | print(len(new_pairs), "pairs") 56 | save_path = "data/waymo_processed/waymo_pairs_video.npz" 57 | np.savez(save_path, scenes=np.array(new_scenes_last), frames=np.array(new_frames), pairs=np.array(new_pairs)) -------------------------------------------------------------------------------- /demo_data/lady-running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running.mp4 -------------------------------------------------------------------------------- /demo_data/lady-running/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00000.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00001.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00002.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00003.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00004.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00005.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00006.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00007.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00008.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00009.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00010.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00011.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00012.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00013.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00014.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00015.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00016.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00017.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00018.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00019.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00020.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00021.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00022.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00023.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00024.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00025.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00026.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00027.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00028.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00029.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00030.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00031.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00032.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00033.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00034.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00035.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00036.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00037.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00038.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00039.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00040.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00041.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00042.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00043.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00044.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00045.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00046.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00046.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00047.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00047.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00048.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00049.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00050.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00051.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00052.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00053.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00054.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00055.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00056.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00057.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00057.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00058.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00058.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00059.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00059.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00060.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00061.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00062.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00063.jpg -------------------------------------------------------------------------------- /demo_data/lady-running/00064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/demo_data/lady-running/00064.jpg -------------------------------------------------------------------------------- /dust3r/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /dust3r/cloud_opt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # global alignment optimization wrapper function 6 | # -------------------------------------------------------- 7 | from enum import Enum 8 | 9 | from .optimizer import PointCloudOptimizer 10 | from .modular_optimizer import ModularPointCloudOptimizer 11 | from .pair_viewer import PairViewer 12 | 13 | 14 | class GlobalAlignerMode(Enum): 15 | PointCloudOptimizer = "PointCloudOptimizer" 16 | ModularPointCloudOptimizer = "ModularPointCloudOptimizer" 17 | PairViewer = "PairViewer" 18 | 19 | 20 | def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): 21 | # extract all inputs 22 | view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] 23 | # build the optimizer 24 | if mode == GlobalAlignerMode.PointCloudOptimizer: 25 | net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) 26 | elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: 27 | net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) 28 | elif mode == GlobalAlignerMode.PairViewer: 29 | net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) 30 | else: 31 | raise NotImplementedError(f'Unknown mode {mode}') 32 | 33 | return net 34 | -------------------------------------------------------------------------------- /dust3r/cloud_opt/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utility functions for global alignment 6 | # -------------------------------------------------------- 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | from scipy.stats import zscore 11 | 12 | def edge_str(i, j): 13 | return f'{i}_{j}' 14 | 15 | 16 | def i_j_ij(ij): 17 | # inputs are (i, j) 18 | return edge_str(*ij), ij 19 | 20 | 21 | def edge_conf(conf_i, conf_j, edge): 22 | 23 | score = float(conf_i[edge].mean() * conf_j[edge].mean()) 24 | 25 | return score 26 | 27 | 28 | def compute_edge_scores(edges, conf_i, conf_j): 29 | score_dict = {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} 30 | 31 | return score_dict 32 | 33 | def NoGradParamDict(x): 34 | assert isinstance(x, dict) 35 | return nn.ParameterDict(x).requires_grad_(False) 36 | 37 | 38 | def get_imshapes(edges, pred_i, pred_j): 39 | n_imgs = max(max(e) for e in edges) + 1 40 | imshapes = [None] * n_imgs 41 | for e, (i, j) in enumerate(edges): 42 | shape_i = tuple(pred_i[e].shape[0:2]) 43 | shape_j = tuple(pred_j[e].shape[0:2]) 44 | if imshapes[i]: 45 | assert imshapes[i] == shape_i, f'incorrect shape for image {i}' 46 | if imshapes[j]: 47 | assert imshapes[j] == shape_j, f'incorrect shape for image {j}' 48 | imshapes[i] = shape_i 49 | imshapes[j] = shape_j 50 | return imshapes 51 | 52 | 53 | def get_conf_trf(mode): 54 | if mode == 'log': 55 | def conf_trf(x): return x.log() 56 | elif mode == 'sqrt': 57 | def conf_trf(x): return x.sqrt() 58 | elif mode == 'm1': 59 | def conf_trf(x): return x-1 60 | elif mode in ('id', 'none'): 61 | def conf_trf(x): return x 62 | else: 63 | raise ValueError(f'bad mode for {mode=}') 64 | return conf_trf 65 | 66 | 67 | def l2_dist(a, b, weight): 68 | return ((a - b).square().sum(dim=-1) * weight) 69 | 70 | 71 | def l1_dist(a, b, weight): 72 | return ((a - b).norm(dim=-1) * weight) 73 | 74 | 75 | ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) 76 | 77 | 78 | def signed_log1p(x): 79 | sign = torch.sign(x) 80 | return sign * torch.log1p(torch.abs(x)) 81 | 82 | 83 | def signed_expm1(x): 84 | sign = torch.sign(x) 85 | return sign * torch.expm1(torch.abs(x)) 86 | 87 | 88 | def cosine_schedule(t, lr_start, lr_end): 89 | assert 0 <= t <= 1 90 | return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 91 | 92 | 93 | def linear_schedule(t, lr_start, lr_end): 94 | assert 0 <= t <= 1 95 | return lr_start + (lr_end - lr_start) * t 96 | 97 | def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): 98 | assert 0 <= t <= 1 99 | cycle_t = t * num_cycles 100 | cycle_t = cycle_t - int(cycle_t) 101 | if t == 1: 102 | cycle_t = 1 103 | return linear_schedule(cycle_t, lr_start, lr_end) -------------------------------------------------------------------------------- /dust3r/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | from .utils.transforms import * 4 | from .base.batched_sampler import BatchedRandomSampler # noqa 5 | from .arkitscenes import ARKitScenes # noqa 6 | from .blendedmvs import BlendedMVS # noqa 7 | from .co3d import Co3d # noqa 8 | from .habitat import Habitat # noqa 9 | from .megadepth import MegaDepth # noqa 10 | from .scannetpp import ScanNetpp # noqa 11 | from .staticthings3d import StaticThings3D # noqa 12 | from .waymo import Waymo # noqa 13 | from .wildrgbd import WildRGBD # noqa 14 | from .pointodyssey import PointOdysseyDUSt3R # noqa 15 | from .sintel import SintelDUSt3R # noqa 16 | from .tartanair import TarTanAirDUSt3R # noqa 17 | from .spring_dataset import SpringDUSt3R # noqa 18 | from .dynamic_replica import DynamicReplicaDUSt3R # noqa 19 | 20 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): 21 | import torch 22 | from croco.utils.misc import get_world_size, get_rank 23 | 24 | # pytorch dataset 25 | if isinstance(dataset, str): 26 | dataset = eval(dataset) 27 | # dataset: "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" 28 | # eval(dataset) returns Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) 29 | 30 | world_size = get_world_size() 31 | rank = get_rank() 32 | 33 | try: 34 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, 35 | rank=rank, drop_last=drop_last) 36 | except (AttributeError, NotImplementedError): 37 | # not avail for this dataset 38 | if torch.distributed.is_initialized(): 39 | sampler = torch.utils.data.DistributedSampler( 40 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last 41 | ) 42 | elif shuffle: 43 | sampler = torch.utils.data.RandomSampler(dataset) 44 | else: 45 | sampler = torch.utils.data.SequentialSampler(dataset) 46 | 47 | data_loader = torch.utils.data.DataLoader( 48 | dataset, 49 | sampler=sampler, 50 | batch_size=batch_size, 51 | num_workers=num_workers, 52 | pin_memory=pin_mem, 53 | drop_last=drop_last, 54 | ) 55 | 56 | return data_loader 57 | -------------------------------------------------------------------------------- /dust3r/datasets/arkitscenes.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed arkitscenes 6 | # dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license 7 | # See datasets_preprocess/preprocess_arkitscenes.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import cv2 11 | import numpy as np 12 | 13 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 14 | from dust3r.utils.image import imread_cv2 15 | 16 | 17 | class ARKitScenes(BaseStereoViewDataset): 18 | def __init__(self, *args, split, ROOT, **kwargs): 19 | self.ROOT = ROOT 20 | super().__init__(*args, **kwargs) 21 | if split == "train": 22 | self.split = "Training" 23 | elif split == "test": 24 | self.split = "Test" 25 | else: 26 | raise ValueError("") 27 | 28 | self.loaded_data = self._load_data(self.split) 29 | 30 | def _load_data(self, split): 31 | with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data: 32 | self.scenes = data['scenes'] 33 | self.sceneids = data['sceneids'] 34 | self.images = data['images'] 35 | self.intrinsics = data['intrinsics'].astype(np.float32) 36 | self.trajectories = data['trajectories'].astype(np.float32) 37 | self.pairs = data['pairs'][:, :2].astype(int) 38 | 39 | def __len__(self): 40 | return len(self.pairs) 41 | 42 | def _get_views(self, idx, resolution, rng): 43 | 44 | image_idx1, image_idx2 = self.pairs[idx] 45 | 46 | views = [] 47 | for view_idx in [image_idx1, image_idx2]: 48 | scene_id = self.sceneids[view_idx] 49 | scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) 50 | 51 | intrinsics = self.intrinsics[view_idx] 52 | camera_pose = self.trajectories[view_idx] 53 | basename = self.images[view_idx] 54 | 55 | # Load RGB image 56 | rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg'))) 57 | # Load depthmap 58 | depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED) 59 | depthmap = depthmap.astype(np.float32) / 1000 60 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 61 | 62 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 63 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 64 | 65 | views.append(dict( 66 | img=rgb_image, 67 | depthmap=depthmap.astype(np.float32), 68 | camera_pose=camera_pose.astype(np.float32), 69 | camera_intrinsics=intrinsics.astype(np.float32), 70 | dataset='arkitscenes', 71 | label=self.scenes[scene_id] + '_' + basename, 72 | instance=f'{str(idx)}_{str(view_idx)}', 73 | )) 74 | 75 | return views 76 | 77 | 78 | if __name__ == "__main__": 79 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 80 | from dust3r.viz import SceneViz, auto_cam_size 81 | from dust3r.utils.image import rgb 82 | 83 | dataset = ARKitScenes(split='train', ROOT="data/arkitscenes_processed", resolution=224, aug_crop=16) 84 | 85 | for idx in np.random.permutation(len(dataset)): 86 | views = dataset[idx] 87 | assert len(views) == 2 88 | print(view_name(views[0]), view_name(views[1])) 89 | viz = SceneViz() 90 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 91 | cam_size = max(auto_cam_size(poses), 0.001) 92 | for view_idx in [0, 1]: 93 | pts3d = views[view_idx]['pts3d'] 94 | valid_mask = views[view_idx]['valid_mask'] 95 | colors = rgb(views[view_idx]['img']) 96 | viz.add_pointcloud(pts3d, colors, valid_mask) 97 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 98 | focal=views[view_idx]['camera_intrinsics'][0, 0], 99 | color=(idx * 255, (1 - idx) * 255, 0), 100 | image=colors, 101 | cam_size=cam_size) 102 | viz.show() 103 | -------------------------------------------------------------------------------- /dust3r/datasets/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /dust3r/datasets/base/batched_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Random sampling under a constraint 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class BatchedRandomSampler: 12 | """ Random sampling under a constraint: each sample in the batch has the same feature, 13 | which is chosen randomly from a known pool of 'features' for each batch. 14 | 15 | For instance, the 'feature' could be the image aspect-ratio. 16 | 17 | The index returned is a tuple (sample_idx, feat_idx). 18 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. 19 | """ 20 | 21 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): 22 | self.batch_size = batch_size 23 | self.pool_size = pool_size 24 | 25 | self.len_dataset = N = len(dataset) 26 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N 27 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' 28 | 29 | # distributed sampler 30 | self.world_size = world_size 31 | self.rank = rank 32 | self.epoch = None 33 | 34 | def __len__(self): 35 | return self.total_size // self.world_size 36 | 37 | def set_epoch(self, epoch): 38 | self.epoch = epoch 39 | 40 | def __iter__(self): 41 | # prepare RNG 42 | if self.epoch is None: 43 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' 44 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 45 | else: 46 | seed = self.epoch + 777 47 | rng = np.random.default_rng(seed=seed) 48 | 49 | # random indices (will restart from 0 if not drop_last) 50 | sample_idxs = np.arange(self.total_size) 51 | rng.shuffle(sample_idxs) 52 | 53 | # random feat_idxs (same across each batch) 54 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size 55 | feat_idxs = rng.integers(self.pool_size, size=n_batches) 56 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) 57 | feat_idxs = feat_idxs.ravel()[:self.total_size] 58 | 59 | # put them together 60 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) 61 | 62 | # Distributed sampler: we select a subset of batches 63 | # make sure the slice for each node is aligned with batch_size 64 | size_per_proc = self.batch_size * ((self.total_size + self.world_size * 65 | self.batch_size-1) // (self.world_size * self.batch_size)) 66 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] 67 | 68 | yield from (tuple(idx) for idx in idxs) 69 | 70 | 71 | def round_by(total, multiple, up=False): 72 | if up: 73 | total = total + multiple-1 74 | return (total//multiple) * multiple 75 | -------------------------------------------------------------------------------- /dust3r/datasets/blendedmvs.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed BlendedMVS 6 | # dataset at https://github.com/YoYo000/BlendedMVS 7 | # See datasets_preprocess/preprocess_blendedmvs.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import numpy as np 11 | 12 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 13 | from dust3r.utils.image import imread_cv2 14 | 15 | 16 | class BlendedMVS (BaseStereoViewDataset): 17 | """ Dataset of outdoor street scenes, 5 images each time 18 | """ 19 | 20 | def __init__(self, *args, ROOT, split=None, **kwargs): 21 | self.ROOT = ROOT 22 | super().__init__(*args, **kwargs) 23 | self._load_data(split) 24 | 25 | def _load_data(self, split): 26 | pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy')) 27 | if split is None: 28 | selection = slice(None) 29 | if split == 'train': 30 | # select 90% of all scenes 31 | selection = (pairs['seq_low'] % 10) > 0 32 | if split == 'val': 33 | # select 10% of all scenes 34 | selection = (pairs['seq_low'] % 10) == 0 35 | self.pairs = pairs[selection] 36 | 37 | # list of all scenes 38 | self.scenes = np.unique(self.pairs['seq_low']) # low is unique enough 39 | 40 | def __len__(self): 41 | return len(self.pairs) 42 | 43 | def get_stats(self): 44 | return f'{len(self)} pairs from {len(self.scenes)} scenes' 45 | 46 | def _get_views(self, pair_idx, resolution, rng): 47 | seqh, seql, img1, img2, score = self.pairs[pair_idx] 48 | 49 | seq = f"{seqh:08x}{seql:016x}" 50 | seq_path = osp.join(self.ROOT, seq) 51 | 52 | views = [] 53 | 54 | for view_index in [img1, img2]: 55 | impath = f"{view_index:08n}" 56 | image = imread_cv2(osp.join(seq_path, impath + ".jpg")) 57 | depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) 58 | camera_params = np.load(osp.join(seq_path, impath + ".npz")) 59 | 60 | intrinsics = np.float32(camera_params['intrinsics']) 61 | camera_pose = np.eye(4, dtype=np.float32) 62 | camera_pose[:3, :3] = camera_params['R_cam2world'] 63 | camera_pose[:3, 3] = camera_params['t_cam2world'] 64 | 65 | image, depthmap, intrinsics = self._crop_resize_if_necessary( 66 | image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) 67 | 68 | views.append(dict( 69 | img=image, 70 | depthmap=depthmap, 71 | camera_pose=camera_pose, # cam2world 72 | camera_intrinsics=intrinsics, 73 | dataset='BlendedMVS', 74 | label=osp.relpath(seq_path, self.ROOT), 75 | instance=impath)) 76 | 77 | return views 78 | 79 | 80 | if __name__ == '__main__': 81 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 82 | from dust3r.viz import SceneViz, auto_cam_size 83 | from dust3r.utils.image import rgb 84 | 85 | dataset = BlendedMVS(split='train', ROOT="data/blendedmvs_processed", resolution=224, aug_crop=16) 86 | 87 | for idx in np.random.permutation(len(dataset)): 88 | views = dataset[idx] 89 | assert len(views) == 2 90 | print(idx, view_name(views[0]), view_name(views[1])) 91 | viz = SceneViz() 92 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 93 | cam_size = max(auto_cam_size(poses), 0.001) 94 | for view_idx in [0, 1]: 95 | pts3d = views[view_idx]['pts3d'] 96 | valid_mask = views[view_idx]['valid_mask'] 97 | colors = rgb(views[view_idx]['img']) 98 | viz.add_pointcloud(pts3d, colors, valid_mask) 99 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 100 | focal=views[view_idx]['camera_intrinsics'][0, 0], 101 | color=(idx * 255, (1 - idx) * 255, 0), 102 | image=colors, 103 | cam_size=cam_size) 104 | viz.show() 105 | -------------------------------------------------------------------------------- /dust3r/datasets/scannetpp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed scannet++ 6 | # dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes 7 | # https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf 8 | # See datasets_preprocess/preprocess_scannetpp.py 9 | # -------------------------------------------------------- 10 | import os.path as osp 11 | import cv2 12 | import numpy as np 13 | 14 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 15 | from dust3r.utils.image import imread_cv2 16 | 17 | 18 | class ScanNetpp(BaseStereoViewDataset): 19 | def __init__(self, *args, ROOT, **kwargs): 20 | self.ROOT = ROOT 21 | super().__init__(*args, **kwargs) 22 | assert self.split == 'train' 23 | self.loaded_data = self._load_data() 24 | 25 | def _load_data(self): 26 | with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: 27 | self.scenes = data['scenes'] 28 | self.sceneids = data['sceneids'] 29 | self.images = data['images'] 30 | self.intrinsics = data['intrinsics'].astype(np.float32) 31 | self.trajectories = data['trajectories'].astype(np.float32) 32 | self.pairs = data['pairs'][:, :2].astype(int) 33 | 34 | def __len__(self): 35 | return len(self.pairs) 36 | 37 | def _get_views(self, idx, resolution, rng): 38 | 39 | image_idx1, image_idx2 = self.pairs[idx] 40 | 41 | views = [] 42 | for view_idx in [image_idx1, image_idx2]: 43 | scene_id = self.sceneids[view_idx] 44 | scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) 45 | 46 | intrinsics = self.intrinsics[view_idx] 47 | camera_pose = self.trajectories[view_idx] 48 | basename = self.images[view_idx] 49 | 50 | # Load RGB image 51 | rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg')) 52 | # Load depthmap 53 | depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED) 54 | depthmap = depthmap.astype(np.float32) / 1000 55 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 56 | 57 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 58 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 59 | 60 | views.append(dict( 61 | img=rgb_image, 62 | depthmap=depthmap.astype(np.float32), 63 | camera_pose=camera_pose.astype(np.float32), 64 | camera_intrinsics=intrinsics.astype(np.float32), 65 | dataset='ScanNet++', 66 | label=self.scenes[scene_id] + '_' + basename, 67 | instance=f'{str(idx)}_{str(view_idx)}', 68 | )) 69 | return views 70 | 71 | 72 | if __name__ == "__main__": 73 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 74 | from dust3r.viz import SceneViz, auto_cam_size 75 | from dust3r.utils.image import rgb 76 | 77 | dataset = ScanNetpp(split='train', ROOT="data/scannetpp_processed", resolution=224, aug_crop=16) 78 | 79 | for idx in np.random.permutation(len(dataset)): 80 | views = dataset[idx] 81 | assert len(views) == 2 82 | print(view_name(views[0]), view_name(views[1])) 83 | viz = SceneViz() 84 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 85 | cam_size = max(auto_cam_size(poses), 0.001) 86 | for view_idx in [0, 1]: 87 | pts3d = views[view_idx]['pts3d'] 88 | valid_mask = views[view_idx]['valid_mask'] 89 | colors = rgb(views[view_idx]['img']) 90 | viz.add_pointcloud(pts3d, colors, valid_mask) 91 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 92 | focal=views[view_idx]['camera_intrinsics'][0, 0], 93 | color=(idx*255, (1 - idx)*255, 0), 94 | image=colors, 95 | cam_size=cam_size) 96 | viz.show() 97 | -------------------------------------------------------------------------------- /dust3r/datasets/staticthings3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed StaticThings3D 6 | # dataset at https://github.com/lmb-freiburg/robustmvd/ 7 | # See datasets_preprocess/preprocess_staticthings3d.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import numpy as np 11 | 12 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 13 | from dust3r.utils.image import imread_cv2 14 | 15 | 16 | class StaticThings3D (BaseStereoViewDataset): 17 | """ Dataset of indoor scenes, 5 images each time 18 | """ 19 | def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): 20 | self.ROOT = ROOT 21 | super().__init__(*args, **kwargs) 22 | 23 | assert mask_bg in (True, False, 'rand') 24 | self.mask_bg = mask_bg 25 | 26 | # loading all pairs 27 | assert self.split is None 28 | self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy')) 29 | 30 | def __len__(self): 31 | return len(self.pairs) 32 | 33 | def get_stats(self): 34 | return f'{len(self)} pairs' 35 | 36 | def _get_views(self, pair_idx, resolution, rng): 37 | scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx] 38 | seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') 39 | 40 | views = [] 41 | 42 | mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) 43 | 44 | CAM = {b'l':'left', b'r':'right'} 45 | for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: 46 | num = f"{idx:04n}" 47 | img = num+"_clean.jpg" if rng.choice(2) else num+"_final.jpg" 48 | image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img)) 49 | depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+".exr")) 50 | camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+".npz")) 51 | 52 | intrinsics = camera_params['intrinsics'] 53 | camera_pose = camera_params['cam2world'] 54 | 55 | if mask_bg: 56 | depthmap[depthmap > 200] = 0 57 | 58 | image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img)) 59 | 60 | views.append(dict( 61 | img = image, 62 | depthmap = depthmap, 63 | camera_pose = camera_pose, # cam2world 64 | camera_intrinsics = intrinsics, 65 | dataset = 'StaticThings3D', 66 | label = seq_path, 67 | instance = cam+'_'+img)) 68 | 69 | return views 70 | 71 | 72 | if __name__ == '__main__': 73 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 74 | from dust3r.viz import SceneViz, auto_cam_size 75 | from dust3r.utils.image import rgb 76 | 77 | dataset = StaticThings3D(ROOT="data/staticthings3d_processed", resolution=224, aug_crop=16) 78 | 79 | for idx in np.random.permutation(len(dataset)): 80 | views = dataset[idx] 81 | assert len(views) == 2 82 | print(idx, view_name(views[0]), view_name(views[1])) 83 | viz = SceneViz() 84 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 85 | cam_size = max(auto_cam_size(poses), 0.001) 86 | for view_idx in [0, 1]: 87 | pts3d = views[view_idx]['pts3d'] 88 | valid_mask = views[view_idx]['valid_mask'] 89 | colors = rgb(views[view_idx]['img']) 90 | viz.add_pointcloud(pts3d, colors, valid_mask) 91 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 92 | focal=views[view_idx]['camera_intrinsics'][0, 0], 93 | color=(idx*255, (1 - idx)*255, 0), 94 | image=colors, 95 | cam_size=cam_size) 96 | viz.show() 97 | -------------------------------------------------------------------------------- /dust3r/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /dust3r/datasets/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # DUST3R default transforms 6 | # -------------------------------------------------------- 7 | import torchvision.transforms as tvf 8 | from dust3r.utils.image import ImgNorm 9 | 10 | # define the standard image transforms 11 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) 12 | -------------------------------------------------------------------------------- /dust3r/datasets/waymo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed WayMo 6 | # dataset at https://github.com/waymo-research/waymo-open-dataset 7 | # See datasets_preprocess/preprocess_waymo.py 8 | # -------------------------------------------------------- 9 | import sys 10 | sys.path.append('.') 11 | import os 12 | import os.path as osp 13 | import numpy as np 14 | 15 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 16 | from dust3r.utils.image import imread_cv2 17 | 18 | 19 | class Waymo (BaseStereoViewDataset): 20 | """ Dataset of outdoor street scenes, 5 images each time 21 | """ 22 | 23 | def __init__(self, *args, ROOT, pairs_npz_name='waymo_pairs_video.npz', **kwargs): 24 | self.ROOT = ROOT 25 | self.pairs_npz_name = pairs_npz_name 26 | super().__init__(*args, **kwargs) 27 | self._load_data() 28 | 29 | def _load_data(self): 30 | with np.load(osp.join(self.ROOT, self.pairs_npz_name)) as data: 31 | self.scenes = data['scenes'] 32 | self.frames = data['frames'] 33 | self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])} 34 | self.pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) 35 | assert self.pairs[:, 0].max() == len(self.scenes) - 1 36 | print(f'Loaded {self.get_stats()}') 37 | 38 | def __len__(self): 39 | return len(self.pairs) 40 | 41 | def get_stats(self): 42 | return f'{len(self)} pairs from {len(self.scenes)} scenes' 43 | 44 | def _get_views(self, pair_idx, resolution, rng): 45 | seq, img1, img2 = self.pairs[pair_idx] 46 | seq_path = osp.join(self.ROOT, self.scenes[seq]) 47 | 48 | views = [] 49 | 50 | for view_index in [img1, img2]: 51 | impath = self.frames[view_index] 52 | image = imread_cv2(osp.join(seq_path, impath + ".jpg")) 53 | depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) 54 | camera_params = np.load(osp.join(seq_path, impath + ".npz")) 55 | 56 | intrinsics = np.float32(camera_params['intrinsics']) 57 | camera_pose = np.float32(camera_params['cam2world']) 58 | 59 | image, depthmap, intrinsics = self._crop_resize_if_necessary( 60 | image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) 61 | 62 | views.append(dict( 63 | img=image, 64 | depthmap=depthmap, 65 | camera_pose=camera_pose, # cam2world 66 | camera_intrinsics=intrinsics, 67 | dataset='Waymo', 68 | label=osp.relpath(seq_path, self.ROOT), 69 | instance=impath)) 70 | 71 | return views 72 | 73 | 74 | if __name__ == '__main__': 75 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 76 | from dust3r.viz import SceneViz, auto_cam_size 77 | from dust3r.utils.image import rgb 78 | 79 | dataset = Waymo(split='train', ROOT="data/waymo_processed", resolution=512, aug_crop=16) 80 | idxs = np.arange(0, len(dataset)-1, (len(dataset)-1)//10) 81 | for idx in idxs: 82 | views = dataset[idx] 83 | assert len(views) == 2 84 | print(idx, view_name(views[0]), view_name(views[1])) 85 | viz = SceneViz() 86 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 87 | cam_size = max(auto_cam_size(poses), 0.001) 88 | for view_idx in [0, 1]: 89 | pts3d = views[view_idx]['pts3d'] 90 | valid_mask = views[view_idx]['valid_mask'] 91 | colors = rgb(views[view_idx]['img']) 92 | viz.add_pointcloud(pts3d, colors, valid_mask) 93 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 94 | focal=views[view_idx]['camera_intrinsics'][0, 0], 95 | color=(idx * 255, (1 - idx) * 255, 0), 96 | image=colors, 97 | cam_size=cam_size) 98 | os.makedirs('./tmp/waymo', exist_ok=True) 99 | path = f"./tmp/waymo/waymo_scene_{idx}.glb" 100 | viz.save_glb(path) 101 | -------------------------------------------------------------------------------- /dust3r/datasets/wildrgbd.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed WildRGB-D 6 | # dataset at https://github.com/wildrgbd/wildrgbd/ 7 | # See datasets_preprocess/preprocess_wildrgbd.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | 11 | import cv2 12 | import numpy as np 13 | 14 | from dust3r.datasets.co3d import Co3d 15 | from dust3r.utils.image import imread_cv2 16 | 17 | 18 | class WildRGBD(Co3d): 19 | def __init__(self, mask_bg=True, *args, ROOT, **kwargs): 20 | super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) 21 | self.dataset_label = 'WildRGBD' 22 | 23 | def _get_metadatapath(self, obj, instance, view_idx): 24 | return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz') 25 | 26 | def _get_impath(self, obj, instance, view_idx): 27 | return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg') 28 | 29 | def _get_depthpath(self, obj, instance, view_idx): 30 | return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png') 31 | 32 | def _get_maskpath(self, obj, instance, view_idx): 33 | return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png') 34 | 35 | def _read_depthmap(self, depthpath, input_metadata): 36 | # We store depths in the depth scale of 1000. 37 | # That is, when we load depth image and divide by 1000, we could get depth in meters. 38 | depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) 39 | depthmap = depthmap.astype(np.float32) / 1000.0 40 | return depthmap 41 | 42 | 43 | if __name__ == "__main__": 44 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 45 | from dust3r.viz import SceneViz, auto_cam_size 46 | from dust3r.utils.image import rgb 47 | 48 | dataset = WildRGBD(split='train', ROOT="data/wildrgbd_processed", resolution=224, aug_crop=16) 49 | 50 | for idx in np.random.permutation(len(dataset)): 51 | views = dataset[idx] 52 | assert len(views) == 2 53 | print(view_name(views[0]), view_name(views[1])) 54 | viz = SceneViz() 55 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 56 | cam_size = max(auto_cam_size(poses), 0.001) 57 | for view_idx in [0, 1]: 58 | pts3d = views[view_idx]['pts3d'] 59 | valid_mask = views[view_idx]['valid_mask'] 60 | colors = rgb(views[view_idx]['img']) 61 | viz.add_pointcloud(pts3d, colors, valid_mask) 62 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 63 | focal=views[view_idx]['camera_intrinsics'][0, 0], 64 | color=(idx * 255, (1 - idx) * 255, 0), 65 | image=colors, 66 | cam_size=cam_size) 67 | viz.show() 68 | -------------------------------------------------------------------------------- /dust3r/demo.py: -------------------------------------------------------------------------------- 1 | ../demo.py -------------------------------------------------------------------------------- /dust3r/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # head factory 6 | # -------------------------------------------------------- 7 | from .linear_head import LinearPts3d 8 | from .dpt_head import create_dpt_head 9 | 10 | 11 | def head_factory(head_type, output_mode, net, has_conf=False): 12 | """" build a prediction head for the decoder 13 | """ 14 | if head_type == 'linear' and output_mode == 'pts3d': 15 | return LinearPts3d(net, has_conf) 16 | elif head_type == 'dpt' and output_mode == 'pts3d': 17 | return create_dpt_head(net, has_conf=has_conf) 18 | else: 19 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") 20 | -------------------------------------------------------------------------------- /dust3r/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # linear head implementation for DUST3R 6 | # -------------------------------------------------------- 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from dust3r.heads.postprocess import postprocess 10 | 11 | 12 | class LinearPts3d (nn.Module): 13 | """ 14 | Linear head for dust3r 15 | Each token outputs: - 16x16 3D points (+ confidence) 16 | """ 17 | 18 | def __init__(self, net, has_conf=False): 19 | super().__init__() 20 | self.patch_size = net.patch_embed.patch_size[0] 21 | self.depth_mode = net.depth_mode 22 | self.conf_mode = net.conf_mode 23 | self.has_conf = has_conf 24 | 25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) 26 | 27 | def setup(self, croconet): 28 | pass 29 | 30 | def forward(self, decout, img_shape): 31 | H, W = img_shape 32 | tokens = decout[-1] 33 | B, S, D = tokens.shape 34 | 35 | # extract 3D points 36 | feat = self.proj(tokens) # B,S,D 37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 39 | 40 | # permute + norm depth 41 | return postprocess(feat, self.depth_mode, self.conf_mode) 42 | -------------------------------------------------------------------------------- /dust3r/heads/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # post process function for all heads: extract 3D points/confidence from output 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def postprocess(out, depth_mode, conf_mode): 11 | """ 12 | extract 3D points/confidence from prediction head output 13 | """ 14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3 15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) 16 | 17 | if conf_mode is not None: 18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) 19 | return res 20 | 21 | 22 | def reg_dense_depth(xyz, mode): 23 | """ 24 | extract 3D points from prediction head output 25 | """ 26 | mode, vmin, vmax = mode 27 | 28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) 29 | assert no_bounds 30 | 31 | if mode == 'linear': 32 | if no_bounds: 33 | return xyz # [-inf, +inf] 34 | return xyz.clip(min=vmin, max=vmax) 35 | 36 | # distance to origin 37 | d = xyz.norm(dim=-1, keepdim=True) 38 | xyz = xyz / d.clip(min=1e-8) 39 | 40 | if mode == 'square': 41 | return xyz * d.square() 42 | 43 | if mode == 'exp': 44 | return xyz * torch.expm1(d) 45 | 46 | raise ValueError(f'bad {mode=}') 47 | 48 | 49 | def reg_dense_conf(x, mode): 50 | """ 51 | extract confidence from prediction head output 52 | """ 53 | mode, vmin, vmax = mode 54 | if mode == 'exp': 55 | return vmin + x.exp().clip(max=vmax-vmin) 56 | if mode == 'sigmoid': 57 | return (vmax - vmin) * torch.sigmoid(x) + vmin 58 | raise ValueError(f'bad {mode=}') 59 | -------------------------------------------------------------------------------- /dust3r/image_pairs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # utilities needed to load image pairs 3 | # -------------------------------------------------------- 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True, force_symmetrize=False): 9 | pairs = [] 10 | if scene_graph == 'complete': # complete graph 11 | for i in range(len(imgs)): 12 | for j in range(i): 13 | pairs.append((imgs[i], imgs[j])) 14 | elif scene_graph.startswith('swin'): 15 | iscyclic = not scene_graph.endswith('noncyclic') 16 | try: 17 | winsize = int(scene_graph.split('-')[1]) 18 | except Exception as e: 19 | winsize = 3 20 | pairsid = set() 21 | if scene_graph.startswith('swinstride'): 22 | stride = 2 23 | elif scene_graph.startswith('swin2stride'): 24 | stride = 3 25 | else: 26 | stride = 1 27 | if scene_graph.startswith('swinskip_start'): 28 | start = 2 29 | else: 30 | start = 1 31 | for i in range(len(imgs)): 32 | for j in range(start, stride*winsize + start, stride): 33 | idx = (i + j) 34 | if iscyclic: 35 | idx = idx % len(imgs) # explicit loop closure 36 | if idx >= len(imgs): 37 | continue 38 | pairsid.add((i, idx) if i < idx else (idx, i)) 39 | for i, j in pairsid: 40 | pairs.append((imgs[i], imgs[j])) 41 | elif scene_graph.startswith('logwin'): 42 | iscyclic = not scene_graph.endswith('noncyclic') 43 | try: 44 | winsize = int(scene_graph.split('-')[1]) 45 | except Exception as e: 46 | winsize = 3 47 | offsets = [2**i for i in range(winsize)] 48 | pairsid = set() 49 | for i in range(len(imgs)): 50 | ixs_l = [i - off for off in offsets] 51 | ixs_r = [i + off for off in offsets] 52 | for j in ixs_l + ixs_r: 53 | if iscyclic: 54 | j = j % len(imgs) # Explicit loop closure 55 | if j < 0 or j >= len(imgs) or j == i: 56 | continue 57 | pairsid.add((i, j) if i < j else (j, i)) 58 | for i, j in pairsid: 59 | pairs.append((imgs[i], imgs[j])) 60 | elif scene_graph.startswith('oneref'): 61 | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 62 | for j in range(len(imgs)): 63 | pairs.append((imgs[refid], imgs[j])) 64 | 65 | if (symmetrize and not scene_graph.startswith('oneref') and not scene_graph.startswith('swin-1')) or len(imgs) == 2 or force_symmetrize: 66 | pairs += [(img2, img1) for img1, img2 in pairs] 67 | 68 | # now, remove edges 69 | if isinstance(prefilter, str) and prefilter.startswith('seq'): 70 | pairs = filter_pairs_seq(pairs, int(prefilter[3:])) 71 | 72 | if isinstance(prefilter, str) and prefilter.startswith('cyc'): 73 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) 74 | 75 | return pairs 76 | 77 | 78 | def sel(x, kept): 79 | if isinstance(x, dict): 80 | return {k: sel(v, kept) for k, v in x.items()} 81 | if isinstance(x, (torch.Tensor, np.ndarray)): 82 | return x[kept] 83 | if isinstance(x, (tuple, list)): 84 | return type(x)([x[k] for k in kept]) 85 | 86 | 87 | def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): 88 | # number of images 89 | n = max(max(e) for e in edges) + 1 90 | 91 | kept = [] 92 | for e, (i, j) in enumerate(edges): 93 | dis = abs(i - j) 94 | if cyclic: 95 | dis = min(dis, abs(i + n - j), abs(i - n - j)) 96 | if dis <= seq_dis_thr: 97 | kept.append(e) 98 | return kept 99 | 100 | 101 | def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): 102 | edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] 103 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) 104 | return [pairs[i] for i in kept] 105 | 106 | 107 | def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): 108 | edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] 109 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) 110 | print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') 111 | return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) 112 | -------------------------------------------------------------------------------- /dust3r/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # optimization functions 6 | # -------------------------------------------------------- 7 | 8 | 9 | def adjust_learning_rate_by_lr(optimizer, lr): 10 | for param_group in optimizer.param_groups: 11 | if "lr_scale" in param_group: 12 | param_group["lr"] = lr * param_group["lr_scale"] 13 | else: 14 | param_group["lr"] = lr 15 | -------------------------------------------------------------------------------- /dust3r/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # PatchEmbed implementation for DUST3R, 6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio 7 | # -------------------------------------------------------- 8 | import torch 9 | import dust3r.utils.path_to_croco # noqa: F401 10 | from models.blocks import PatchEmbed # noqa 11 | 12 | 13 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): 14 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] 15 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) 16 | return patch_embed 17 | 18 | 19 | class PatchEmbedDust3R(PatchEmbed): 20 | def forward(self, x, **kw): 21 | B, C, H, W = x.shape 22 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 23 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 24 | x = self.proj(x) 25 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 26 | if self.flatten: 27 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 28 | x = self.norm(x) 29 | return x, pos 30 | 31 | 32 | class ManyAR_PatchEmbed (PatchEmbed): 33 | """ Handle images with non-square aspect ratio. 34 | All images in the same batch have the same aspect ratio. 35 | true_shape = [(height, width) ...] indicates the actual shape of each image. 36 | """ 37 | 38 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, init='xavier'): 39 | self.embed_dim = embed_dim 40 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten, init) 41 | 42 | def forward(self, img, true_shape): 43 | B, C, H, W = img.shape 44 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' 45 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 46 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 47 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" 48 | 49 | # size expressed in tokens 50 | W //= self.patch_size[0] 51 | H //= self.patch_size[1] 52 | n_tokens = H * W 53 | 54 | height, width = true_shape.T 55 | is_landscape = (width >= height) 56 | is_portrait = ~is_landscape 57 | 58 | # allocate result 59 | x = img.new_zeros((B, n_tokens, self.embed_dim)) 60 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) 61 | 62 | # linear projection, transposed if necessary 63 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() 64 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() 65 | 66 | pos[is_landscape] = self.position_getter(1, H, W, pos.device) 67 | pos[is_portrait] = self.position_getter(1, W, H, pos.device) 68 | 69 | x = self.norm(x) 70 | return x, pos 71 | -------------------------------------------------------------------------------- /dust3r/post_process.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilities for interpreting the DUST3R output 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | from dust3r.utils.geometry import xy_grid 10 | 11 | 12 | def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): 13 | """ Reprojection method, for when the absolute depth is known: 14 | 1) estimate the camera focal using a robust estimator 15 | 2) reproject points onto true rays, minimizing a certain error 16 | """ 17 | B, H, W, THREE = pts3d.shape 18 | assert THREE == 3 19 | 20 | # centered pixel grid 21 | pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 22 | pts3d = pts3d.flatten(1, 2) # (B, HW, 3) 23 | 24 | if focal_mode == 'median': 25 | with torch.no_grad(): 26 | # direct estimation of focal 27 | u, v = pixels.unbind(dim=-1) 28 | x, y, z = pts3d.unbind(dim=-1) 29 | fx_votes = (u * z) / x 30 | fy_votes = (v * z) / y 31 | 32 | # assume square pixels, hence same focal for X and Y 33 | f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) 34 | focal = torch.nanmedian(f_votes, dim=-1).values 35 | 36 | elif focal_mode == 'weiszfeld': 37 | # init focal with l2 closed form 38 | # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| 39 | xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) 40 | 41 | dot_xy_px = (xy_over_z * pixels).sum(dim=-1) 42 | dot_xy_xy = xy_over_z.square().sum(dim=-1) 43 | 44 | focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) 45 | 46 | # iterative re-weighted least-squares 47 | for iter in range(10): 48 | # re-weighting by inverse of distance 49 | dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) 50 | # print(dis.nanmean(-1)) 51 | w = dis.clip(min=1e-8).reciprocal() 52 | # update the scaling with the new weights 53 | focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) 54 | else: 55 | raise ValueError(f'bad {focal_mode=}') 56 | 57 | focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 58 | focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) 59 | # print(focal) 60 | return focal 61 | -------------------------------------------------------------------------------- /dust3r/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /dust3r/utils/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def todevice(batch, device, callback=None, non_blocking=False): 12 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). 13 | 14 | batch: list, tuple, dict of tensors or other things 15 | device: pytorch device or 'numpy' 16 | callback: function that would be called on every sub-elements. 17 | ''' 18 | if callback: 19 | batch = callback(batch) 20 | 21 | if isinstance(batch, dict): 22 | return {k: todevice(v, device) for k, v in batch.items()} 23 | 24 | if isinstance(batch, (tuple, list)): 25 | return type(batch)(todevice(x, device) for x in batch) 26 | 27 | x = batch 28 | if device == 'numpy': 29 | if isinstance(x, torch.Tensor): 30 | x = x.detach().cpu().numpy() 31 | elif x is not None: 32 | if isinstance(x, np.ndarray): 33 | x = torch.from_numpy(x) 34 | if torch.is_tensor(x): 35 | x = x.to(device, non_blocking=non_blocking) 36 | return x 37 | 38 | 39 | to_device = todevice # alias 40 | 41 | 42 | def to_numpy(x): return todevice(x, 'numpy') 43 | def to_cpu(x): return todevice(x, 'cpu') 44 | def to_cuda(x): return todevice(x, 'cuda') 45 | 46 | 47 | def collate_with_cat(whatever, lists=False): 48 | if isinstance(whatever, dict): 49 | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} 50 | 51 | elif isinstance(whatever, (tuple, list)): 52 | if len(whatever) == 0: 53 | return whatever 54 | elem = whatever[0] 55 | T = type(whatever) 56 | 57 | if elem is None: 58 | return None 59 | if isinstance(elem, (bool, float, int, str)): 60 | return whatever 61 | if isinstance(elem, tuple): 62 | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) 63 | if isinstance(elem, dict): 64 | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} 65 | 66 | if isinstance(elem, torch.Tensor): 67 | return listify(whatever) if lists else torch.cat(whatever) 68 | if isinstance(elem, np.ndarray): 69 | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) 70 | 71 | # otherwise, we just chain lists 72 | return sum(whatever, T()) 73 | 74 | 75 | def listify(elems): 76 | return [x for e in elems for x in e] 77 | -------------------------------------------------------------------------------- /dust3r/utils/parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for multiprocessing 6 | # -------------------------------------------------------- 7 | from tqdm import tqdm 8 | from multiprocessing.dummy import Pool as ThreadPool 9 | from multiprocessing import cpu_count 10 | 11 | 12 | def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): 13 | """ tqdm but with parallel execution. 14 | 15 | Will essentially return 16 | res = [ function(arg) # default 17 | function(*arg) # if star_args is True 18 | function(**arg) # if kw_args is True 19 | for arg in args] 20 | 21 | Note: 22 | the first elements of args will not be parallelized. 23 | This can be useful for debugging. 24 | """ 25 | while workers <= 0: 26 | workers += cpu_count() 27 | if workers == 1: 28 | front_num = float('inf') 29 | 30 | # convert into an iterable 31 | try: 32 | n_args_parallel = len(args) - front_num 33 | except TypeError: 34 | n_args_parallel = None 35 | args = iter(args) 36 | 37 | # sequential execution first 38 | front = [] 39 | while len(front) < front_num: 40 | try: 41 | a = next(args) 42 | except StopIteration: 43 | return front # end of the iterable 44 | front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) 45 | 46 | # then parallel execution 47 | out = [] 48 | with Pool(workers) as pool: 49 | # Pass the elements of args into function 50 | if star_args: 51 | futures = pool.imap(starcall, [(function, a) for a in args]) 52 | elif kw_args: 53 | futures = pool.imap(starstarcall, [(function, a) for a in args]) 54 | else: 55 | futures = pool.imap(function, args) 56 | # Print out the progress as tasks complete 57 | for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): 58 | out.append(f) 59 | return front + out 60 | 61 | 62 | def parallel_processes(*args, **kwargs): 63 | """ Same as parallel_threads, with processes 64 | """ 65 | import multiprocessing as mp 66 | kwargs['Pool'] = mp.Pool 67 | return parallel_threads(*args, **kwargs) 68 | 69 | 70 | def starcall(args): 71 | """ convenient wrapper for Process.Pool """ 72 | function, args = args 73 | return function(*args) 74 | 75 | 76 | def starstarcall(args): 77 | """ convenient wrapper for Process.Pool """ 78 | function, args = args 79 | return function(**args) 80 | -------------------------------------------------------------------------------- /dust3r/utils/path_to_croco.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # CroCo submodule import 6 | # -------------------------------------------------------- 7 | 8 | import sys 9 | import os.path as path 10 | HERE_PATH = path.normpath(path.dirname(__file__)) 11 | CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) 12 | CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') 13 | # check the presence of models directory in repo to be sure its cloned 14 | if path.isdir(CROCO_MODELS_PATH): 15 | # workaround for sibling import 16 | sys.path.insert(0, CROCO_REPO_PATH) 17 | else: 18 | raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " 19 | "Did you forget to run 'git submodule update --init --recursive' ?") 20 | -------------------------------------------------------------------------------- /dust3r/utils/po_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # junyi -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # training executable for DUSt3R 3 | # -------------------------------------------------------- 4 | from dust3r.training import get_args_parser, train, load_model 5 | from dust3r.pose_eval import eval_pose_estimation 6 | from dust3r.depth_eval import eval_mono_depth_estimation 7 | import croco.utils.misc as misc # noqa 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import numpy as np 11 | import os 12 | 13 | if __name__ == '__main__': 14 | args = get_args_parser() 15 | args = args.parse_args() 16 | if args.mode.startswith('eval'): 17 | misc.init_distributed_mode(args) 18 | global_rank = misc.get_rank() 19 | world_size = misc.get_world_size() 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | device = torch.device(device) 22 | 23 | # fix the seed 24 | seed = args.seed + misc.get_rank() 25 | torch.manual_seed(seed) 26 | np.random.seed(seed) 27 | cudnn.benchmark = args.cudnn_benchmark 28 | model, _ = load_model(args, device) 29 | os.makedirs(args.output_dir, exist_ok=True) 30 | 31 | if args.mode == 'eval_pose': 32 | ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=args.output_dir) 33 | print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}') 34 | if args.mode == 'eval_depth': 35 | eval_mono_depth_estimation(args, model, device) 36 | 37 | exit(0) 38 | train(args) 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # minimal requirements for inference 2 | torch 3 | torchvision 4 | roma 5 | gradio 6 | matplotlib 7 | tqdm 8 | opencv-python 9 | scipy 10 | einops 11 | gdown 12 | trimesh 13 | pyglet<2 14 | huggingface-hub[torch]>=0.22 15 | 16 | # for camera trajectory 17 | evo 18 | 19 | # for sam2, customized version 20 | -e third_party/sam2 -------------------------------------------------------------------------------- /requirements_optional.txt: -------------------------------------------------------------------------------- 1 | pillow-heif # add heif/heic image support 2 | pyrender # for rendering depths in scannetpp 3 | kapture # for visloc data loading 4 | kapture-localization 5 | numpy-quaternion 6 | pycolmap # for pnp 7 | poselib # for pnp 8 | 9 | # for download tartanair 10 | boto3 11 | # for preprocess waymo 12 | tensorflow 13 | waymo-open-dataset-tf-2-12-0 --no-deps 14 | # need to change bytearray() to bytes() in waymo_open_dataset package 15 | 16 | # for training 17 | # for logging 18 | wandb 19 | tensorboard 20 | #for pointodyssey 21 | prettytable 22 | scikit-image 23 | scikit-learn 24 | h5py 25 | gdown 26 | # for scannet 27 | pypng -------------------------------------------------------------------------------- /third_party/RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /third_party/RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /third_party/RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /third_party/RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /third_party/RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/third_party/RAFT/core/__init__.py -------------------------------------------------------------------------------- /third_party/RAFT/core/configs/congif_spring_M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "spring-M", 3 | "dataset": "spring", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [540, 960], 18 | "scale": -1, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/third_party/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | def coords_grid2(batch, ht, wd, device): 7 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 8 | coords = torch.stack(coords[::-1], dim=0).float() 9 | return coords[None].repeat(batch, 1, 1, 1) 10 | 11 | class InputPadder: 12 | """ Pads images such that dimensions are divisible by 8 """ 13 | def __init__(self, dims, mode='sintel'): 14 | self.ht, self.wd = dims[-2:] 15 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 16 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 17 | if mode == 'sintel': 18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 19 | else: 20 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 21 | 22 | def pad(self, *inputs): 23 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 24 | 25 | def unpad(self, x): 26 | ht, wd = x.shape[-2:] 27 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 28 | return x[..., c[0]:c[1], c[2]:c[3]] 29 | 30 | def forward_interpolate(flow): 31 | flow = flow.detach().cpu().numpy() 32 | dx, dy = flow[0], flow[1] 33 | 34 | ht, wd = dx.shape 35 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 36 | 37 | x1 = x0 + dx 38 | y1 = y0 + dy 39 | 40 | x1 = x1.reshape(-1) 41 | y1 = y1.reshape(-1) 42 | dx = dx.reshape(-1) 43 | dy = dy.reshape(-1) 44 | 45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 46 | x1 = x1[valid] 47 | y1 = y1[valid] 48 | dx = dx[valid] 49 | dy = dy[valid] 50 | 51 | flow_x = interpolate.griddata( 52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow_y = interpolate.griddata( 55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 56 | 57 | flow = np.stack([flow_x, flow_y], axis=0) 58 | return torch.from_numpy(flow).float() 59 | 60 | 61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 62 | """ Wrapper for grid_sample, uses pixel coordinates """ 63 | H, W = img.shape[-2:] 64 | xgrid, ygrid = coords.split([1,1], dim=-1) 65 | xgrid = 2*xgrid/(W-1) - 1 66 | ygrid = 2*ygrid/(H-1) - 1 67 | 68 | grid = torch.cat([xgrid, ygrid], dim=-1) 69 | img = F.grid_sample(img, grid, align_corners=True) 70 | 71 | if mask: 72 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 73 | return img, mask.float() 74 | 75 | return img 76 | 77 | 78 | def coords_grid(batch, ht, wd): 79 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 80 | coords = torch.stack(coords[::-1], dim=0).float() 81 | return coords[None].repeat(batch, 1, 1, 1) 82 | 83 | 84 | def upflow8(flow, mode='bilinear'): 85 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 86 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 87 | -------------------------------------------------------------------------------- /third_party/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | model = torch.nn.DataParallel(RAFT(args)) 44 | model.load_state_dict(torch.load(args.model)) 45 | 46 | model = model.module 47 | model.to(DEVICE) 48 | model.eval() 49 | 50 | with torch.no_grad(): 51 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 52 | glob.glob(os.path.join(args.path, '*.jpg')) 53 | 54 | images = sorted(images) 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | image1 = load_image(imfile1) 57 | image2 = load_image(imfile2) 58 | 59 | padder = InputPadder(image1.shape) 60 | image1, image2 = padder.pad(image1, image2) 61 | 62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 63 | viz(image1, flow_up) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model', help="restore checkpoint") 69 | parser.add_argument('--path', help="dataset for evaluation") 70 | parser.add_argument('--small', action='store_true', help='use small model') 71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 73 | args = parser.parse_args() 74 | 75 | demo(args) 76 | -------------------------------------------------------------------------------- /third_party/RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /third_party/RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /third_party/RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 7 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/raft.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import torch 5 | import json 6 | from os.path import dirname, join 7 | RAFT_PATH_ROOT = join(dirname(__file__), 'RAFT') 8 | RAFT_PATH_CORE = join(RAFT_PATH_ROOT, 'core') 9 | sys.path.append(RAFT_PATH_CORE) 10 | from raft import RAFT, RAFT2 # nopep8 11 | from utils.utils import InputPadder # nopep8 12 | 13 | # %% 14 | # utility functions 15 | 16 | def json_to_args(json_path): 17 | # return a argparse.Namespace object 18 | with open(json_path, 'r') as f: 19 | data = json.load(f) 20 | args = argparse.Namespace() 21 | args_dict = args.__dict__ 22 | for key, value in data.items(): 23 | args_dict[key] = value 24 | return args 25 | 26 | def parse_args(parser): 27 | entry = parser.parse_args(args=[]) 28 | json_path = entry.cfg 29 | args = json_to_args(json_path) 30 | args_dict = args.__dict__ 31 | for index, (key, value) in enumerate(vars(entry).items()): 32 | args_dict[key] = value 33 | return args 34 | 35 | def get_input_padder(shape): 36 | return InputPadder(shape, mode='sintel') 37 | 38 | 39 | def load_RAFT(model_path=None): 40 | if model_path is None or 'M' not in model_path: # RAFT1 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--model', help="restore checkpoint", default=model_path) 43 | parser.add_argument('--path', help="dataset for evaluation") 44 | parser.add_argument('--small', action='store_true', help='use small model') 45 | parser.add_argument('--mixed_precision', 46 | action='store_true', help='use mixed precision') 47 | parser.add_argument('--alternate_corr', action='store_true', 48 | help='use efficient correlation implementation') 49 | 50 | # Set default value for --model if model_path is provided 51 | args = parser.parse_args( 52 | ['--model', model_path if model_path else join(RAFT_PATH_ROOT, 'models', 'raft-sintel.pth'), '--path', './']) 53 | 54 | net = RAFT(args) 55 | else: # RAFT2 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--cfg', help='experiment configure file name', default="third_party/RAFT/core/configs/congif_spring_M.json") 58 | parser.add_argument('--model', help='checkpoint path', default=model_path) 59 | parser.add_argument('--device', help='inference device', type=str, default='cpu') 60 | args = parse_args(parser) 61 | net = RAFT2(args) 62 | 63 | if torch.cuda.is_available(): 64 | state_dict = torch.load(args.model) 65 | else: 66 | state_dict = torch.load(args.model, map_location="cpu") 67 | print('Loaded pretrained RAFT model from', args.model) 68 | new_state_dict = {} 69 | for k in state_dict: 70 | if 'module' in k: 71 | name = k[7:] 72 | else: 73 | name = k 74 | new_state_dict[name] = state_dict[k] 75 | net.load_state_dict(new_state_dict) 76 | return net.eval() 77 | 78 | if __name__ == "__main__": 79 | net = load_RAFT(model_path='third_party/RAFT/models/Tartan-C-T432x960-M.pth') 80 | print(net) -------------------------------------------------------------------------------- /third_party/sam2/.clang-format: -------------------------------------------------------------------------------- 1 | AccessModifierOffset: -1 2 | AlignAfterOpenBracket: AlwaysBreak 3 | AlignConsecutiveAssignments: false 4 | AlignConsecutiveDeclarations: false 5 | AlignEscapedNewlinesLeft: true 6 | AlignOperands: false 7 | AlignTrailingComments: false 8 | AllowAllParametersOfDeclarationOnNextLine: false 9 | AllowShortBlocksOnASingleLine: false 10 | AllowShortCaseLabelsOnASingleLine: false 11 | AllowShortFunctionsOnASingleLine: Empty 12 | AllowShortIfStatementsOnASingleLine: false 13 | AllowShortLoopsOnASingleLine: false 14 | AlwaysBreakAfterReturnType: None 15 | AlwaysBreakBeforeMultilineStrings: true 16 | AlwaysBreakTemplateDeclarations: true 17 | BinPackArguments: false 18 | BinPackParameters: false 19 | BraceWrapping: 20 | AfterClass: false 21 | AfterControlStatement: false 22 | AfterEnum: false 23 | AfterFunction: false 24 | AfterNamespace: false 25 | AfterObjCDeclaration: false 26 | AfterStruct: false 27 | AfterUnion: false 28 | BeforeCatch: false 29 | BeforeElse: false 30 | IndentBraces: false 31 | BreakBeforeBinaryOperators: None 32 | BreakBeforeBraces: Attach 33 | BreakBeforeTernaryOperators: true 34 | BreakConstructorInitializersBeforeComma: false 35 | BreakAfterJavaFieldAnnotations: false 36 | BreakStringLiterals: false 37 | ColumnLimit: 80 38 | CommentPragmas: '^ IWYU pragma:' 39 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 40 | ConstructorInitializerIndentWidth: 4 41 | ContinuationIndentWidth: 4 42 | Cpp11BracedListStyle: true 43 | DerivePointerAlignment: false 44 | DisableFormat: false 45 | ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] 46 | IncludeCategories: 47 | - Regex: '^<.*\.h(pp)?>' 48 | Priority: 1 49 | - Regex: '^<.*' 50 | Priority: 2 51 | - Regex: '.*' 52 | Priority: 3 53 | IndentCaseLabels: true 54 | IndentWidth: 2 55 | IndentWrappedFunctionNames: false 56 | KeepEmptyLinesAtTheStartOfBlocks: false 57 | MacroBlockBegin: '' 58 | MacroBlockEnd: '' 59 | MaxEmptyLinesToKeep: 1 60 | NamespaceIndentation: None 61 | ObjCBlockIndentWidth: 2 62 | ObjCSpaceAfterProperty: false 63 | ObjCSpaceBeforeProtocolList: false 64 | PenaltyBreakBeforeFirstCallParameter: 1 65 | PenaltyBreakComment: 300 66 | PenaltyBreakFirstLessLess: 120 67 | PenaltyBreakString: 1000 68 | PenaltyExcessCharacter: 1000000 69 | PenaltyReturnTypeOnItsOwnLine: 200 70 | PointerAlignment: Left 71 | ReflowComments: true 72 | SortIncludes: true 73 | SpaceAfterCStyleCast: false 74 | SpaceBeforeAssignmentOperators: true 75 | SpaceBeforeParens: ControlStatements 76 | SpaceInEmptyParentheses: false 77 | SpacesBeforeTrailingComments: 1 78 | SpacesInAngles: false 79 | SpacesInContainerLiterals: true 80 | SpacesInCStyleCastParentheses: false 81 | SpacesInParentheses: false 82 | SpacesInSquareBrackets: false 83 | Standard: Cpp11 84 | TabWidth: 8 85 | UseTab: Never 86 | -------------------------------------------------------------------------------- /third_party/sam2/.github/workflows/check_fmt.yml: -------------------------------------------------------------------------------- 1 | name: SAM2/fmt 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | jobs: 7 | ufmt_check: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Check formatting 11 | uses: omnilib/ufmt@action-v1 12 | with: 13 | path: sam2 tools 14 | version: "2.0.0b2" 15 | python-version: "3.10" 16 | black-version: "24.2.0" 17 | usort-version: "1.0.2" 18 | -------------------------------------------------------------------------------- /third_party/sam2/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | outputs/* 10 | checkpoints/*.pt 11 | -------------------------------------------------------------------------------- /third_party/sam2/.watchmanconfig: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /third_party/sam2/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /third_party/sam2/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /third_party/sam2/LICENSE_cctorch: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /third_party/sam2/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | recursive-include sam2 *.yaml #include all config files 8 | -------------------------------------------------------------------------------- /third_party/sam2/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/third_party/sam2/assets/model_diagram.png -------------------------------------------------------------------------------- /third_party/sam2/assets/sa_v_dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junyi42/monst3r/1747338dcc01a850f7105bdd7147ab167e400f97/third_party/sam2/assets/sa_v_dataset.jpg -------------------------------------------------------------------------------- /third_party/sam2/backend.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime 2 | ARG MODEL_SIZE=base_plus 3 | 4 | FROM ${BASE_IMAGE} 5 | 6 | # Gunicorn environment variables 7 | ENV GUNICORN_WORKERS=1 8 | ENV GUNICORN_THREADS=2 9 | ENV GUNICORN_PORT=5000 10 | 11 | # SAM 2 environment variables 12 | ENV APP_ROOT=/opt/sam2 13 | ENV PYTHONUNBUFFERED=1 14 | ENV SAM2_BUILD_CUDA=0 15 | ENV MODEL_SIZE=${MODEL_SIZE} 16 | 17 | # Install system requirements 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | ffmpeg \ 20 | libavutil-dev \ 21 | libavcodec-dev \ 22 | libavformat-dev \ 23 | libswscale-dev \ 24 | pkg-config \ 25 | build-essential \ 26 | libffi-dev 27 | 28 | COPY setup.py . 29 | COPY README.md . 30 | 31 | RUN pip install --upgrade pip setuptools 32 | RUN pip install -e ".[interactive-demo]" 33 | 34 | # https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707 35 | RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg 36 | 37 | # Make app directory. This directory will host all files required for the 38 | # backend and SAM 2 inference files. 39 | RUN mkdir ${APP_ROOT} 40 | 41 | # Copy backend server files 42 | COPY demo/backend/server ${APP_ROOT}/server 43 | 44 | # Copy SAM 2 inference files 45 | COPY sam2 ${APP_ROOT}/server/sam2 46 | 47 | # Download SAM 2.1 checkpoints 48 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt 49 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt 50 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt 51 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt 52 | 53 | WORKDIR ${APP_ROOT}/server 54 | 55 | # https://pythonspeed.com/articles/gunicorn-in-docker/ 56 | CMD gunicorn --worker-tmp-dir /dev/shm \ 57 | --worker-class gthread app:app \ 58 | --log-level info \ 59 | --access-logfile /dev/stdout \ 60 | --log-file /dev/stderr \ 61 | --workers ${GUNICORN_WORKERS} \ 62 | --threads ${GUNICORN_THREADS} \ 63 | --bind 0.0.0.0:${GUNICORN_PORT} \ 64 | --timeout 60 65 | -------------------------------------------------------------------------------- /third_party/sam2/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | frontend: 3 | image: sam2/frontend 4 | build: 5 | context: ./demo/frontend 6 | dockerfile: frontend.Dockerfile 7 | ports: 8 | - 7262:80 9 | 10 | backend: 11 | image: sam2/backend 12 | build: 13 | context: . 14 | dockerfile: backend.Dockerfile 15 | ports: 16 | - 7263:5000 17 | volumes: 18 | - ./demo/data/:/data/:rw 19 | environment: 20 | - SERVER_ENVIRONMENT=DEV 21 | - GUNICORN_WORKERS=1 22 | # Inference API needs to have at least 2 threads to handle an incoming 23 | # parallel cancel propagation request 24 | - GUNICORN_THREADS=2 25 | - GUNICORN_PORT=5000 26 | - API_URL=http://localhost:7263 27 | - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 28 | # # ffmpeg/video encode settings 29 | - FFMPEG_NUM_THREADS=1 30 | - VIDEO_ENCODE_CODEC=libx264 31 | - VIDEO_ENCODE_CRF=23 32 | - VIDEO_ENCODE_FPS=24 33 | - VIDEO_ENCODE_MAX_WIDTH=1280 34 | - VIDEO_ENCODE_MAX_HEIGHT=720 35 | - VIDEO_ENCODE_VERBOSE=False 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /third_party/sam2/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=61.0", 4 | "torch>=2.3.1", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | no_obj_embed_spatial: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: true 105 | proj_tpos_enc_in_obj_ptrs: true 106 | use_signed_tpos_enc_to_obj_ptrs: true 107 | only_obj_ptrs_in_the_past_for_eval: true 108 | # object occlusion prediction 109 | pred_obj_scores: true 110 | pred_obj_scores_mlp: true 111 | fixed_no_obj_ptr: true 112 | # multimask tracking settings 113 | multimask_output_for_tracking: true 114 | use_multimask_token_for_obj_ptr: true 115 | multimask_min_pt_num: 0 116 | multimask_max_pt_num: 1 117 | use_mlp_for_obj_ptr_proj: true 118 | # Compilation flag 119 | compile_image_encoder: False 120 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /third_party/sam2/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /third_party/sam2/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /third_party/sam2/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /third_party/sam2/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/tools/README.md: -------------------------------------------------------------------------------- 1 | ## SAM 2 toolkits 2 | 3 | This directory provides toolkits for additional SAM 2 use cases. 4 | 5 | ### Semi-supervised VOS inference 6 | 7 | The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. 8 | 9 | After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. 10 | ```bash 11 | python ./tools/vos_inference.py \ 12 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 13 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 14 | --base_video_dir /path-to-davis-2017/JPEGImages/480p \ 15 | --input_mask_dir /path-to-davis-2017/Annotations/480p \ 16 | --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ 17 | --output_mask_dir ./outputs/davis_2017_pred_pngs 18 | ``` 19 | (replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) 20 | 21 | To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. 22 | ```bash 23 | python ./tools/vos_inference.py \ 24 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 25 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 26 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 27 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 28 | --video_list_file /path-to-sav-val/sav_val.txt \ 29 | --per_obj_png_file \ 30 | --output_mask_dir ./outputs/sav_val_pred_pngs 31 | ``` 32 | (replace `/path-to-sav-val` with the path to SA-V val) 33 | 34 | Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. 35 | 36 | Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. 37 | -------------------------------------------------------------------------------- /third_party/sam2/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/training/assets/MOSE_sample_val_list.txt: -------------------------------------------------------------------------------- 1 | 32e5d721 2 | 5bad0bab 3 | 267bfd6c 4 | 0a43a414 5 | 56c56ca9 6 | 9a1146b3 7 | c6ad7aaf 8 | 78a1f4b1 9 | fc455e73 10 | 072e7b3f 11 | 77ccb57d 12 | a76ee415 13 | 8cdcfc17 14 | 5d518b42 15 | 376dd830 16 | 0e843fc8 17 | 2af0e766 18 | 2bd4e845 19 | de2f2a6a 20 | ade9ee91 21 | 001ca3cb 22 | fc4c1c67 23 | 8ef55579 24 | b84ce852 25 | 4cc8528a 26 | 767ffaaa 27 | 112a2ef0 28 | a338c8aa 29 | cbd144f5 30 | 5ff72128 31 | 86a949e2 32 | 9f2323ac 33 | 1fab1d1c 34 | 75924351 35 | ef55817b 36 | 02deca50 37 | 4d979d99 38 | 4d65f873 39 | 28470fa0 40 | 0d1575fe 41 | 06ea172e 42 | 29a6ddc2 43 | 797f1bec 44 | 780e7a99 45 | b9ed5b44 46 | 02a236b4 47 | 607d8ff5 48 | af5666b2 49 | 0558d0ed 50 | a938c6b2 51 | 103df575 52 | 77110e80 53 | 739e5a07 54 | 6763a576 55 | 06ebc138 56 | ba4b3b09 57 | b35cc2f3 58 | 4e0597a0 59 | 5949ee84 60 | 5348d547 61 | 323c4236 62 | b3b51117 63 | 55727ddd 64 | ab2714f3 65 | d2878895 66 | c0734cb3 67 | 94f7c53e 68 | 2a2745e5 69 | 442ffb54 70 | 3592425a 71 | 50ae03b0 72 | 5f150435 73 | 3067f9fa 74 | 9ffb2818 75 | adeaf5aa 76 | 31caacec 77 | 1cd99b86 78 | aa22f9d0 79 | 8fa50320 80 | e6348d2c 81 | 42ff84a5 82 | 8c8b7913 83 | c96adcbc 84 | 495be321 85 | db735509 86 | ee113fc4 87 | a678cdab 88 | c409ca4d 89 | 68d2b259 90 | 592b4dee 91 | 4e2b4dc7 92 | eb4d26e1 93 | 2009a00f 94 | bec5c89d 95 | 67191f24 96 | a3e85b4b 97 | da7080cd 98 | 80d978e9 99 | 36dcb93f 100 | a41e8c44 101 | 12fdc864 102 | 46d140ea 103 | 657c9dd9 104 | a86f84ee 105 | 90c1c43d 106 | 33015509 107 | afc7664d 108 | 23df06e1 109 | 291d4799 110 | 0ab75563 111 | 251bf059 112 | bcefdcc4 113 | ce9a2796 114 | 94d3403a 115 | 8f2e04bc 116 | f9cda066 117 | 9dfa2cc5 118 | 66924c91 119 | e765a09e 120 | 15654ee1 121 | 48e0bd39 122 | ee095221 123 | 2463609b 124 | 544d0d1f 125 | 51b8c2e1 126 | d321dde4 127 | 4cb11a5f 128 | d7058a0d 129 | 37af282a 130 | fabae187 131 | 7be91184 132 | 181ec185 133 | 2d16ceeb 134 | b56be4b1 135 | 6699eff0 136 | 79acac96 137 | d61c4665 138 | 0c13e1e7 139 | 100f6ecf 140 | 71217dfc 141 | 82df0888 142 | 4c42c747 143 | c9fdf703 144 | d2efeb4b 145 | 69ed9d14 146 | 64914fb6 147 | 255bedbc 148 | 4ea934d8 149 | a034feb2 150 | e4f4ddae 151 | e36a3026 152 | c1489591 153 | 111bb373 154 | e1d9fb32 155 | 93e22d48 156 | c1ec4b26 157 | d9638e69 158 | 60ab04c5 159 | cfe7773a 160 | 62132822 161 | 2f5fb2a3 162 | 7bdd197d 163 | 033333fd 164 | 130fcdbe 165 | 12e509c2 166 | 67138c33 167 | 6f90cc5f 168 | 4e3020fe 169 | bbdd8bb7 170 | b399ccdb 171 | fecd10d2 172 | 2e0967f7 173 | f509054f 174 | 792c6ff7 175 | 48e2afc5 176 | d904c048 177 | 111e0a5c 178 | b83024e2 179 | e6a7b79c 180 | bdc5ccf7 181 | b8146d00 182 | 9d394f1a 183 | 645b84f9 184 | 95ab2d0f 185 | e6f8a31d 186 | b4f876fb 187 | dc2c570d 188 | 3afd02d7 189 | 5c80c82c 190 | b1b32ddd 191 | 9f25fc61 192 | ba538072 193 | f8916fef 194 | 43c04ad2 195 | a658e949 196 | 2861dd53 197 | f6e40aba 198 | 09d305d1 199 | aac33bff 200 | 8d9d4c08 201 | -------------------------------------------------------------------------------- /third_party/sam2/training/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/training/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | from torch.utils.data import ( 13 | ConcatDataset as TorchConcatDataset, 14 | Dataset, 15 | Subset as TorchSubset, 16 | ) 17 | 18 | 19 | class ConcatDataset(TorchConcatDataset): 20 | def __init__(self, datasets: Iterable[Dataset]) -> None: 21 | super(ConcatDataset, self).__init__(datasets) 22 | 23 | self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) 24 | 25 | def set_epoch(self, epoch: int): 26 | for dataset in self.datasets: 27 | if hasattr(dataset, "epoch"): 28 | dataset.epoch = epoch 29 | if hasattr(dataset, "set_epoch"): 30 | dataset.set_epoch(epoch) 31 | 32 | 33 | class Subset(TorchSubset): 34 | def __init__(self, dataset, indices) -> None: 35 | super(Subset, self).__init__(dataset, indices) 36 | 37 | self.repeat_factors = dataset.repeat_factors[indices] 38 | assert len(indices) == len(self.repeat_factors) 39 | 40 | 41 | # Adapted from Detectron2 42 | class RepeatFactorWrapper(Dataset): 43 | """ 44 | Thin wrapper around a dataset to implement repeat factor sampling. 45 | The underlying dataset must have a repeat_factors member to indicate the per-image factor. 46 | Set it to uniformly ones to disable repeat factor sampling 47 | """ 48 | 49 | def __init__(self, dataset, seed: int = 0): 50 | self.dataset = dataset 51 | self.epoch_ids = None 52 | self._seed = seed 53 | 54 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 55 | self._int_part = torch.trunc(dataset.repeat_factors) 56 | self._frac_part = dataset.repeat_factors - self._int_part 57 | 58 | def _get_epoch_indices(self, generator): 59 | """ 60 | Create a list of dataset indices (with repeats) to use for one epoch. 61 | 62 | Args: 63 | generator (torch.Generator): pseudo random number generator used for 64 | stochastic rounding. 65 | 66 | Returns: 67 | torch.Tensor: list of dataset indices to use in one epoch. Each index 68 | is repeated based on its calculated repeat factor. 69 | """ 70 | # Since repeat factors are fractional, we use stochastic rounding so 71 | # that the target repeat factor is achieved in expectation over the 72 | # course of training 73 | rands = torch.rand(len(self._frac_part), generator=generator) 74 | rep_factors = self._int_part + (rands < self._frac_part).float() 75 | # Construct a list of indices in which we repeat images as specified 76 | indices = [] 77 | for dataset_index, rep_factor in enumerate(rep_factors): 78 | indices.extend([dataset_index] * int(rep_factor.item())) 79 | return torch.tensor(indices, dtype=torch.int64) 80 | 81 | def __len__(self): 82 | if self.epoch_ids is None: 83 | # Here we raise an error instead of returning original len(self.dataset) avoid 84 | # accidentally using unwrapped length. Otherwise it's error-prone since the 85 | # length changes to `len(self.epoch_ids)`changes after set_epoch is called. 86 | raise RuntimeError("please call set_epoch first to get wrapped length") 87 | # return len(self.dataset) 88 | 89 | return len(self.epoch_ids) 90 | 91 | def set_epoch(self, epoch: int): 92 | g = torch.Generator() 93 | g.manual_seed(self._seed + epoch) 94 | self.epoch_ids = self._get_epoch_indices(g) 95 | if hasattr(self.dataset, "set_epoch"): 96 | self.dataset.set_epoch(epoch) 97 | 98 | def __getitem__(self, idx): 99 | if self.epoch_ids is None: 100 | raise RuntimeError( 101 | "Repeat ids haven't been computed. Did you forget to call set_epoch?" 102 | ) 103 | 104 | return self.dataset[self.epoch_ids[idx]] 105 | -------------------------------------------------------------------------------- /third_party/sam2/training/dataset/vos_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from dataclasses import dataclass 9 | from typing import List 10 | 11 | from training.dataset.vos_segment_loader import LazySegments 12 | 13 | MAX_RETRIES = 1000 14 | 15 | 16 | @dataclass 17 | class SampledFramesAndObjects: 18 | frames: List[int] 19 | object_ids: List[int] 20 | 21 | 22 | class VOSSampler: 23 | def __init__(self, sort_frames=True): 24 | # frames are ordered by frame id when sort_frames is True 25 | self.sort_frames = sort_frames 26 | 27 | def sample(self, video): 28 | raise NotImplementedError() 29 | 30 | 31 | class RandomUniformSampler(VOSSampler): 32 | def __init__( 33 | self, 34 | num_frames, 35 | max_num_objects, 36 | reverse_time_prob=0.0, 37 | ): 38 | self.num_frames = num_frames 39 | self.max_num_objects = max_num_objects 40 | self.reverse_time_prob = reverse_time_prob 41 | 42 | def sample(self, video, segment_loader, epoch=None): 43 | 44 | for retry in range(MAX_RETRIES): 45 | if len(video.frames) < self.num_frames: 46 | raise Exception( 47 | f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." 48 | ) 49 | start = random.randrange(0, len(video.frames) - self.num_frames + 1) 50 | frames = [video.frames[start + step] for step in range(self.num_frames)] 51 | if random.uniform(0, 1) < self.reverse_time_prob: 52 | # Reverse time 53 | frames = frames[::-1] 54 | 55 | # Get first frame object ids 56 | visible_object_ids = [] 57 | loaded_segms = segment_loader.load(frames[0].frame_idx) 58 | if isinstance(loaded_segms, LazySegments): 59 | # LazySegments for SA1BRawDataset 60 | visible_object_ids = list(loaded_segms.keys()) 61 | else: 62 | for object_id, segment in segment_loader.load( 63 | frames[0].frame_idx 64 | ).items(): 65 | if segment.sum(): 66 | visible_object_ids.append(object_id) 67 | 68 | # First frame needs to have at least a target to track 69 | if len(visible_object_ids) > 0: 70 | break 71 | if retry >= MAX_RETRIES - 1: 72 | raise Exception("No visible objects") 73 | 74 | object_ids = random.sample( 75 | visible_object_ids, 76 | min(len(visible_object_ids), self.max_num_objects), 77 | ) 78 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 79 | 80 | 81 | class EvalSampler(VOSSampler): 82 | """ 83 | VOS Sampler for evaluation: sampling all the frames and all the objects in a video 84 | """ 85 | 86 | def __init__( 87 | self, 88 | ): 89 | super().__init__() 90 | 91 | def sample(self, video, segment_loader, epoch=None): 92 | """ 93 | Sampling all the frames and all the objects 94 | """ 95 | if self.sort_frames: 96 | # ordered by frame id 97 | frames = sorted(video.frames, key=lambda x: x.frame_idx) 98 | else: 99 | # use the original order 100 | frames = video.frames 101 | object_ids = segment_loader.load(frames[0].frame_idx).keys() 102 | if len(object_ids) == 0: 103 | raise Exception("First frame of the video has no objects") 104 | 105 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 106 | -------------------------------------------------------------------------------- /third_party/sam2/training/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /third_party/sam2/training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | --------------------------------------------------------------------------------