├── mmmg_eval ├── sam2 │ ├── .watchmanconfig │ ├── sam2 │ │ ├── sam2_hiera_l.yaml │ │ ├── sam2_hiera_s.yaml │ │ ├── sam2_hiera_t.yaml │ │ ├── sam2_hiera_b+.yaml │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── sam │ │ │ │ ├── __init__.py │ │ │ │ └── prompt_encoder.py │ │ │ ├── backbones │ │ │ │ ├── __init__.py │ │ │ │ ├── utils.py │ │ │ │ └── image_encoder.py │ │ │ ├── memory_attention.py │ │ │ ├── memory_encoder.py │ │ │ └── position_encoding.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── transforms.py │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── configs │ │ │ ├── sam2 │ │ │ │ ├── sam2_hiera_b+.yaml │ │ │ │ ├── sam2_hiera_s.yaml │ │ │ │ ├── sam2_hiera_l.yaml │ │ │ │ └── sam2_hiera_t.yaml │ │ │ └── sam2.1 │ │ │ │ ├── sam2.1_hiera_b+.yaml │ │ │ │ ├── sam2.1_hiera_s.yaml │ │ │ │ ├── sam2.1_hiera_l.yaml │ │ │ │ └── sam2.1_hiera_t.yaml │ │ ├── build_sam.py │ │ └── csrc │ │ │ └── connected_components.cu │ ├── assets │ │ ├── model_diagram.png │ │ └── sa_v_dataset.jpg │ ├── pyproject.toml │ ├── .gitignore │ ├── training │ │ ├── __init__.py │ │ ├── model │ │ │ └── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── data_utils.py │ │ │ └── logger.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── vos_sampler.py │ │ │ ├── utils.py │ │ │ ├── vos_dataset.py │ │ │ └── sam2_datasets.py │ │ ├── assets │ │ │ └── MOSE_sample_val_list.txt │ │ ├── scripts │ │ │ └── sav_frame_extraction_submitit.py │ │ └── README.md │ ├── MANIFEST.in │ ├── .github │ │ └── workflows │ │ │ └── check_fmt.yml │ ├── docker-compose.yaml │ ├── CONTRIBUTING.md │ ├── LICENSE_cctorch │ ├── backend.Dockerfile │ ├── tools │ │ └── README.md │ ├── checkpoints │ │ └── download_ckpts.sh │ ├── .clang-format │ ├── CODE_OF_CONDUCT.md │ ├── RELEASE_NOTES.md │ └── setup.py ├── utils │ ├── __pycache__ │ │ ├── all_configs.cpython-310.pyc │ │ ├── gpt_api_pool.cpython-310.pyc │ │ └── instruction_prompt.cpython-310.pyc │ ├── gpt_api_pool.py │ ├── stat_sam2.py │ ├── all_configs.py │ ├── instruction_prompt.py │ └── stat_knowledge.py ├── test_load_dataset_and_validate.py └── step1_knowledge_integrity.py ├── .gitignore ├── imgs └── figure5.jpg ├── environment.yaml ├── evaluate.py └── README.md /mmmg_eval/sam2/.watchmanconfig: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /results/ 2 | /images/ 3 | /data/ 4 | /output/ -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /imgs/figure5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/imgs/figure5.jpg -------------------------------------------------------------------------------- /mmmg_eval/sam2/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/mmmg_eval/sam2/assets/model_diagram.png -------------------------------------------------------------------------------- /mmmg_eval/sam2/assets/sa_v_dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/mmmg_eval/sam2/assets/sa_v_dataset.jpg -------------------------------------------------------------------------------- /mmmg_eval/utils/__pycache__/all_configs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/mmmg_eval/utils/__pycache__/all_configs.cpython-310.pyc -------------------------------------------------------------------------------- /mmmg_eval/utils/__pycache__/gpt_api_pool.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/mmmg_eval/utils/__pycache__/gpt_api_pool.cpython-310.pyc -------------------------------------------------------------------------------- /mmmg_eval/sam2/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=61.0", 4 | "torch>=2.5.1", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /mmmg_eval/utils/__pycache__/instruction_prompt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MMMGBench/MMMG/HEAD/mmmg_eval/utils/__pycache__/instruction_prompt.cpython-310.pyc -------------------------------------------------------------------------------- /mmmg_eval/sam2/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | outputs/* 10 | checkpoints/*.pt 11 | demo/backend/checkpoints/*.pt 12 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | recursive-include sam2 *.yaml #include all config files 8 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mmmg 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.10 7 | - pip 8 | - pip: 9 | - paddleocr==2.10.0 10 | - paddlepaddle-gpu==2.6.2 11 | - openai==1.82.0 12 | - azure-identity 13 | - datasets 14 | - networkx==3.4.2 15 | - scikit-learn 16 | - matplotlib 17 | - tqdm 18 | - -e ./mmmg_eval/sam2 19 | -------------------------------------------------------------------------------- /mmmg_eval/test_load_dataset_and_validate.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | cache_root = "/detr_blob/v-luoyuxuan/hf_cache" 4 | ds = load_dataset("MMMGbench/MMMGBench", split="test", cache_dir=cache_root, trust_remote_code=True) 5 | 6 | all_jobs = [ 7 | ( 8 | dict(sample), 9 | 10 | ) 11 | for i, sample in enumerate(ds) 12 | ] 13 | 14 | print(all_jobs[0]) -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/.github/workflows/check_fmt.yml: -------------------------------------------------------------------------------- 1 | name: SAM2/fmt 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | jobs: 7 | ufmt_check: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Check formatting 11 | uses: omnilib/ufmt@action-v1 12 | with: 13 | path: sam2 tools 14 | version: "2.0.0b2" 15 | python-version: "3.10" 16 | black-version: "24.2.0" 17 | usort-version: "1.0.2" 18 | -------------------------------------------------------------------------------- /mmmg_eval/utils/gpt_api_pool.py: -------------------------------------------------------------------------------- 1 | gpt_api_pool = [ 2 | #TODO: Add your GPT API pool configuration here 3 | ] 4 | 5 | ''' 6 | AzureOpenAI Resource example: 7 | 8 | { 9 | "index": 13, 10 | "azure_endpoint": YOUR_ENDPOINT, 11 | "api_key":YOUR_API_KEY, 12 | "api_version": PREVIEW, 13 | "rate_limit_requests": RATE_LIMIT 14 | "rate_limit_tokens": TOKEN_NUMS, 15 | "deployment_name": "o3" 16 | } 17 | 18 | NOTICE PLEASE DISABLE THE AZURE CONTENT FILTER. 19 | 20 | We recommend to use the "o3" deployment from official OpenAI website, which is the most powerful model available. 21 | ''' -------------------------------------------------------------------------------- /mmmg_eval/sam2/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | frontend: 3 | image: sam2/frontend 4 | build: 5 | context: ./demo/frontend 6 | dockerfile: frontend.Dockerfile 7 | ports: 8 | - 7262:80 9 | 10 | backend: 11 | image: sam2/backend 12 | build: 13 | context: . 14 | dockerfile: backend.Dockerfile 15 | ports: 16 | - 7263:5000 17 | volumes: 18 | - ./demo/data/:/data/:rw 19 | environment: 20 | - SERVER_ENVIRONMENT=DEV 21 | - GUNICORN_WORKERS=1 22 | # Inference API needs to have at least 2 threads to handle an incoming 23 | # parallel cancel propagation request 24 | - GUNICORN_THREADS=2 25 | - GUNICORN_PORT=5000 26 | - API_URL=http://localhost:7263 27 | - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 28 | # # ffmpeg/video encode settings 29 | - FFMPEG_NUM_THREADS=1 30 | - VIDEO_ENCODE_CODEC=libx264 31 | - VIDEO_ENCODE_CRF=23 32 | - VIDEO_ENCODE_FPS=24 33 | - VIDEO_ENCODE_MAX_WIDTH=1280 34 | - VIDEO_ENCODE_MAX_HEIGHT=720 35 | - VIDEO_ENCODE_VERBOSE=False 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/LICENSE_cctorch: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/backend.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime 2 | ARG MODEL_SIZE=base_plus 3 | 4 | FROM ${BASE_IMAGE} 5 | 6 | # Gunicorn environment variables 7 | ENV GUNICORN_WORKERS=1 8 | ENV GUNICORN_THREADS=2 9 | ENV GUNICORN_PORT=5000 10 | 11 | # SAM 2 environment variables 12 | ENV APP_ROOT=/opt/sam2 13 | ENV PYTHONUNBUFFERED=1 14 | ENV SAM2_BUILD_CUDA=0 15 | ENV MODEL_SIZE=${MODEL_SIZE} 16 | 17 | # Install system requirements 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | ffmpeg \ 20 | libavutil-dev \ 21 | libavcodec-dev \ 22 | libavformat-dev \ 23 | libswscale-dev \ 24 | pkg-config \ 25 | build-essential \ 26 | libffi-dev 27 | 28 | COPY setup.py . 29 | COPY README.md . 30 | 31 | RUN pip install --upgrade pip setuptools 32 | RUN pip install -e ".[interactive-demo]" 33 | 34 | # https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707 35 | RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg 36 | 37 | # Make app directory. This directory will host all files required for the 38 | # backend and SAM 2 inference files. 39 | RUN mkdir ${APP_ROOT} 40 | 41 | # Copy backend server files 42 | COPY demo/backend/server ${APP_ROOT}/server 43 | 44 | # Copy SAM 2 inference files 45 | COPY sam2 ${APP_ROOT}/server/sam2 46 | 47 | # Download SAM 2.1 checkpoints 48 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt 49 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt 50 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt 51 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt 52 | 53 | WORKDIR ${APP_ROOT}/server 54 | 55 | # https://pythonspeed.com/articles/gunicorn-in-docker/ 56 | CMD gunicorn --worker-tmp-dir /dev/shm \ 57 | --worker-class gthread app:app \ 58 | --log-level info \ 59 | --access-logfile /dev/stdout \ 60 | --log-file /dev/stderr \ 61 | --workers ${GUNICORN_WORKERS} \ 62 | --threads ${GUNICORN_THREADS} \ 63 | --bind 0.0.0.0:${GUNICORN_PORT} \ 64 | --timeout 60 65 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/tools/README.md: -------------------------------------------------------------------------------- 1 | ## SAM 2 toolkits 2 | 3 | This directory provides toolkits for additional SAM 2 use cases. 4 | 5 | ### Semi-supervised VOS inference 6 | 7 | The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. 8 | 9 | After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. 10 | ```bash 11 | python ./tools/vos_inference.py \ 12 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 13 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 14 | --base_video_dir /path-to-davis-2017/JPEGImages/480p \ 15 | --input_mask_dir /path-to-davis-2017/Annotations/480p \ 16 | --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ 17 | --output_mask_dir ./outputs/davis_2017_pred_pngs 18 | ``` 19 | (replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) 20 | 21 | To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. 22 | ```bash 23 | python ./tools/vos_inference.py \ 24 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 25 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 26 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 27 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 28 | --video_list_file /path-to-sav-val/sav_val.txt \ 29 | --per_obj_png_file \ 30 | --output_mask_dir ./outputs/sav_val_pred_pngs 31 | ``` 32 | (replace `/path-to-sav-val` with the path to SA-V val) 33 | 34 | Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. 35 | 36 | Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. 37 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | from mmmg_eval.utils.gpt_api_pool import gpt_api_pool 5 | 6 | def run_cmd(cmd): 7 | print(f"[Running] {cmd}") 8 | result = subprocess.run(cmd, shell=True) 9 | if result.returncode != 0: 10 | raise RuntimeError(f"Command failed: {cmd}") 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description="Unified MMMG Evaluation Pipeline") 14 | parser.add_argument("-i", "--img_dir", required=True, help="Directory containing generated images") 15 | parser.add_argument("-o", "--output_dir", default = "./output/", help="Directory to save results") 16 | parser.add_argument("-s", "--sam2_ckpt", required=True, help="Path to SAM2.1 checkpoint") 17 | parser.add_argument("-m", "--t2i_method", required=True, help="Name of the T2I method") 18 | parser.add_argument("-a", "--api_name", required=True, help="Name of the OpenAI evaluator API method") 19 | parser.add_argument("-c", "--hf_cache", default="./data/MMMG", help="HuggingFace cache path (optional)") 20 | 21 | args = parser.parse_args() 22 | api_name = args.api_name 23 | t2i_model = args.t2i_method 24 | img_dir = args.img_dir 25 | out_dir = os.path.join(args.output_dir, t2i_model) 26 | sam2_ckpt = args.sam2_ckpt 27 | hf_cache = args.hf_cache 28 | num_workers = min(40, len(gpt_api_pool)*10) 29 | 30 | os.makedirs(out_dir, exist_ok=True) 31 | 32 | # Step 1 – Knowledge Fidelity 33 | cmd1 = f"python mmmg_eval/step1_knowledge_integrity.py -i {img_dir} -o {out_dir}/step1 -m {t2i_model} -a {api_name} -c {hf_cache} --num_workers {num_workers}" 34 | run_cmd(cmd1) 35 | 36 | # Formulate Knowledge Fidelity result into a JSON file 37 | cmd2 = f"python mmmg_eval/utils/stat_knowledge.py --result_folder {out_dir}/step1 --image_folder {img_dir} --api_name {api_name} --output_dir {out_dir} --save_name {t2i_model}_step1_summarize" 38 | run_cmd(cmd2) 39 | 40 | # Step 2 – Visual Readability 41 | cmd3 = f"python mmmg_eval/step2_readability.py -s {sam2_ckpt} -i {img_dir} -o {out_dir}/step2 --save_name {t2i_model}_step2_summarize" 42 | run_cmd(cmd3) 43 | 44 | # Formulate Readability result into a JSON file 45 | cmd4 = f"python mmmg_eval/utils/stat_sam2.py -s {out_dir}/step2 -o {out_dir} -n {t2i_model}_step2_summarize" 46 | run_cmd(cmd4) 47 | 48 | # Step 3 – Final Score Aggregation 49 | cmd5 = f"python mmmg_eval/step3_stat.py --data_dir {out_dir}/{t2i_model}_step1_summarize.json --score_dir {out_dir}/{t2i_model}_step2_summarize.json --save_dir {out_dir}/{t2i_model}_MMMGStat.json" 50 | run_cmd(cmd5) 51 | 52 | print(f"\n✅ Evaluation complete. Final results saved in {out_dir}/final") 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/checkpoints/download_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Use either wget or curl to download the checkpoints 10 | if command -v wget &> /dev/null; then 11 | CMD="wget" 12 | elif command -v curl &> /dev/null; then 13 | CMD="curl -L -O" 14 | else 15 | echo "Please install wget or curl to download the checkpoints." 16 | exit 1 17 | fi 18 | 19 | # Define the URLs for SAM 2 checkpoints 20 | # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824" 21 | # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt" 22 | # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt" 23 | # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt" 24 | # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt" 25 | 26 | # Download each of the four checkpoints using wget 27 | # echo "Downloading sam2_hiera_tiny.pt checkpoint..." 28 | # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; } 29 | 30 | # echo "Downloading sam2_hiera_small.pt checkpoint..." 31 | # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; } 32 | 33 | # echo "Downloading sam2_hiera_base_plus.pt checkpoint..." 34 | # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; } 35 | 36 | # echo "Downloading sam2_hiera_large.pt checkpoint..." 37 | # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; } 38 | 39 | # Define the URLs for SAM 2.1 checkpoints 40 | SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824" 41 | sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt" 42 | sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt" 43 | sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt" 44 | sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt" 45 | 46 | # SAM 2.1 checkpoints 47 | echo "Downloading sam2.1_hiera_tiny.pt checkpoint..." 48 | $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; } 49 | 50 | echo "Downloading sam2.1_hiera_small.pt checkpoint..." 51 | $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; } 52 | 53 | echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..." 54 | $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; } 55 | 56 | echo "Downloading sam2.1_hiera_large.pt checkpoint..." 57 | $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; } 58 | 59 | echo "All checkpoints are downloaded successfully." 60 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/.clang-format: -------------------------------------------------------------------------------- 1 | AccessModifierOffset: -1 2 | AlignAfterOpenBracket: AlwaysBreak 3 | AlignConsecutiveAssignments: false 4 | AlignConsecutiveDeclarations: false 5 | AlignEscapedNewlinesLeft: true 6 | AlignOperands: false 7 | AlignTrailingComments: false 8 | AllowAllParametersOfDeclarationOnNextLine: false 9 | AllowShortBlocksOnASingleLine: false 10 | AllowShortCaseLabelsOnASingleLine: false 11 | AllowShortFunctionsOnASingleLine: Empty 12 | AllowShortIfStatementsOnASingleLine: false 13 | AllowShortLoopsOnASingleLine: false 14 | AlwaysBreakAfterReturnType: None 15 | AlwaysBreakBeforeMultilineStrings: true 16 | AlwaysBreakTemplateDeclarations: true 17 | BinPackArguments: false 18 | BinPackParameters: false 19 | BraceWrapping: 20 | AfterClass: false 21 | AfterControlStatement: false 22 | AfterEnum: false 23 | AfterFunction: false 24 | AfterNamespace: false 25 | AfterObjCDeclaration: false 26 | AfterStruct: false 27 | AfterUnion: false 28 | BeforeCatch: false 29 | BeforeElse: false 30 | IndentBraces: false 31 | BreakBeforeBinaryOperators: None 32 | BreakBeforeBraces: Attach 33 | BreakBeforeTernaryOperators: true 34 | BreakConstructorInitializersBeforeComma: false 35 | BreakAfterJavaFieldAnnotations: false 36 | BreakStringLiterals: false 37 | ColumnLimit: 80 38 | CommentPragmas: '^ IWYU pragma:' 39 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 40 | ConstructorInitializerIndentWidth: 4 41 | ContinuationIndentWidth: 4 42 | Cpp11BracedListStyle: true 43 | DerivePointerAlignment: false 44 | DisableFormat: false 45 | ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] 46 | IncludeCategories: 47 | - Regex: '^<.*\.h(pp)?>' 48 | Priority: 1 49 | - Regex: '^<.*' 50 | Priority: 2 51 | - Regex: '.*' 52 | Priority: 3 53 | IndentCaseLabels: true 54 | IndentWidth: 2 55 | IndentWrappedFunctionNames: false 56 | KeepEmptyLinesAtTheStartOfBlocks: false 57 | MacroBlockBegin: '' 58 | MacroBlockEnd: '' 59 | MaxEmptyLinesToKeep: 1 60 | NamespaceIndentation: None 61 | ObjCBlockIndentWidth: 2 62 | ObjCSpaceAfterProperty: false 63 | ObjCSpaceBeforeProtocolList: false 64 | PenaltyBreakBeforeFirstCallParameter: 1 65 | PenaltyBreakComment: 300 66 | PenaltyBreakFirstLessLess: 120 67 | PenaltyBreakString: 1000 68 | PenaltyExcessCharacter: 1000000 69 | PenaltyReturnTypeOnItsOwnLine: 200 70 | PointerAlignment: Left 71 | ReflowComments: true 72 | SortIncludes: true 73 | SpaceAfterCStyleCast: false 74 | SpaceBeforeAssignmentOperators: true 75 | SpaceBeforeParens: ControlStatements 76 | SpaceInEmptyParentheses: false 77 | SpacesBeforeTrailingComments: 1 78 | SpacesInAngles: false 79 | SpacesInContainerLiterals: true 80 | SpacesInCStyleCastParentheses: false 81 | SpacesInParentheses: false 82 | SpacesInSquareBrackets: false 83 | Standard: Cpp11 84 | TabWidth: 8 85 | UseTab: Never 86 | -------------------------------------------------------------------------------- /mmmg_eval/utils/stat_sam2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Collect SAM-2.1 segmentation stats. 6 | 7 | Example 8 | ------- 9 | python collect_stats.py \ 10 | --src_dir /path/to/sam2_outputs \ 11 | --dest_dir /path/to/statistics \ 12 | --name my_model \ 13 | --num_workers 8 14 | """ 15 | 16 | import argparse 17 | import json 18 | import os 19 | from functools import partial 20 | from multiprocessing import Pool, cpu_count 21 | 22 | 23 | def process_file(src_dir: str, filename: str): 24 | """Read //anno.json and return a (key, info_dict) pair.""" 25 | try: 26 | file_path = os.path.join(src_dir, filename, "anno.json") 27 | image_path = os.path.join(src_dir, filename, "sam2_result.png") 28 | 29 | with open(file_path, "r") as f: 30 | data = json.load(f) 31 | 32 | key = data["image_uid"].split("__")[-1].split(".")[0] 33 | return key, { 34 | "region_count": data["region_count"], 35 | "image_path": data["image_path"], 36 | "sam_path": image_path, 37 | } 38 | except Exception as e: # noqa: BLE001 39 | print(f"[skip] {filename}: {e}") 40 | return None 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser( 45 | description="Merge per-image anno.json into a single stat file." 46 | ) 47 | parser.add_argument("-s", "--src_dir", required=True, 48 | help="Root dir that contains per-image sub-folders") 49 | parser.add_argument("-o", "--dest_dir", required=True, 50 | help="Folder to save the merged stat JSON") 51 | parser.add_argument("-n", "--name", 52 | help="Prefix for output JSON (default: basename of src_dir)") 53 | parser.add_argument("-j", "--num_workers", type=int, default=cpu_count(), 54 | help="Parallel workers (default: CPU cores)") 55 | args = parser.parse_args() 56 | 57 | src_dir = os.path.abspath(args.src_dir) 58 | dest_dir = os.path.abspath(args.dest_dir) 59 | name = args.name or os.path.basename(src_dir.rstrip("/")) 60 | os.makedirs(dest_dir, exist_ok=True) 61 | 62 | # enumerate sub-folders 63 | sub_dirs = [d for d in os.listdir(src_dir) 64 | if os.path.isdir(os.path.join(src_dir, d))] 65 | print(f"Found {len(sub_dirs)} samples in {src_dir}") 66 | 67 | # pool-map 68 | with Pool(processes=args.num_workers) as pool: 69 | results = pool.map(partial(process_file, src_dir), sub_dirs) 70 | 71 | # flatten & filter None 72 | merged = {k: v for pair in results if pair for k, v in [pair]} 73 | 74 | # save 75 | out_path = os.path.join(dest_dir, f"{name}.json") 76 | with open(out_path, "w") as f: 77 | json.dump(merged, f, indent=4) 78 | print(f"Saved {len(merged)} entries → {out_path}") 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/assets/MOSE_sample_val_list.txt: -------------------------------------------------------------------------------- 1 | 32e5d721 2 | 5bad0bab 3 | 267bfd6c 4 | 0a43a414 5 | 56c56ca9 6 | 9a1146b3 7 | c6ad7aaf 8 | 78a1f4b1 9 | fc455e73 10 | 072e7b3f 11 | 77ccb57d 12 | a76ee415 13 | 8cdcfc17 14 | 5d518b42 15 | 376dd830 16 | 0e843fc8 17 | 2af0e766 18 | 2bd4e845 19 | de2f2a6a 20 | ade9ee91 21 | 001ca3cb 22 | fc4c1c67 23 | 8ef55579 24 | b84ce852 25 | 4cc8528a 26 | 767ffaaa 27 | 112a2ef0 28 | a338c8aa 29 | cbd144f5 30 | 5ff72128 31 | 86a949e2 32 | 9f2323ac 33 | 1fab1d1c 34 | 75924351 35 | ef55817b 36 | 02deca50 37 | 4d979d99 38 | 4d65f873 39 | 28470fa0 40 | 0d1575fe 41 | 06ea172e 42 | 29a6ddc2 43 | 797f1bec 44 | 780e7a99 45 | b9ed5b44 46 | 02a236b4 47 | 607d8ff5 48 | af5666b2 49 | 0558d0ed 50 | a938c6b2 51 | 103df575 52 | 77110e80 53 | 739e5a07 54 | 6763a576 55 | 06ebc138 56 | ba4b3b09 57 | b35cc2f3 58 | 4e0597a0 59 | 5949ee84 60 | 5348d547 61 | 323c4236 62 | b3b51117 63 | 55727ddd 64 | ab2714f3 65 | d2878895 66 | c0734cb3 67 | 94f7c53e 68 | 2a2745e5 69 | 442ffb54 70 | 3592425a 71 | 50ae03b0 72 | 5f150435 73 | 3067f9fa 74 | 9ffb2818 75 | adeaf5aa 76 | 31caacec 77 | 1cd99b86 78 | aa22f9d0 79 | 8fa50320 80 | e6348d2c 81 | 42ff84a5 82 | 8c8b7913 83 | c96adcbc 84 | 495be321 85 | db735509 86 | ee113fc4 87 | a678cdab 88 | c409ca4d 89 | 68d2b259 90 | 592b4dee 91 | 4e2b4dc7 92 | eb4d26e1 93 | 2009a00f 94 | bec5c89d 95 | 67191f24 96 | a3e85b4b 97 | da7080cd 98 | 80d978e9 99 | 36dcb93f 100 | a41e8c44 101 | 12fdc864 102 | 46d140ea 103 | 657c9dd9 104 | a86f84ee 105 | 90c1c43d 106 | 33015509 107 | afc7664d 108 | 23df06e1 109 | 291d4799 110 | 0ab75563 111 | 251bf059 112 | bcefdcc4 113 | ce9a2796 114 | 94d3403a 115 | 8f2e04bc 116 | f9cda066 117 | 9dfa2cc5 118 | 66924c91 119 | e765a09e 120 | 15654ee1 121 | 48e0bd39 122 | ee095221 123 | 2463609b 124 | 544d0d1f 125 | 51b8c2e1 126 | d321dde4 127 | 4cb11a5f 128 | d7058a0d 129 | 37af282a 130 | fabae187 131 | 7be91184 132 | 181ec185 133 | 2d16ceeb 134 | b56be4b1 135 | 6699eff0 136 | 79acac96 137 | d61c4665 138 | 0c13e1e7 139 | 100f6ecf 140 | 71217dfc 141 | 82df0888 142 | 4c42c747 143 | c9fdf703 144 | d2efeb4b 145 | 69ed9d14 146 | 64914fb6 147 | 255bedbc 148 | 4ea934d8 149 | a034feb2 150 | e4f4ddae 151 | e36a3026 152 | c1489591 153 | 111bb373 154 | e1d9fb32 155 | 93e22d48 156 | c1ec4b26 157 | d9638e69 158 | 60ab04c5 159 | cfe7773a 160 | 62132822 161 | 2f5fb2a3 162 | 7bdd197d 163 | 033333fd 164 | 130fcdbe 165 | 12e509c2 166 | 67138c33 167 | 6f90cc5f 168 | 4e3020fe 169 | bbdd8bb7 170 | b399ccdb 171 | fecd10d2 172 | 2e0967f7 173 | f509054f 174 | 792c6ff7 175 | 48e2afc5 176 | d904c048 177 | 111e0a5c 178 | b83024e2 179 | e6a7b79c 180 | bdc5ccf7 181 | b8146d00 182 | 9d394f1a 183 | 645b84f9 184 | 95ab2d0f 185 | e6f8a31d 186 | b4f876fb 187 | dc2c570d 188 | 3afd02d7 189 | 5c80c82c 190 | b1b32ddd 191 | 9f25fc61 192 | ba538072 193 | f8916fef 194 | 43c04ad2 195 | a658e949 196 | 2861dd53 197 | f6e40aba 198 | 09d305d1 199 | aac33bff 200 | 8d9d4c08 201 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from sam2.build_sam import build_sam2_video_predictor 15 | 16 | # Only cuda supported 17 | assert torch.cuda.is_available() 18 | device = torch.device("cuda") 19 | 20 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 21 | if torch.cuda.get_device_properties(0).major >= 8: 22 | # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) 23 | torch.backends.cuda.matmul.allow_tf32 = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | 26 | # Config and checkpoint 27 | sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" 28 | model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" 29 | 30 | # Build video predictor with vos_optimized=True setting 31 | predictor = build_sam2_video_predictor( 32 | model_cfg, sam2_checkpoint, device=device, vos_optimized=True 33 | ) 34 | 35 | 36 | # Initialize with video 37 | video_dir = "notebooks/videos/bedroom" 38 | # scan all the JPEG frame names in this directory 39 | frame_names = [ 40 | p 41 | for p in os.listdir(video_dir) 42 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 43 | ] 44 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 45 | inference_state = predictor.init_state(video_path=video_dir) 46 | 47 | 48 | # Number of runs, warmup etc 49 | warm_up, runs = 5, 25 50 | verbose = True 51 | num_frames = len(frame_names) 52 | total, count = 0, 0 53 | torch.cuda.empty_cache() 54 | 55 | # We will select an object with a click. 56 | # See video_predictor_example.ipynb for more detailed explanation 57 | ann_frame_idx, ann_obj_id = 0, 1 58 | # Add a positive click at (x, y) = (210, 350) 59 | # For labels, `1` means positive click 60 | points = np.array([[210, 350]], dtype=np.float32) 61 | labels = np.array([1], np.int32) 62 | 63 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( 64 | inference_state=inference_state, 65 | frame_idx=ann_frame_idx, 66 | obj_id=ann_obj_id, 67 | points=points, 68 | labels=labels, 69 | ) 70 | 71 | # Warmup and then average FPS over several runs 72 | with torch.autocast("cuda", torch.bfloat16): 73 | with torch.inference_mode(): 74 | for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): 75 | start = time.time() 76 | # Start tracking 77 | for ( 78 | out_frame_idx, 79 | out_obj_ids, 80 | out_mask_logits, 81 | ) in predictor.propagate_in_video(inference_state): 82 | pass 83 | 84 | end = time.time() 85 | total += end - start 86 | count += 1 87 | if i == warm_up - 1: 88 | print("Warmup FPS: ", count * num_frames / total) 89 | total = 0 90 | count = 0 91 | 92 | print("FPS: ", count * num_frames / total) 93 | -------------------------------------------------------------------------------- /mmmg_eval/utils/all_configs.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets 2 | import os 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | def load_all_mmmg_configs(cache_dir="~/.cache/mmmg", max_workers=8): 6 | """ 7 | Load and concatenate all config splits from the MMMGbench/MMMG dataset. 8 | 9 | Args: 10 | cache_dir (str): HuggingFace cache directory. 11 | max_workers (int): Maximum number of parallel threads to use. 12 | 13 | Returns: 14 | Dataset: The concatenated dataset with all config entries. 15 | """ 16 | all_configs = [ 17 | 'PhD_Biology', 'PhD_Chemistry', 'PhD_Economics', 'PhD_Engineering', 'PhD_Geography', 18 | 'PhD_History', 'PhD_Literature', 'PhD_Math', 'PhD_Philosophy', 'PhD_Sociology', 19 | 'highschool_Biology', 'highschool_Chemistry', 'highschool_Economics', 'highschool_Engineering', 20 | 'highschool_Geography', 'highschool_History', 'highschool_Literature', 'highschool_Math', 21 | 'highschool_Philosophy', 'highschool_Sociology', 'preschool_Biology', 'preschool_Chemistry', 22 | 'preschool_Economics', 'preschool_Engineering', 'preschool_Geography', 'preschool_History', 23 | 'preschool_Literature', 'preschool_Math', 'preschool_Sociology', 'primaryschool_Biology', 24 | 'primaryschool_Chemistry', 'primaryschool_Economics', 'primaryschool_Engineering', 25 | 'primaryschool_Geography', 'primaryschool_History', 'primaryschool_Literature', 26 | 'primaryschool_Math', 'primaryschool_Philosophy', 'primaryschool_Sociology', 27 | 'secondaryschool_Biology', 'secondaryschool_Chemistry', 'secondaryschool_Economics', 28 | 'secondaryschool_Engineering', 'secondaryschool_Geography', 'secondaryschool_History', 29 | 'secondaryschool_Literature', 'secondaryschool_Math', 'secondaryschool_Philosophy', 30 | 'secondaryschool_Sociology', 'undergraduate_Biology', 'undergraduate_Chemistry', 31 | 'undergraduate_Economics', 'undergraduate_Engineering', 'undergraduate_Geography', 32 | 'undergraduate_History', 'undergraduate_Literature', 'undergraduate_Math', 33 | 'undergraduate_Philosophy', 'undergraduate_Sociology' 34 | ] 35 | 36 | expanded_cache_dir = os.path.expanduser(cache_dir) 37 | 38 | def load_one_config(cfg_name): 39 | print(f"Loading config: {cfg_name}...") 40 | return load_dataset( 41 | "MMMGbench/MMMG", 42 | name=cfg_name, 43 | split="test", 44 | cache_dir=expanded_cache_dir, 45 | trust_remote_code=True, 46 | ) 47 | 48 | all_datasets = [] 49 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 50 | futures = {executor.submit(load_one_config, cfg): cfg for cfg in all_configs} 51 | for future in as_completed(futures): 52 | cfg = futures[future] 53 | try: 54 | ds = future.result() 55 | all_datasets.append(ds) 56 | except Exception as e: 57 | print(f"❌ Failed to load {cfg}: {e}") 58 | 59 | full_dataset = concatenate_datasets(all_datasets) 60 | print(f"✅ Loaded total {len(full_dataset):,} samples from {len(all_datasets)} configs.") 61 | return full_dataset 62 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) 36 | return windows, (Hp, Wp) 37 | 38 | 39 | def window_unpartition(windows, window_size, pad_hw, hw): 40 | """ 41 | Window unpartition into original sequences and removing padding. 42 | Args: 43 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 44 | window_size (int): window size. 45 | pad_hw (Tuple): padded height and width (Hp, Wp). 46 | hw (Tuple): original height and width (H, W) before padding. 47 | Returns: 48 | x: unpartitioned sequences with [B, H, W, C]. 49 | """ 50 | Hp, Wp = pad_hw 51 | H, W = hw 52 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 53 | x = windows.reshape( 54 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 55 | ) 56 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) 57 | 58 | if Hp > H or Wp > W: 59 | x = x[:, :H, :W, :] 60 | return x 61 | 62 | 63 | class PatchEmbed(nn.Module): 64 | """ 65 | Image to Patch Embedding. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | kernel_size: Tuple[int, ...] = (7, 7), 71 | stride: Tuple[int, ...] = (4, 4), 72 | padding: Tuple[int, ...] = (3, 3), 73 | in_chans: int = 3, 74 | embed_dim: int = 768, 75 | ): 76 | """ 77 | Args: 78 | kernel_size (Tuple): kernel size of the projection layer. 79 | stride (Tuple): stride of the projection layer. 80 | padding (Tuple): padding size of the projection layer. 81 | in_chans (int): Number of input image channels. 82 | embed_dim (int): embed_dim (int): Patch embedding dimension. 83 | """ 84 | super().__init__() 85 | self.proj = nn.Conv2d( 86 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 87 | ) 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | x = self.proj(x) 91 | # B C H W -> B H W C 92 | x = x.permute(0, 2, 3, 1) 93 | return x 94 | -------------------------------------------------------------------------------- /mmmg_eval/utils/instruction_prompt.py: -------------------------------------------------------------------------------- 1 | instruction_prompt = ''' 2 | This evaluation is part of a research study on visual grounding of abstract concepts. No jailbreak or prompt injection is intended. 3 | 4 | Please provide an extremely detailed description of the visual content of this image. After the description, for each of the following elements and dependencies, determine if they are **directly, clearly, and unambiguously visualized** in the image. Output "yes" or "no" for each. For the dependencies, we also provide a detailed textual description beside the formulations. 5 | 6 | # Important Instructions: 7 | 8 | * **Base your judgment solely on what is explicitly visible in the image.** Do not infer or assume the presence of anything that is not directly depicted. 9 | * **If the element or dependency is not clearly visible, or if it is only implied, answer "no".** 10 | 11 | * For elements, the specific object or concept must be clearly identifiable in the image. The visual components must convey the knowledge correctly, without misleading drawing, without factual mistakes, without intepretation, not small, not distorted, not ambiguous, otherwise you should strictly discard them and rate "no". 12 | 13 | * For dependencies, you must give your answer accompanied by a brief explanation of why do you give such judgement. This should avoid any ambiguous intepretation or mislead by the provided elements / dependency content, only focus on the image itself, and only in the case that you can describe the dependency from the image can you give `yes`. The dependencies are: 14 | * **Defines:** Look for clear, strong, prominent visual cues suggesting the first element in a way that clearly defines or illustrates the second element. Any ambiguous or inferential patterns should lead to "no". 15 | * **Contains:** Look for clear, strong, prominent visual cues suggesting the first element as a part of or within the second element. Any ambiguous or inferential patterns should lead to "no". 16 | * **Requires:** Look for clear, strong, prominent visual cues suggesting the first element necessitates the presence or use of the second element (e.g., a boiler visibly connected to or interacting with a working fluid). 17 | * **Entails:** Look for clear, strong, prominent visual cues suggesting the first element leading to or involving the second element (e.g., a boiler clearly connected to a turbine). 18 | * **Causes:** Look for clear, strong, prominent visual cues suggesting a causal relationship between the two elements (this might be challenging for static images). 19 | * **TemporalOrder:** Look for visual cues suggesting a sequence or flow between the elements (e.g., pipes or connections implying a direction). If no clear visual cue for temporal order exists, answer "no". 20 | 21 | * **Exclude any entity or dependency that is absent, unclear, or based on factual artifacts or external knowledge not directly shown.** 22 | * For abstract concepts only answer "yes" if the key visual components and their interactions characteristic of these concepts are clearly and directly depicted. 23 | 24 | The elements and dependencies are as follows, where there are no offensive or inappropriate elements, just educational ones: 25 | [ELEM_DEPEND] 26 | 27 | For the output format, please use the following structure: 28 | **Image Description:** 29 | [IMAGE_DESCRIPTION] 30 | **Element and Dependency Analysis:** 31 | ** Element Evaluation: ** 32 | * [ELEMENT_1]: [yes/no] 33 | * [ELEMENT_2]: [yes/no] 34 | ... 35 | ** Dependency Evaluation: ** 36 | * [DEPENDENCY_1]: [yes/no] [Provide a brief explanation for your reason to support your judge.] 37 | * [DEPENDENCY_2]: [yes/no] [Provide a brief explanation for your reason to support your judge.] 38 | ... 39 | ''' 40 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/RELEASE_NOTES.md: -------------------------------------------------------------------------------- 1 | ## SAM 2 release notes 2 | 3 | ### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking 4 | 5 | - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). 6 | * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. 7 | * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. 8 | * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. 9 | * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. 10 | - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: 11 | * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. 12 | * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). 13 | * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class. 14 | 15 | ### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released 16 | 17 | - A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. 18 | * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. 19 | - The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. 20 | - The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. 21 | 22 | ### 07/29/2024 -- SAM 2 is released 23 | 24 | - We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos. 25 | * SAM 2 code: https://github.com/facebookresearch/sam2 26 | * SAM 2 demo: https://sam2.metademolab.com/ 27 | * SAM 2 paper: https://arxiv.org/abs/2408.00714 28 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/dataset/vos_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from dataclasses import dataclass 9 | from typing import List 10 | 11 | from training.dataset.vos_segment_loader import LazySegments 12 | 13 | MAX_RETRIES = 1000 14 | 15 | 16 | @dataclass 17 | class SampledFramesAndObjects: 18 | frames: List[int] 19 | object_ids: List[int] 20 | 21 | 22 | class VOSSampler: 23 | def __init__(self, sort_frames=True): 24 | # frames are ordered by frame id when sort_frames is True 25 | self.sort_frames = sort_frames 26 | 27 | def sample(self, video): 28 | raise NotImplementedError() 29 | 30 | 31 | class RandomUniformSampler(VOSSampler): 32 | def __init__( 33 | self, 34 | num_frames, 35 | max_num_objects, 36 | reverse_time_prob=0.0, 37 | ): 38 | self.num_frames = num_frames 39 | self.max_num_objects = max_num_objects 40 | self.reverse_time_prob = reverse_time_prob 41 | 42 | def sample(self, video, segment_loader, epoch=None): 43 | 44 | for retry in range(MAX_RETRIES): 45 | if len(video.frames) < self.num_frames: 46 | raise Exception( 47 | f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." 48 | ) 49 | start = random.randrange(0, len(video.frames) - self.num_frames + 1) 50 | frames = [video.frames[start + step] for step in range(self.num_frames)] 51 | if random.uniform(0, 1) < self.reverse_time_prob: 52 | # Reverse time 53 | frames = frames[::-1] 54 | 55 | # Get first frame object ids 56 | visible_object_ids = [] 57 | loaded_segms = segment_loader.load(frames[0].frame_idx) 58 | if isinstance(loaded_segms, LazySegments): 59 | # LazySegments for SA1BRawDataset 60 | visible_object_ids = list(loaded_segms.keys()) 61 | else: 62 | for object_id, segment in segment_loader.load( 63 | frames[0].frame_idx 64 | ).items(): 65 | if segment.sum(): 66 | visible_object_ids.append(object_id) 67 | 68 | # First frame needs to have at least a target to track 69 | if len(visible_object_ids) > 0: 70 | break 71 | if retry >= MAX_RETRIES - 1: 72 | raise Exception("No visible objects") 73 | 74 | object_ids = random.sample( 75 | visible_object_ids, 76 | min(len(visible_object_ids), self.max_num_objects), 77 | ) 78 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 79 | 80 | 81 | class EvalSampler(VOSSampler): 82 | """ 83 | VOS Sampler for evaluation: sampling all the frames and all the objects in a video 84 | """ 85 | 86 | def __init__( 87 | self, 88 | ): 89 | super().__init__() 90 | 91 | def sample(self, video, segment_loader, epoch=None): 92 | """ 93 | Sampling all the frames and all the objects 94 | """ 95 | if self.sort_frames: 96 | # ordered by frame id 97 | frames = sorted(video.frames, key=lambda x: x.frame_idx) 98 | else: 99 | # use the original order 100 | frames = video.frames 101 | object_ids = segment_loader.load(frames[0].frame_idx).keys() 102 | if len(object_ids) == 0: 103 | raise Exception("First frame of the video has no objects") 104 | 105 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | MMMG: A Massive, Multidisciplinary, Multi-Tier Generation Benchmark for Text-to-Image Reasoning 3 | 4 |

5 | 6 | Official evaluation toolkit for **MMMG**: the **M**assive **M**ulti-discipline **M**ulti-tier Knowledge-Image **G**eneration benchmark. 7 | 8 | * ✨ **[Project Page](https://mmmgbench.github.io/)** 9 | * 📄 **[Paper (arXiv 2506.10963)](https://arxiv.org/abs/2506.10963)** 10 | * 💾 **[MMMG Dataset on HuggingFace](https://huggingface.co/datasets/MMMGBench/MMMG)** 11 | * 📷 **[Sampled Results](https://huggingface.co/datasets/MMMGBench/MMMG_Result)** 12 | * 📂 **[Training Set](https://huggingface.co/datasets/MMMGBench/MMMG_Train)** 13 | 14 | --- 15 | 16 | 17 | ## ✨ Overview 18 | ![teaser](imgs/figure5.jpg) 19 | 20 | **MMMG** is a large-scale benchmark designed to assess text-to-image (T2I) models on their ability to generate *faithful* and *visually readable* images based on knowledge-intensive prompts, spanning multiple academic disciplines and educational levels. 21 | 22 | **MMMG-Score** is computed as: 23 | 24 | > **MMMG-Score = Knowledge Fidelity (1 - GED) × Visual Readability (SAM2.1)** 25 | 26 | Where: 27 | 28 | * **GED**: Graph Edit Distance between predicted and ground-truth concept graphs. 29 | * **SAM2.1**: Visual readability score based on SAM2.1 segmentation accuracy. 30 | 31 | --- 32 | 33 | ## 📬 News 34 | - **2025.11.29** We have benchmarked [Nano Banana Pro](https://blog.google/technology/ai/nano-banana-pro/), which is currently the leading model. 35 | - **2025.9.19** Our work has been accepted by NeurIPS 2025! 36 | - **2025.6.10** The repository has been updated. 37 | 38 | --- 39 | 40 | ## ♻️ Installation 41 | 42 | ```bash 43 | git clone https://github.com/MMMGBench/MMMG.git 44 | cd MMMG 45 | conda env create -f environment.yaml 46 | conda activate mmmg 47 | ``` 48 | 49 | --- 50 | 51 | ## 📊 Dataset Preparation 52 | 53 | Place your generated images under the following structure: 54 | 55 | ``` 56 | /data/ 57 | ├─ preschool/ 58 | ├─ primaryschool/ 59 | ├─ secondaryschool/ 60 | ├─ highschool/ 61 | ├─ undergraduate/ 62 | └─ PhD/ 63 | ``` 64 | 65 | Each folder contains model-generated images named as `.png`. 66 | 67 | --- 68 | 69 | ## 💡 Run Evaluation 70 | 71 | We use the Azure OpenAI service for knowledge integrity evaluation. If you use a different API interface (e.g., from OpenAI website), please **modify**: 72 | 73 | ```bash 74 | mmmg_eval/step1_knowledge_integrity.py 75 | ``` 76 | 77 | Insert your API keys into: 78 | 79 | ```bash 80 | mmmg_eval/utils/gpt_api_pool.py 81 | ``` 82 | 83 | ### Example: Evaluate GPT-4o Generations 84 | 85 | ```bash 86 | python evaluate.py \ 87 | --img_dir ./data/GPT-4o \ 88 | --output_dir ./output \ 89 | --sam2_ckpt /YOUR/PATH/TO/sam2/checkpoints/sam2.1_hiera_large.pt \ 90 | --t2i_method GPT-4o \ 91 | --api_name o3 \ 92 | --hf_cache ./data/MMMG 93 | ``` 94 | 95 | ### Arguments 96 | 97 | * `--img_dir`: Path to generated images (organized by education tier). 98 | * `--output_dir`: Where evaluation logs and scores will be saved. 99 | * `--sam2_ckpt`: Path to the pretrained SAM2.1 checkpoint. 100 | * `--t2i_method`: Name of the T2I model under evaluation. 101 | * `--api_name`: LLM backend (e.g., `gpt-4`, `gpt-4o`, `o3`). 102 | * `--hf_cache`: Path to HuggingFace cache for loading ground-truth graphs. 103 | 104 | --- 105 | 106 | ## 📅 Citation 107 | 108 | If you find MMMG helpful in your research, please consider citing our paper: 109 | 110 | ```bibtex 111 | @inproceedings{luo2025mmmg, 112 | title={Mmmg: A massive, multidisciplinary, multi-tier generation benchmark for text-to-image reasoning}, 113 | author={Luo, Yuxuan and Yuan, Yuhui and Chen, Junwen and Cai, Haonan and Yue, Ziyi and Yang, Yuwei and Daha, Fatima Zohra and Li, Ji and Lian, Zhouhui}, 114 | booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, 115 | year={2025} 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [64, 64] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [64, 64] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [64, 64] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [64, 64] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [64, 64] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [64, 64] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | no_obj_embed_spatial: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: true 105 | proj_tpos_enc_in_obj_ptrs: true 106 | use_signed_tpos_enc_to_obj_ptrs: true 107 | only_obj_ptrs_in_the_past_for_eval: true 108 | # object occlusion prediction 109 | pred_obj_scores: true 110 | pred_obj_scores_mlp: true 111 | fixed_no_obj_ptr: true 112 | # multimask tracking settings 113 | multimask_output_for_tracking: true 114 | use_multimask_token_for_obj_ptr: true 115 | multimask_min_pt_num: 0 116 | multimask_max_pt_num: 1 117 | use_mlp_for_obj_ptr_proj: true 118 | # Compilation flag 119 | compile_image_encoder: False 120 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | from torch.utils.data import ( 13 | ConcatDataset as TorchConcatDataset, 14 | Dataset, 15 | Subset as TorchSubset, 16 | ) 17 | 18 | 19 | class ConcatDataset(TorchConcatDataset): 20 | def __init__(self, datasets: Iterable[Dataset]) -> None: 21 | super(ConcatDataset, self).__init__(datasets) 22 | 23 | self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) 24 | 25 | def set_epoch(self, epoch: int): 26 | for dataset in self.datasets: 27 | if hasattr(dataset, "epoch"): 28 | dataset.epoch = epoch 29 | if hasattr(dataset, "set_epoch"): 30 | dataset.set_epoch(epoch) 31 | 32 | 33 | class Subset(TorchSubset): 34 | def __init__(self, dataset, indices) -> None: 35 | super(Subset, self).__init__(dataset, indices) 36 | 37 | self.repeat_factors = dataset.repeat_factors[indices] 38 | assert len(indices) == len(self.repeat_factors) 39 | 40 | 41 | # Adapted from Detectron2 42 | class RepeatFactorWrapper(Dataset): 43 | """ 44 | Thin wrapper around a dataset to implement repeat factor sampling. 45 | The underlying dataset must have a repeat_factors member to indicate the per-image factor. 46 | Set it to uniformly ones to disable repeat factor sampling 47 | """ 48 | 49 | def __init__(self, dataset, seed: int = 0): 50 | self.dataset = dataset 51 | self.epoch_ids = None 52 | self._seed = seed 53 | 54 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 55 | self._int_part = torch.trunc(dataset.repeat_factors) 56 | self._frac_part = dataset.repeat_factors - self._int_part 57 | 58 | def _get_epoch_indices(self, generator): 59 | """ 60 | Create a list of dataset indices (with repeats) to use for one epoch. 61 | 62 | Args: 63 | generator (torch.Generator): pseudo random number generator used for 64 | stochastic rounding. 65 | 66 | Returns: 67 | torch.Tensor: list of dataset indices to use in one epoch. Each index 68 | is repeated based on its calculated repeat factor. 69 | """ 70 | # Since repeat factors are fractional, we use stochastic rounding so 71 | # that the target repeat factor is achieved in expectation over the 72 | # course of training 73 | rands = torch.rand(len(self._frac_part), generator=generator) 74 | rep_factors = self._int_part + (rands < self._frac_part).float() 75 | # Construct a list of indices in which we repeat images as specified 76 | indices = [] 77 | for dataset_index, rep_factor in enumerate(rep_factors): 78 | indices.extend([dataset_index] * int(rep_factor.item())) 79 | return torch.tensor(indices, dtype=torch.int64) 80 | 81 | def __len__(self): 82 | if self.epoch_ids is None: 83 | # Here we raise an error instead of returning original len(self.dataset) avoid 84 | # accidentally using unwrapped length. Otherwise it's error-prone since the 85 | # length changes to `len(self.epoch_ids)`changes after set_epoch is called. 86 | raise RuntimeError("please call set_epoch first to get wrapped length") 87 | # return len(self.dataset) 88 | 89 | return len(self.epoch_ids) 90 | 91 | def set_epoch(self, epoch: int): 92 | g = torch.Generator() 93 | g.manual_seed(self._seed + epoch) 94 | self.epoch_ids = self._get_epoch_indices(g) 95 | if hasattr(self.dataset, "set_epoch"): 96 | self.dataset.set_epoch(epoch) 97 | 98 | def __getitem__(self, idx): 99 | if self.epoch_ids is None: 100 | raise RuntimeError( 101 | "Repeat ids haven't been computed. Did you forget to call set_epoch?" 102 | ) 103 | 104 | return self.dataset[self.epoch_ids[idx]] 105 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [64, 64] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [64, 64] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | self.d_model = d_model 75 | for dim in backbone_channel_list: 76 | current = nn.Sequential() 77 | current.add_module( 78 | "conv", 79 | nn.Conv2d( 80 | in_channels=dim, 81 | out_channels=d_model, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | ), 86 | ) 87 | 88 | self.convs.append(current) 89 | self.fpn_interp_model = fpn_interp_model 90 | assert fuse_type in ["sum", "avg"] 91 | self.fuse_type = fuse_type 92 | 93 | # levels to have top-down features in its outputs 94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 95 | # have top-down propagation, while outputs of level 0 and level 1 have only 96 | # lateral features from the same backbone level. 97 | if fpn_top_down_levels is None: 98 | # default is to have top-down features on all levels 99 | fpn_top_down_levels = range(len(self.convs)) 100 | self.fpn_top_down_levels = list(fpn_top_down_levels) 101 | 102 | def forward(self, xs: List[torch.Tensor]): 103 | 104 | out = [None] * len(self.convs) 105 | pos = [None] * len(self.convs) 106 | assert len(xs) == len(self.convs) 107 | # fpn forward pass 108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 109 | prev_features = None 110 | # forward in top-down order (from low to high resolution) 111 | n = len(self.convs) - 1 112 | for i in range(n, -1, -1): 113 | x = xs[i] 114 | lateral_features = self.convs[n - i](x) 115 | if i in self.fpn_top_down_levels and prev_features is not None: 116 | top_down_features = F.interpolate( 117 | prev_features.to(dtype=torch.float32), 118 | scale_factor=2.0, 119 | mode=self.fpn_interp_model, 120 | align_corners=( 121 | None if self.fpn_interp_model == "nearest" else False 122 | ), 123 | antialias=False, 124 | ) 125 | prev_features = lateral_features + top_down_features 126 | if self.fuse_type == "avg": 127 | prev_features /= 2 128 | else: 129 | prev_features = lateral_features 130 | x_out = prev_features 131 | out[i] = x_out 132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 133 | 134 | return out, pos 135 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/scripts/sav_frame_extraction_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | import argparse 4 | import os 5 | from pathlib import Path 6 | 7 | import cv2 8 | 9 | import numpy as np 10 | import submitit 11 | import tqdm 12 | 13 | 14 | def get_args_parser(): 15 | parser = argparse.ArgumentParser( 16 | description="[SA-V Preprocessing] Extracting JPEG frames", 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 18 | ) 19 | 20 | # ------------ 21 | # DATA 22 | # ------------ 23 | data_parser = parser.add_argument_group( 24 | title="SA-V dataset data root", 25 | description="What data to load and how to process it.", 26 | ) 27 | data_parser.add_argument( 28 | "--sav-vid-dir", 29 | type=str, 30 | required=True, 31 | help=("Where to find the SAV videos"), 32 | ) 33 | data_parser.add_argument( 34 | "--sav-frame-sample-rate", 35 | type=int, 36 | default=4, 37 | help="Rate at which to sub-sample frames", 38 | ) 39 | 40 | # ------------ 41 | # LAUNCH 42 | # ------------ 43 | launch_parser = parser.add_argument_group( 44 | title="Cluster launch settings", 45 | description="Number of jobs and retry settings.", 46 | ) 47 | launch_parser.add_argument( 48 | "--n-jobs", 49 | type=int, 50 | required=True, 51 | help="Shard the run over this many jobs.", 52 | ) 53 | launch_parser.add_argument( 54 | "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes." 55 | ) 56 | launch_parser.add_argument( 57 | "--partition", type=str, required=True, help="Partition to launch on." 58 | ) 59 | launch_parser.add_argument( 60 | "--account", type=str, required=True, help="Partition to launch on." 61 | ) 62 | launch_parser.add_argument("--qos", type=str, required=True, help="QOS.") 63 | 64 | # ------------ 65 | # OUTPUT 66 | # ------------ 67 | output_parser = parser.add_argument_group( 68 | title="Setting for results output", description="Where and how to save results." 69 | ) 70 | output_parser.add_argument( 71 | "--output-dir", 72 | type=str, 73 | required=True, 74 | help=("Where to dump the extracted jpeg frames"), 75 | ) 76 | output_parser.add_argument( 77 | "--slurm-output-root-dir", 78 | type=str, 79 | required=True, 80 | help=("Where to save slurm outputs"), 81 | ) 82 | return parser 83 | 84 | 85 | def decode_video(video_path: str): 86 | assert os.path.exists(video_path) 87 | video = cv2.VideoCapture(video_path) 88 | video_frames = [] 89 | while video.isOpened(): 90 | ret, frame = video.read() 91 | if ret: 92 | video_frames.append(frame) 93 | else: 94 | break 95 | return video_frames 96 | 97 | 98 | def extract_frames(video_path, sample_rate): 99 | frames = decode_video(video_path) 100 | return frames[::sample_rate] 101 | 102 | 103 | def submitit_launch(video_paths, sample_rate, save_root): 104 | for path in tqdm.tqdm(video_paths): 105 | frames = extract_frames(path, sample_rate) 106 | output_folder = os.path.join(save_root, Path(path).stem) 107 | if not os.path.exists(output_folder): 108 | os.makedirs(output_folder) 109 | for fid, frame in enumerate(frames): 110 | frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg") 111 | cv2.imwrite(frame_path, frame) 112 | print(f"Saved output to {save_root}") 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = get_args_parser() 117 | args = parser.parse_args() 118 | 119 | sav_vid_dir = args.sav_vid_dir 120 | save_root = args.output_dir 121 | sample_rate = args.sav_frame_sample_rate 122 | 123 | # List all SA-V videos 124 | mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")]) 125 | mp4_files = np.array(mp4_files) 126 | chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)] 127 | 128 | print(f"Processing videos in: {sav_vid_dir}") 129 | print(f"Processing {len(mp4_files)} files") 130 | print(f"Beginning processing in {args.n_jobs} processes") 131 | 132 | # Submitit params 133 | jobs_dir = os.path.join(args.slurm_output_root_dir, "%j") 134 | cpus_per_task = 4 135 | executor = submitit.AutoExecutor(folder=jobs_dir) 136 | executor.update_parameters( 137 | timeout_min=args.timeout, 138 | gpus_per_node=0, 139 | tasks_per_node=1, 140 | slurm_array_parallelism=args.n_jobs, 141 | cpus_per_task=cpus_per_task, 142 | slurm_partition=args.partition, 143 | slurm_account=args.account, 144 | slurm_qos=args.qos, 145 | ) 146 | executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"]) 147 | 148 | # Launch 149 | jobs = [] 150 | with executor.batch(): 151 | for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)): 152 | job = executor.submit( 153 | submitit_launch, 154 | video_paths=mp4_chunk, 155 | sample_rate=sample_rate, 156 | save_root=save_root, 157 | ) 158 | jobs.append(job) 159 | 160 | for j in jobs: 161 | print(f"Slurm JobID: {j.job_id}") 162 | print(f"Saving outputs to {save_root}") 163 | print(f"Slurm outputs at {args.slurm_output_root_dir}") 164 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | # Package metadata 11 | NAME = "SAM-2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" 14 | URL = "https://github.com/facebookresearch/sam2" 15 | AUTHOR = "Meta AI" 16 | AUTHOR_EMAIL = "segment-anything@meta.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r", encoding="utf-8") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.5.1", 26 | "torchvision>=0.20.1", 27 | "numpy>=1.24.4", 28 | "tqdm>=4.66.1", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=9.4.0", 32 | ] 33 | 34 | EXTRA_PACKAGES = { 35 | "notebooks": [ 36 | "matplotlib>=3.9.1", 37 | "jupyter>=1.0.0", 38 | "opencv-python>=4.7.0", 39 | "eva-decord>=0.6.1", 40 | ], 41 | "interactive-demo": [ 42 | "Flask>=3.0.3", 43 | "Flask-Cors>=5.0.0", 44 | "av>=13.0.0", 45 | "dataclasses-json>=0.6.7", 46 | "eva-decord>=0.6.1", 47 | "gunicorn>=23.0.0", 48 | "imagesize>=1.4.1", 49 | "pycocotools>=2.0.8", 50 | "strawberry-graphql>=0.243.0", 51 | ], 52 | "dev": [ 53 | "black==24.2.0", 54 | "usort==1.0.2", 55 | "ufmt==2.0.0b2", 56 | "fvcore>=0.1.5.post20221221", 57 | "pandas>=2.2.2", 58 | "scikit-image>=0.24.0", 59 | "tensorboard>=2.17.0", 60 | "pycocotools>=2.0.8", 61 | "tensordict>=0.6.0", 62 | "opencv-python>=4.7.0", 63 | "submitit>=1.5.1", 64 | ], 65 | } 66 | 67 | # By default, we also build the SAM 2 CUDA extension. 68 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. 69 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" 70 | # By default, we allow SAM 2 installation to proceed even with build errors. 71 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. 72 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" 73 | 74 | # Catch and skip errors during extension building and print a warning message 75 | # (note that this message only shows up under verbose build mode 76 | # "pip install -v -e ." or "python setup.py build_ext -v") 77 | CUDA_ERROR_MSG = ( 78 | "{}\n\n" 79 | "Failed to build the SAM 2 CUDA extension due to the error above. " 80 | "You can still use SAM 2 and it's OK to ignore the error above, although some " 81 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 82 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" 83 | ) 84 | 85 | 86 | def get_extensions(): 87 | if not BUILD_CUDA: 88 | return [] 89 | 90 | try: 91 | from torch.utils.cpp_extension import CUDAExtension 92 | 93 | srcs = ["sam2/csrc/connected_components.cu"] 94 | compile_args = { 95 | "cxx": [], 96 | "nvcc": [ 97 | "-DCUDA_HAS_FP16=1", 98 | "-D__CUDA_NO_HALF_OPERATORS__", 99 | "-D__CUDA_NO_HALF_CONVERSIONS__", 100 | "-D__CUDA_NO_HALF2_OPERATORS__", 101 | ], 102 | } 103 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 104 | except Exception as e: 105 | if BUILD_ALLOW_ERRORS: 106 | print(CUDA_ERROR_MSG.format(e)) 107 | ext_modules = [] 108 | else: 109 | raise e 110 | 111 | return ext_modules 112 | 113 | 114 | try: 115 | from torch.utils.cpp_extension import BuildExtension 116 | 117 | class BuildExtensionIgnoreErrors(BuildExtension): 118 | 119 | def finalize_options(self): 120 | try: 121 | super().finalize_options() 122 | except Exception as e: 123 | print(CUDA_ERROR_MSG.format(e)) 124 | self.extensions = [] 125 | 126 | def build_extensions(self): 127 | try: 128 | super().build_extensions() 129 | except Exception as e: 130 | print(CUDA_ERROR_MSG.format(e)) 131 | self.extensions = [] 132 | 133 | def get_ext_filename(self, ext_name): 134 | try: 135 | return super().get_ext_filename(ext_name) 136 | except Exception as e: 137 | print(CUDA_ERROR_MSG.format(e)) 138 | self.extensions = [] 139 | return "_C.so" 140 | 141 | cmdclass = { 142 | "build_ext": ( 143 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) 144 | if BUILD_ALLOW_ERRORS 145 | else BuildExtension.with_options(no_python_abi_suffix=True) 146 | ) 147 | } 148 | except Exception as e: 149 | cmdclass = {} 150 | if BUILD_ALLOW_ERRORS: 151 | print(CUDA_ERROR_MSG.format(e)) 152 | else: 153 | raise e 154 | 155 | 156 | # Setup configuration 157 | setup( 158 | name=NAME, 159 | version=VERSION, 160 | description=DESCRIPTION, 161 | long_description=LONG_DESCRIPTION, 162 | long_description_content_type="text/markdown", 163 | url=URL, 164 | author=AUTHOR, 165 | author_email=AUTHOR_EMAIL, 166 | license=LICENSE, 167 | packages=find_packages(exclude="notebooks"), 168 | include_package_data=True, 169 | install_requires=REQUIRED_PACKAGES, 170 | extras_require=EXTRA_PACKAGES, 171 | python_requires=">=3.10.0", 172 | ext_modules=get_extensions(), 173 | cmdclass=cmdclass, 174 | ) 175 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/dataset/vos_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import random 9 | from copy import deepcopy 10 | 11 | import numpy as np 12 | 13 | import torch 14 | from iopath.common.file_io import g_pathmgr 15 | from PIL import Image as PILImage 16 | from torchvision.datasets.vision import VisionDataset 17 | 18 | from training.dataset.vos_raw_dataset import VOSRawDataset 19 | from training.dataset.vos_sampler import VOSSampler 20 | from training.dataset.vos_segment_loader import JSONSegmentLoader 21 | 22 | from training.utils.data_utils import Frame, Object, VideoDatapoint 23 | 24 | MAX_RETRIES = 100 25 | 26 | 27 | class VOSDataset(VisionDataset): 28 | def __init__( 29 | self, 30 | transforms, 31 | training: bool, 32 | video_dataset: VOSRawDataset, 33 | sampler: VOSSampler, 34 | multiplier: int, 35 | always_target=True, 36 | target_segments_available=True, 37 | ): 38 | self._transforms = transforms 39 | self.training = training 40 | self.video_dataset = video_dataset 41 | self.sampler = sampler 42 | 43 | self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) 44 | self.repeat_factors *= multiplier 45 | print(f"Raw dataset length = {len(self.video_dataset)}") 46 | 47 | self.curr_epoch = 0 # Used in case data loader behavior changes across epochs 48 | self.always_target = always_target 49 | self.target_segments_available = target_segments_available 50 | 51 | def _get_datapoint(self, idx): 52 | 53 | for retry in range(MAX_RETRIES): 54 | try: 55 | if isinstance(idx, torch.Tensor): 56 | idx = idx.item() 57 | # sample a video 58 | video, segment_loader = self.video_dataset.get_video(idx) 59 | # sample frames and object indices to be used in a datapoint 60 | sampled_frms_and_objs = self.sampler.sample( 61 | video, segment_loader, epoch=self.curr_epoch 62 | ) 63 | break # Succesfully loaded video 64 | except Exception as e: 65 | if self.training: 66 | logging.warning( 67 | f"Loading failed (id={idx}); Retry {retry} with exception: {e}" 68 | ) 69 | idx = random.randrange(0, len(self.video_dataset)) 70 | else: 71 | # Shouldn't fail to load a val video 72 | raise e 73 | 74 | datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) 75 | for transform in self._transforms: 76 | datapoint = transform(datapoint, epoch=self.curr_epoch) 77 | return datapoint 78 | 79 | def construct(self, video, sampled_frms_and_objs, segment_loader): 80 | """ 81 | Constructs a VideoDatapoint sample to pass to transforms 82 | """ 83 | sampled_frames = sampled_frms_and_objs.frames 84 | sampled_object_ids = sampled_frms_and_objs.object_ids 85 | 86 | images = [] 87 | rgb_images = load_images(sampled_frames) 88 | # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) 89 | for frame_idx, frame in enumerate(sampled_frames): 90 | w, h = rgb_images[frame_idx].size 91 | images.append( 92 | Frame( 93 | data=rgb_images[frame_idx], 94 | objects=[], 95 | ) 96 | ) 97 | # We load the gt segments associated with the current frame 98 | if isinstance(segment_loader, JSONSegmentLoader): 99 | segments = segment_loader.load( 100 | frame.frame_idx, obj_ids=sampled_object_ids 101 | ) 102 | else: 103 | segments = segment_loader.load(frame.frame_idx) 104 | for obj_id in sampled_object_ids: 105 | # Extract the segment 106 | if obj_id in segments: 107 | assert ( 108 | segments[obj_id] is not None 109 | ), "None targets are not supported" 110 | # segment is uint8 and remains uint8 throughout the transforms 111 | segment = segments[obj_id].to(torch.uint8) 112 | else: 113 | # There is no target, we either use a zero mask target or drop this object 114 | if not self.always_target: 115 | continue 116 | segment = torch.zeros(h, w, dtype=torch.uint8) 117 | 118 | images[frame_idx].objects.append( 119 | Object( 120 | object_id=obj_id, 121 | frame_index=frame.frame_idx, 122 | segment=segment, 123 | ) 124 | ) 125 | return VideoDatapoint( 126 | frames=images, 127 | video_id=video.video_id, 128 | size=(h, w), 129 | ) 130 | 131 | def __getitem__(self, idx): 132 | return self._get_datapoint(idx) 133 | 134 | def __len__(self): 135 | return len(self.video_dataset) 136 | 137 | 138 | def load_images(frames): 139 | all_images = [] 140 | cache = {} 141 | for frame in frames: 142 | if frame.data is None: 143 | # Load the frame rgb data from file 144 | path = frame.image_path 145 | if path in cache: 146 | all_images.append(deepcopy(all_images[cache[path]])) 147 | continue 148 | with g_pathmgr.open(path, "rb") as fopen: 149 | all_images.append(PILImage.open(fopen).convert("RGB")) 150 | cache[path] = len(all_images) - 1 151 | else: 152 | # The frame rgb data has already been loaded 153 | # Convert it to a PILImage 154 | all_images.append(tensor_2_PIL(frame.data)) 155 | 156 | return all_images 157 | 158 | 159 | def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: 160 | data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 161 | data = data.astype(np.uint8) 162 | return PILImage.fromarray(data) 163 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Misc functions, including distributed helpers. 9 | 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | 13 | from dataclasses import dataclass 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import torch 17 | 18 | from PIL import Image as PILImage 19 | from tensordict import tensorclass 20 | 21 | 22 | @tensorclass 23 | class BatchedVideoMetaData: 24 | """ 25 | This class represents metadata about a batch of videos. 26 | Attributes: 27 | unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) 28 | frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. 29 | """ 30 | 31 | unique_objects_identifier: torch.LongTensor 32 | frame_orig_size: torch.LongTensor 33 | 34 | 35 | @tensorclass 36 | class BatchedVideoDatapoint: 37 | """ 38 | This class represents a batch of videos with associated annotations and metadata. 39 | Attributes: 40 | img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. 41 | obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. 42 | masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. 43 | metadata: An instance of BatchedVideoMetaData containing metadata about the batch. 44 | dict_key: A string key used to identify the batch. 45 | """ 46 | 47 | img_batch: torch.FloatTensor 48 | obj_to_frame_idx: torch.IntTensor 49 | masks: torch.BoolTensor 50 | metadata: BatchedVideoMetaData 51 | 52 | dict_key: str 53 | 54 | def pin_memory(self, device=None): 55 | return self.apply(torch.Tensor.pin_memory, device=device) 56 | 57 | @property 58 | def num_frames(self) -> int: 59 | """ 60 | Returns the number of frames per video. 61 | """ 62 | return self.batch_size[0] 63 | 64 | @property 65 | def num_videos(self) -> int: 66 | """ 67 | Returns the number of videos in the batch. 68 | """ 69 | return self.img_batch.shape[1] 70 | 71 | @property 72 | def flat_obj_to_img_idx(self) -> torch.IntTensor: 73 | """ 74 | Returns a flattened tensor containing the object to img index. 75 | The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] 76 | """ 77 | frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) 78 | flat_idx = video_idx * self.num_frames + frame_idx 79 | return flat_idx 80 | 81 | @property 82 | def flat_img_batch(self) -> torch.FloatTensor: 83 | """ 84 | Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] 85 | """ 86 | 87 | return self.img_batch.transpose(0, 1).flatten(0, 1) 88 | 89 | 90 | @dataclass 91 | class Object: 92 | # Id of the object in the media 93 | object_id: int 94 | # Index of the frame in the media (0 if single image) 95 | frame_index: int 96 | segment: Union[torch.Tensor, dict] # RLE dict or binary mask 97 | 98 | 99 | @dataclass 100 | class Frame: 101 | data: Union[torch.Tensor, PILImage.Image] 102 | objects: List[Object] 103 | 104 | 105 | @dataclass 106 | class VideoDatapoint: 107 | """Refers to an image/video and all its annotations""" 108 | 109 | frames: List[Frame] 110 | video_id: int 111 | size: Tuple[int, int] 112 | 113 | 114 | def collate_fn( 115 | batch: List[VideoDatapoint], 116 | dict_key, 117 | ) -> BatchedVideoDatapoint: 118 | """ 119 | Args: 120 | batch: A list of VideoDatapoint instances. 121 | dict_key (str): A string key used to identify the batch. 122 | """ 123 | img_batch = [] 124 | for video in batch: 125 | img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] 126 | 127 | img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) 128 | T = img_batch.shape[0] 129 | # Prepare data structures for sequential processing. Per-frame processing but batched across videos. 130 | step_t_objects_identifier = [[] for _ in range(T)] 131 | step_t_frame_orig_size = [[] for _ in range(T)] 132 | 133 | step_t_masks = [[] for _ in range(T)] 134 | step_t_obj_to_frame_idx = [ 135 | [] for _ in range(T) 136 | ] # List to store frame indices for each time step 137 | 138 | for video_idx, video in enumerate(batch): 139 | orig_video_id = video.video_id 140 | orig_frame_size = video.size 141 | for t, frame in enumerate(video.frames): 142 | objects = frame.objects 143 | for obj in objects: 144 | orig_obj_id = obj.object_id 145 | orig_frame_idx = obj.frame_index 146 | step_t_obj_to_frame_idx[t].append( 147 | torch.tensor([t, video_idx], dtype=torch.int) 148 | ) 149 | step_t_masks[t].append(obj.segment.to(torch.bool)) 150 | step_t_objects_identifier[t].append( 151 | torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) 152 | ) 153 | step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) 154 | 155 | obj_to_frame_idx = torch.stack( 156 | [ 157 | torch.stack(obj_to_frame_idx, dim=0) 158 | for obj_to_frame_idx in step_t_obj_to_frame_idx 159 | ], 160 | dim=0, 161 | ) 162 | masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) 163 | objects_identifier = torch.stack( 164 | [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 165 | ) 166 | frame_orig_size = torch.stack( 167 | [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 168 | ) 169 | return BatchedVideoDatapoint( 170 | img_batch=img_batch, 171 | obj_to_frame_idx=obj_to_frame_idx, 172 | masks=masks, 173 | metadata=BatchedVideoMetaData( 174 | unique_objects_identifier=objects_identifier, 175 | frame_orig_size=frame_orig_size, 176 | ), 177 | dict_key=dict_key, 178 | batch_size=[T], 179 | ) 180 | -------------------------------------------------------------------------------- /mmmg_eval/utils/stat_knowledge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from pathlib import Path 6 | from all_configs import * 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | def clean_key(k): 11 | 12 | s = k.strip() 13 | s = re.sub(r"^[\*`]+", "", s) 14 | s = re.sub(r"[\*`]+$", "", s) 15 | s = re.sub(r"[::]$", "", s) 16 | return s.strip() 17 | 18 | def parse_output_to_checklist(image_id, text, elements, dependencies): 19 | all_keys = elements + dependencies 20 | 21 | escaped_keys = [re.escape(k) for k in all_keys] 22 | 23 | text = text.replace("**", "") 24 | text = text.replace("[yes]", "yes") 25 | text = text.replace("[no]", "no") 26 | pattern = rf"^\s*(?:[*\-•]|\d+\.)?\s*`*\s*({'|'.join(escaped_keys)})`*\s*[::]\s*(yes|no|[yes]|[no]|YES|NO|Yes|No)\b" 27 | matches = re.findall(pattern, text, flags=re.IGNORECASE | re.MULTILINE) 28 | 29 | if len (matches) == 0: 30 | return { 31 | "key": image_id, 32 | "elements": {key: False for key in elements}, 33 | "dependencies": {key: False for key in dependencies} 34 | } 35 | 36 | element_dict = {key: False for key in elements} 37 | dependency_dict = {key: False for key in dependencies} 38 | 39 | for raw_k, v in matches: 40 | k_clean = clean_key(raw_k) 41 | v_bool = v.strip().lower() == "yes" 42 | if k_clean in element_dict: 43 | element_dict[k_clean] = v_bool 44 | elif k_clean in dependency_dict: 45 | dependency_dict[k_clean] = v_bool 46 | else: 47 | find=False 48 | for ele in element_dict.keys(): 49 | if k_clean.lower() == ele.lower(): 50 | element_dict[ele] = v_bool 51 | find=True 52 | break 53 | if not find: 54 | for ele in dependency_dict.keys(): 55 | if k_clean.lower() == ele.lower(): 56 | dependency_dict[ele] = v_bool 57 | find=True 58 | break 59 | 60 | 61 | return { 62 | "key": image_id, 63 | "elements": element_dict, 64 | "dependencies": dependency_dict 65 | } 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser( 70 | description="Merge model answers with MMMG ground truth." 71 | ) 72 | parser.add_argument("--result_folder", "-o", required=True, 73 | help="Path containing model-generated *.json files") 74 | parser.add_argument("--image_folder", "-i", required=True, 75 | help="Root folder holding six sub-folders of images") 76 | parser.add_argument("--api_name", "-a", required=True, 77 | help="Key used in each result JSON to fetch the model output") 78 | parser.add_argument("--output_dir", required=True, 79 | help="Folder to save merged result JSON") 80 | parser.add_argument("--save_name", default="step2_summarize", 81 | help="Output file name (without .json)") 82 | parser.add_argument("--hf_cache", default="./data/MMMG", 83 | help="HuggingFace cache dir") 84 | args = parser.parse_args() 85 | 86 | # ---------------- load MMMG ground truth --------------------------------- 87 | full_dataset = load_all_mmmg_configs( 88 | cache_dir=args.hf_cache, max_workers=16 89 | ) 90 | 91 | # build lookup: {grade: {image_id: {elements, dependencies}}} 92 | gt = {} 93 | for sample in full_dataset: 94 | grade = str(sample["Education"]) # e.g. "preschool" 95 | image_id = sample["key"] 96 | kg = json.loads(sample["Knowledge_Graph"]) 97 | gt.setdefault(grade, {})[image_id] = kg 98 | 99 | 100 | grade_map = { # match folder names if they differ 101 | "preschool": "0_preschool", 102 | "primaryschool": "1_primaryschool", 103 | "secondaryschool": "2_secondaryschool", 104 | "highschool": "3_highschool", 105 | "undergraduate": "4_undergraduate", 106 | "PhD": "5_PhD", 107 | } 108 | 109 | merged = {g: {} for g in grade_map} 110 | 111 | # ---------------- scan model result files -------------------------------- 112 | result_folder = Path(args.result_folder) 113 | image_folder = Path(args.image_folder) 114 | 115 | for fn in tqdm(sorted(result_folder.glob("*.json")), desc="merging"): 116 | try: 117 | data = json.loads(fn.read_text(encoding="utf-8")) 118 | except Exception as e: 119 | print(f"[skip] cannot parse {fn.name}: {e}") 120 | continue 121 | 122 | if args.api_name not in data: 123 | print(f"[skip] {fn.name}: no key '{args.api_name}'") 124 | continue 125 | 126 | text = data[args.api_name] # LLM checklist 127 | image_id = data["key"] # assert match later 128 | 129 | parts = fn.stem.split("__") # __.json 130 | if len(parts) == 2: 131 | grade, img_key = parts 132 | else: 133 | grade, img_key,= parts[:2] 134 | 135 | if img_key != image_id: 136 | print(f"[warn] ID mismatch in {fn.name}") 137 | 138 | # locate image (png/jpg fallback) 139 | base = image_folder / grade 140 | if not base.exists(): 141 | base = image_folder / grade_map.get(grade, grade) 142 | png_path = base / f"{image_id}.png" 143 | jpg_path = base / f"{image_id}.jpg" 144 | img_path = png_path if png_path.exists() else jpg_path 145 | if not img_path.exists(): 146 | print(f"[skip] image not found for {image_id}") 147 | continue 148 | 149 | # ground-truth KG 150 | try: 151 | gt_entry = gt[grade][image_id] 152 | elements = gt_entry["elements"] 153 | dependencies = gt_entry["dependencies"] 154 | except KeyError: 155 | print(f"[skip] GT not found for {grade}/{image_id}") 156 | continue 157 | 158 | merged.setdefault(grade, {})[image_id] = { 159 | "img_path": str(img_path), 160 | "result": parse_output_to_checklist(image_id, text, elements, dependencies), 161 | } 162 | 163 | # ---------------- save ---------------------------------------------------- 164 | out_dir = Path(args.output_dir) 165 | out_dir.mkdir(parents=True, exist_ok=True) 166 | out_path = out_dir / f"{args.save_name}.json" 167 | out_path.write_text(json.dumps(merged, indent=2, ensure_ascii=False)) 168 | print(f"Stage1 Evaluation Done ✔️. Saved to {out_path}") 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/README.md: -------------------------------------------------------------------------------- 1 | # Training Code for SAM 2 2 | 3 | This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos. 4 | The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both). 5 | 6 | ## Structure 7 | 8 | The training code is organized into the following subfolders: 9 | 10 | * `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms. 11 | * `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling). 12 | * `utils`: This folder contains training utils such as loggers and distributed training utils. 13 | * `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training. 14 | * `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training. 15 | * `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers. 16 | * `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop. 17 | * `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h` 18 | 19 | ## Getting Started 20 | 21 | To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets. 22 | 23 | #### Requirements: 24 | - We assume training on A100 GPUs with **80 GB** of memory. 25 | - Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download). 26 | 27 | #### Steps to fine-tune on MOSE: 28 | - Install the packages required for training by running `pip install -e ".[dev]"`. 29 | - Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`. 30 | ```yaml 31 | dataset: 32 | # PATHS to Dataset 33 | img_folder: null # PATH to MOSE JPEGImages folder 34 | gt_folder: null # PATH to MOSE Annotations folder 35 | file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training 36 | ``` 37 | - To fine-tune the base model on MOSE using 8 GPUs, run 38 | 39 | ```python 40 | python training/train.py \ 41 | -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ 42 | --use-cluster 0 \ 43 | --num-gpus 8 44 | ``` 45 | 46 | We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running 47 | 48 | ```python 49 | python training/train.py \ 50 | -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ 51 | --use-cluster 1 \ 52 | --num-gpus 8 \ 53 | --num-nodes 2 54 | --partition $PARTITION \ 55 | --qos $QOS \ 56 | --account $ACCOUNT 57 | ``` 58 | where partition, qos, and account are optional and depend on your SLURM configuration. 59 | By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows: 60 | 61 | ```yaml 62 | experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} 63 | ``` 64 | The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4. 65 | 66 | 67 | After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)). 68 | ## Training on images and videos 69 | The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets: 70 | 71 | ```yaml 72 | data: 73 | train: 74 | _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset 75 | phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases 76 | batch_sizes: # List of batch sizes corresponding to each dataset 77 | - ${bs1} # Batch size of dataset 1 78 | - ${bs2} # Batch size of dataset 2 79 | datasets: 80 | # SA1B as an example of an image dataset 81 | - _target_: training.dataset.vos_dataset.VOSDataset 82 | training: true 83 | video_dataset: 84 | _target_: training.dataset.vos_raw_dataset.SA1BRawDataset 85 | img_folder: ${path_to_img_folder} 86 | gt_folder: ${path_to_gt_folder} 87 | file_list_txt: ${path_to_train_filelist} # Optional 88 | sampler: 89 | _target_: training.dataset.vos_sampler.RandomUniformSampler 90 | num_frames: 1 91 | max_num_objects: ${max_num_objects_per_image} 92 | transforms: ${image_transforms} 93 | # SA-V as an example of a video dataset 94 | - _target_: training.dataset.vos_dataset.VOSDataset 95 | training: true 96 | video_dataset: 97 | _target_: training.dataset.vos_raw_dataset.JSONRawDataset 98 | img_folder: ${path_to_img_folder} 99 | gt_folder: ${path_to_gt_folder} 100 | file_list_txt: ${path_to_train_filelist} # Optional 101 | ann_every: 4 102 | sampler: 103 | _target_: training.dataset.vos_sampler.RandomUniformSampler 104 | num_frames: 8 # Number of frames per video 105 | max_num_objects: ${max_num_objects_per_video} 106 | reverse_time_prob: ${reverse_time_prob} # probability to reverse video 107 | transforms: ${video_transforms} 108 | shuffle: True 109 | num_workers: ${num_train_workers} 110 | pin_memory: True 111 | drop_last: True 112 | collate_fn: 113 | _target_: training.utils.data_utils.collate_fn 114 | _partial_: true 115 | dict_key: all 116 | ``` 117 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | import torch 11 | from hydra import compose 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf 14 | 15 | import sam2 16 | 17 | # Check if the user is running Python from the parent directory of the sam2 repo 18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since 19 | # it could shadow the sam2 package and cause issues. 20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): 21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself 22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). 23 | # This typically happens because the user is running Python from the parent directory 24 | # that contains the sam2 repo they cloned. 25 | raise RuntimeError( 26 | "You're likely running Python from the parent directory of the sam2 repository " 27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " 28 | "This is not supported since the `sam2` Python package could be shadowed by the " 29 | "repository name (the repository is also named `sam2` and contains the Python package " 30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " 31 | "rather than its parent dir, or from your home directory) after installing SAM 2." 32 | ) 33 | 34 | 35 | HF_MODEL_ID_TO_FILENAMES = { 36 | "facebook/sam2-hiera-tiny": ( 37 | "configs/sam2/sam2_hiera_t.yaml", 38 | "sam2_hiera_tiny.pt", 39 | ), 40 | "facebook/sam2-hiera-small": ( 41 | "configs/sam2/sam2_hiera_s.yaml", 42 | "sam2_hiera_small.pt", 43 | ), 44 | "facebook/sam2-hiera-base-plus": ( 45 | "configs/sam2/sam2_hiera_b+.yaml", 46 | "sam2_hiera_base_plus.pt", 47 | ), 48 | "facebook/sam2-hiera-large": ( 49 | "configs/sam2/sam2_hiera_l.yaml", 50 | "sam2_hiera_large.pt", 51 | ), 52 | "facebook/sam2.1-hiera-tiny": ( 53 | "configs/sam2.1/sam2.1_hiera_t.yaml", 54 | "sam2.1_hiera_tiny.pt", 55 | ), 56 | "facebook/sam2.1-hiera-small": ( 57 | "configs/sam2.1/sam2.1_hiera_s.yaml", 58 | "sam2.1_hiera_small.pt", 59 | ), 60 | "facebook/sam2.1-hiera-base-plus": ( 61 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 62 | "sam2.1_hiera_base_plus.pt", 63 | ), 64 | "facebook/sam2.1-hiera-large": ( 65 | "configs/sam2.1/sam2.1_hiera_l.yaml", 66 | "sam2.1_hiera_large.pt", 67 | ), 68 | } 69 | 70 | 71 | def build_sam2( 72 | config_file, 73 | ckpt_path=None, 74 | device="cuda", 75 | mode="eval", 76 | hydra_overrides_extra=[], 77 | apply_postprocessing=True, 78 | **kwargs, 79 | ): 80 | 81 | if apply_postprocessing: 82 | hydra_overrides_extra = hydra_overrides_extra.copy() 83 | hydra_overrides_extra += [ 84 | # dynamically fall back to multi-mask if the single mask is not stable 85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 88 | ] 89 | # Read config and init model 90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 91 | OmegaConf.resolve(cfg) 92 | model = instantiate(cfg.model, _recursive_=True) 93 | _load_checkpoint(model, ckpt_path) 94 | model = model.to(device) 95 | if mode == "eval": 96 | model.eval() 97 | return model 98 | 99 | 100 | def build_sam2_video_predictor( 101 | config_file, 102 | ckpt_path=None, 103 | device="cuda", 104 | mode="eval", 105 | hydra_overrides_extra=[], 106 | apply_postprocessing=True, 107 | vos_optimized=False, 108 | **kwargs, 109 | ): 110 | hydra_overrides = [ 111 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 112 | ] 113 | if vos_optimized: 114 | hydra_overrides = [ 115 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", 116 | "++model.compile_image_encoder=True", # Let sam2_base handle this 117 | ] 118 | 119 | if apply_postprocessing: 120 | hydra_overrides_extra = hydra_overrides_extra.copy() 121 | hydra_overrides_extra += [ 122 | # dynamically fall back to multi-mask if the single mask is not stable 123 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 124 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 125 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 126 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 127 | "++model.binarize_mask_from_pts_for_mem_enc=true", 128 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 129 | "++model.fill_hole_area=8", 130 | ] 131 | hydra_overrides.extend(hydra_overrides_extra) 132 | 133 | # Read config and init model 134 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 135 | OmegaConf.resolve(cfg) 136 | model = instantiate(cfg.model, _recursive_=True) 137 | _load_checkpoint(model, ckpt_path) 138 | model = model.to(device) 139 | if mode == "eval": 140 | model.eval() 141 | return model 142 | 143 | 144 | def _hf_download(model_id): 145 | from huggingface_hub import hf_hub_download 146 | 147 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 148 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 149 | return config_name, ckpt_path 150 | 151 | 152 | def build_sam2_hf(model_id, **kwargs): 153 | config_name, ckpt_path = _hf_download(model_id) 154 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 155 | 156 | 157 | def build_sam2_video_predictor_hf(model_id, **kwargs): 158 | config_name, ckpt_path = _hf_download(model_id) 159 | return build_sam2_video_predictor( 160 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 161 | ) 162 | 163 | 164 | def _load_checkpoint(model, ckpt_path): 165 | if ckpt_path is not None: 166 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 167 | missing_keys, unexpected_keys = model.load_state_dict(sd) 168 | if missing_keys: 169 | logging.error(missing_keys) 170 | raise RuntimeError() 171 | if unexpected_keys: 172 | logging.error(unexpected_keys) 173 | raise RuntimeError() 174 | logging.info("Loaded checkpoint sucessfully") 175 | -------------------------------------------------------------------------------- /mmmg_eval/step1_knowledge_integrity.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from openai import AzureOpenAI 3 | from PIL import Image 4 | import io 5 | import base64 6 | from datetime import datetime 7 | from multiprocessing import Pool, cpu_count 8 | import time 9 | from tqdm import tqdm 10 | import argparse 11 | from PIL import Image 12 | from utils.all_configs import * 13 | from utils.gpt_api_pool import gpt_api_pool 14 | from utils.instruction_prompt import instruction_prompt 15 | 16 | 17 | def encode_image(image): 18 | buffered = io.BytesIO() 19 | image.save(buffered, format=image.format) 20 | return base64.b64encode(buffered.getvalue()).decode('utf-8') 21 | 22 | 23 | def get_response_text_azure_img(image1, question, gpt_api_i): 24 | 25 | # init azure openAI credential. 26 | aoiclient = AzureOpenAI( 27 | api_key=gpt_api_i["api_key"], 28 | api_version=gpt_api_i["api_version"], 29 | azure_endpoint=gpt_api_i["azure_endpoint"] 30 | ) 31 | deployment_name = gpt_api_i["deployment_name"] 32 | if image1 == None: 33 | print("Error: No image path provided.") 34 | return None 35 | try: 36 | png_file = Image.open(image1) 37 | base64_image = encode_image(png_file) 38 | except Exception as e: 39 | print(f"Error opening image file {image1}: {e}") 40 | return None 41 | 42 | full_prompt = question 43 | 44 | num_iters = 3 45 | for itx in range(num_iters): 46 | try: 47 | response = aoiclient.chat.completions.create( 48 | messages=[{ 49 | "role": "user", 50 | "content": [{ 51 | "type": "text", 52 | "text": full_prompt 53 | }, { 54 | "type": "image_url", 55 | "image_url": { 56 | "url": f"data:image/png;base64,{base64_image}" 57 | } 58 | }] 59 | }], 60 | model=deployment_name, #Azure uses deployment_name, not model ID 61 | max_tokens=2048, #o3 62 | #max_completion_tokens=2048, #o1 63 | temperature=0, 64 | top_p=1, 65 | ) 66 | return response.choices[0].message.content 67 | except Exception as e: 68 | print(f"Error for image {image1} occurs when calling Azure OpenAI API: {e}") 69 | time.sleep(5*(1+itx)) 70 | 71 | return None 72 | 73 | def process_item(args_tuple): 74 | grades = [ 75 | "preschool", 76 | "primaryschool", 77 | "secondaryschool", 78 | "highschool", 79 | "undergraduate", 80 | "PhD" 81 | ] 82 | sample, t2i_method, args, gpt_api_i, save_root = args_tuple 83 | 84 | image_name = sample["key"] 85 | gradeschool = sample["Education"] 86 | base_dir = os.path.join(args.img_dir, gradeschool) 87 | image_path = os.path.join(base_dir, f"{image_name}.png") 88 | 89 | json_save_name = os.path.join( 90 | save_root, f"{gradeschool}__{image_name}__eval_by__{args.api_name}.json") 91 | os.makedirs(f"{save_root}", exist_ok=True) 92 | if os.path.exists(json_save_name): 93 | try: 94 | with open(json_save_name, "r") as f: 95 | try_text = json.load(f) 96 | if args.api_name in try_text.keys(): 97 | return f"Already exists: {json_save_name}" 98 | else: 99 | print("blank json file, continue...") 100 | except Exception as e: 101 | print("data corrupted, continue...") 102 | 103 | try: 104 | image = Image.open(image_path) 105 | assert image is not None, f"Image {image_path} is None" 106 | except Exception as e: 107 | image_path = os.path.join(base_dir, f"{image_name}.jpg") 108 | print(f"Processing {image_path}...") 109 | try: 110 | image = Image.open(image_path) 111 | assert image is not None, f"Image {image_path} is None" 112 | except Exception as e: 113 | # ambiguous check to avoid saving the same image in different folders 114 | for indxx, gradex in enumerate(grades): 115 | if gradex in image_name: 116 | image_name = str(indxx)+"_" + image_name 117 | image_path = os.path.join(args.img_dir, gradex, f"{image_name}.png") 118 | try: 119 | image = Image.open(image_path) 120 | assert image is not None, f"Image {image_path} is None" 121 | break 122 | except Exception as e: 123 | print(f"Error opening image file {image_path}: {e}") 124 | return 125 | 126 | save_dict_i = { 127 | "key": image_name, 128 | "t2i_method": t2i_method 129 | } 130 | save_dict_i["image_path"] = image_path 131 | 132 | instruction_prompt_i = instruction_prompt.replace( 133 | "[ELEM_DEPEND]", 134 | str(sample["Knowledge_Graph"])+"\n"+str(sample["Annotation"]) 135 | ) 136 | 137 | response = get_response_text_azure_img( 138 | image_path, 139 | instruction_prompt_i, 140 | gpt_api_i 141 | ) 142 | 143 | if response is None: 144 | return f"Error: No response for {image_name} - {t2i_method}- {image_path}" 145 | 146 | save_dict_i[args.api_name] = response 147 | 148 | with open(json_save_name, "w") as f: 149 | json.dump(save_dict_i, f, indent=4) 150 | 151 | time.sleep(0.25) 152 | return f"Saved to {json_save_name}" 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser(description="MMMG-Evaluation") 156 | parser.add_argument("--img_dir", "-i", type=str) 157 | parser.add_argument("--output_dir", "-o", type=str) 158 | parser.add_argument("--t2i_method", "-m", type=str, default="0") 159 | parser.add_argument("--api_name", "-a", type=str, default="o3", help="API type for OpenAI.") 160 | parser.add_argument("--hf_cache", "-c", type=str) 161 | parser.add_argument("--num_workers", type=int, default=24) 162 | 163 | args = parser.parse_args() 164 | 165 | 166 | grades = [ 167 | ("preschool", "0_preschool"), 168 | ("primaryschool", "1_primaryschool"), 169 | ("secondaryschool", "2_secondaryschool"), 170 | ("highschool", "3_highschool"), 171 | ("undergraduate", "4_undergraduate"), 172 | ("PhD", "5_PhD") 173 | ] 174 | 175 | # ---------------- Load HF dataset once ---------------- 176 | MMMG = load_all_mmmg_configs( 177 | cache_dir=args.hf_cache, 178 | max_workers=args.num_workers 179 | ) 180 | print(f"Start Stage1 Eval. Loaded {len(MMMG)} samples.") 181 | save_root = args.output_dir 182 | t2i_method = args.t2i_method 183 | num_keys = len(gpt_api_pool) 184 | all_jobs = [ 185 | ( 186 | dict(sample), 187 | t2i_method, 188 | args, 189 | gpt_api_pool[i % num_keys], 190 | save_root, 191 | ) 192 | for i, sample in enumerate(MMMG) 193 | ] 194 | print("Jobs prepared for processing.") 195 | 196 | with Pool(processes=args.num_workers) as pool: 197 | for msg in tqdm(pool.imap_unordered(process_item, all_jobs), total=len(all_jobs)): 198 | print(msg) 199 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/dataset/sam2_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | from typing import Callable, Iterable, List, Optional, Sequence 10 | 11 | import torch 12 | 13 | from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset 14 | 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | 18 | class MixedDataLoader: 19 | def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): 20 | """ 21 | Args: 22 | dataloaders (List[DataLoader]): List of DataLoaders to be mixed. 23 | mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from 24 | 25 | """ 26 | assert len(dataloaders) == mixing_prob.shape[0] 27 | self.dataloaders = dataloaders 28 | self.mixing_prob = mixing_prob 29 | # Iterator state 30 | self._iter_dls = None 31 | self._iter_mixing_prob = None 32 | self.random_generator = torch.Generator() 33 | 34 | def __len__(self): 35 | return sum([len(d) for d in self.dataloaders]) 36 | 37 | def __iter__(self): 38 | # Synchronize dataloader seeds 39 | self.random_generator.manual_seed(42) 40 | self._iter_dls = [iter(loader) for loader in self.dataloaders] 41 | self._iter_mixing_prob = self.mixing_prob.clone() 42 | return self 43 | 44 | def __next__(self): 45 | """ 46 | Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. 47 | """ 48 | if self._iter_dls is None: 49 | raise TypeError(f"{type(self).__name__} object is not an iterator") 50 | 51 | while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. 52 | dataset_idx = self._iter_mixing_prob.multinomial( 53 | 1, generator=self.random_generator 54 | ).item() 55 | try: 56 | item = next(self._iter_dls[dataset_idx]) 57 | return item 58 | except StopIteration: 59 | # No more iterations for this dataset, set it's mixing probability to zero and try again. 60 | self._iter_mixing_prob[dataset_idx] = 0 61 | except Exception as e: 62 | # log and raise any other unexpected error. 63 | logging.error(e) 64 | raise e 65 | 66 | # Exhausted all iterators 67 | raise StopIteration 68 | 69 | 70 | class TorchTrainMixedDataset: 71 | def __init__( 72 | self, 73 | datasets: List[Dataset], 74 | batch_sizes: List[int], 75 | num_workers: int, 76 | shuffle: bool, 77 | pin_memory: bool, 78 | drop_last: bool, 79 | collate_fn: Optional[Callable] = None, 80 | worker_init_fn: Optional[Callable] = None, 81 | phases_per_epoch: int = 1, 82 | dataset_prob: Optional[List[float]] = None, 83 | ) -> None: 84 | """ 85 | Args: 86 | datasets (List[Dataset]): List of Datasets to be mixed. 87 | batch_sizes (List[int]): Batch sizes for each dataset in the list. 88 | num_workers (int): Number of workers per dataloader. 89 | shuffle (bool): Whether or not to shuffle data. 90 | pin_memory (bool): If True, use pinned memory when loading tensors from disk. 91 | drop_last (bool): Whether or not to drop the last batch of data. 92 | collate_fn (Callable): Function to merge a list of samples into a mini-batch. 93 | worker_init_fn (Callable): Function to init each dataloader worker. 94 | phases_per_epoch (int): Number of phases per epoch. 95 | dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 96 | """ 97 | 98 | self.datasets = datasets 99 | self.batch_sizes = batch_sizes 100 | self.num_workers = num_workers 101 | self.shuffle = shuffle 102 | self.pin_memory = pin_memory 103 | self.drop_last = drop_last 104 | self.collate_fn = collate_fn 105 | self.worker_init_fn = worker_init_fn 106 | assert len(self.datasets) > 0 107 | for dataset in self.datasets: 108 | assert not isinstance(dataset, IterableDataset), "Not supported" 109 | # `RepeatFactorWrapper` requires calling set_epoch first to get its length 110 | self._set_dataset_epoch(dataset, 0) 111 | self.phases_per_epoch = phases_per_epoch 112 | self.chunks = [None] * len(datasets) 113 | if dataset_prob is None: 114 | # If not provided, assign each dataset a probability proportional to its length. 115 | dataset_lens = [ 116 | (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) 117 | for d, bs in zip(datasets, batch_sizes) 118 | ] 119 | total_len = sum(dataset_lens) 120 | dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) 121 | else: 122 | assert len(dataset_prob) == len(datasets) 123 | dataset_prob = torch.tensor(dataset_prob) 124 | 125 | logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") 126 | assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" 127 | self.dataset_prob = dataset_prob 128 | 129 | def _set_dataset_epoch(self, dataset, epoch: int) -> None: 130 | if hasattr(dataset, "epoch"): 131 | dataset.epoch = epoch 132 | if hasattr(dataset, "set_epoch"): 133 | dataset.set_epoch(epoch) 134 | 135 | def get_loader(self, epoch) -> Iterable: 136 | dataloaders = [] 137 | for d_idx, (dataset, batch_size) in enumerate( 138 | zip(self.datasets, self.batch_sizes) 139 | ): 140 | if self.phases_per_epoch > 1: 141 | # Major epoch that looops over entire dataset 142 | # len(main_epoch) == phases_per_epoch * len(epoch) 143 | main_epoch = epoch // self.phases_per_epoch 144 | 145 | # Phase with in the main epoch 146 | local_phase = epoch % self.phases_per_epoch 147 | 148 | # Start of new data-epoch or job is resumed after preemtion. 149 | if local_phase == 0 or self.chunks[d_idx] is None: 150 | # set seed for dataset epoch 151 | # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. 152 | self._set_dataset_epoch(dataset, main_epoch) 153 | 154 | # Separate random generator for subset sampling 155 | g = torch.Generator() 156 | g.manual_seed(main_epoch) 157 | self.chunks[d_idx] = torch.chunk( 158 | torch.randperm(len(dataset), generator=g), 159 | self.phases_per_epoch, 160 | ) 161 | 162 | dataset = Subset(dataset, self.chunks[d_idx][local_phase]) 163 | else: 164 | self._set_dataset_epoch(dataset, epoch) 165 | 166 | sampler = DistributedSampler(dataset, shuffle=self.shuffle) 167 | sampler.set_epoch(epoch) 168 | 169 | batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) 170 | dataloaders.append( 171 | DataLoader( 172 | dataset, 173 | num_workers=self.num_workers, 174 | pin_memory=self.pin_memory, 175 | batch_sampler=batch_sampler, 176 | collate_fn=self.collate_fn, 177 | worker_init_fn=self.worker_init_fn, 178 | ) 179 | ) 180 | return MixedDataLoader(dataloaders, self.dataset_prob) 181 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | 96 | point_embedding = torch.where( 97 | (labels == -1).unsqueeze(-1), 98 | torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, 99 | point_embedding, 100 | ) 101 | point_embedding = torch.where( 102 | (labels == 0).unsqueeze(-1), 103 | point_embedding + self.point_embeddings[0].weight, 104 | point_embedding, 105 | ) 106 | point_embedding = torch.where( 107 | (labels == 1).unsqueeze(-1), 108 | point_embedding + self.point_embeddings[1].weight, 109 | point_embedding, 110 | ) 111 | point_embedding = torch.where( 112 | (labels == 2).unsqueeze(-1), 113 | point_embedding + self.point_embeddings[2].weight, 114 | point_embedding, 115 | ) 116 | point_embedding = torch.where( 117 | (labels == 3).unsqueeze(-1), 118 | point_embedding + self.point_embeddings[3].weight, 119 | point_embedding, 120 | ) 121 | return point_embedding 122 | 123 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 124 | """Embeds box prompts.""" 125 | boxes = boxes + 0.5 # Shift to center of pixel 126 | coords = boxes.reshape(-1, 2, 2) 127 | corner_embedding = self.pe_layer.forward_with_coords( 128 | coords, self.input_image_size 129 | ) 130 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 131 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 132 | return corner_embedding 133 | 134 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 135 | """Embeds mask inputs.""" 136 | mask_embedding = self.mask_downscaling(masks) 137 | return mask_embedding 138 | 139 | def _get_batch_size( 140 | self, 141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 142 | boxes: Optional[torch.Tensor], 143 | masks: Optional[torch.Tensor], 144 | ) -> int: 145 | """ 146 | Gets the batch size of the output given the batch size of the input prompts. 147 | """ 148 | if points is not None: 149 | return points[0].shape[0] 150 | elif boxes is not None: 151 | return boxes.shape[0] 152 | elif masks is not None: 153 | return masks.shape[0] 154 | else: 155 | return 1 156 | 157 | def _get_device(self) -> torch.device: 158 | return self.point_embeddings[0].weight.device 159 | 160 | def forward( 161 | self, 162 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 163 | boxes: Optional[torch.Tensor], 164 | masks: Optional[torch.Tensor], 165 | ) -> Tuple[torch.Tensor, torch.Tensor]: 166 | """ 167 | Embeds different types of prompts, returning both sparse and dense 168 | embeddings. 169 | 170 | Arguments: 171 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 172 | and labels to embed. 173 | boxes (torch.Tensor or none): boxes to embed 174 | masks (torch.Tensor or none): masks to embed 175 | 176 | Returns: 177 | torch.Tensor: sparse embeddings for the points and boxes, with shape 178 | BxNx(embed_dim), where N is determined by the number of input points 179 | and boxes. 180 | torch.Tensor: dense embeddings for the masks, in the shape 181 | Bx(embed_dim)x(embed_H)x(embed_W) 182 | """ 183 | bs = self._get_batch_size(points, boxes, masks) 184 | sparse_embeddings = torch.empty( 185 | (bs, 0, self.embed_dim), device=self._get_device() 186 | ) 187 | if points is not None: 188 | coords, labels = points 189 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 190 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 191 | if boxes is not None: 192 | box_embeddings = self._embed_boxes(boxes) 193 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 194 | 195 | if masks is not None: 196 | dense_embeddings = self._embed_masks(masks) 197 | else: 198 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 199 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 200 | ) 201 | 202 | return sparse_embeddings, dense_embeddings 203 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/training/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py 8 | import atexit 9 | import functools 10 | import logging 11 | import sys 12 | import uuid 13 | from typing import Any, Dict, Optional, Union 14 | 15 | from hydra.utils import instantiate 16 | 17 | from iopath.common.file_io import g_pathmgr 18 | from numpy import ndarray 19 | from torch import Tensor 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | from training.utils.train_utils import get_machine_local_and_dist_rank, makedir 23 | 24 | Scalar = Union[Tensor, ndarray, int, float] 25 | 26 | 27 | def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): 28 | makedir(log_dir) 29 | summary_writer_method = SummaryWriter 30 | return TensorBoardLogger( 31 | path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs 32 | ) 33 | 34 | 35 | class TensorBoardWriterWrapper: 36 | """ 37 | A wrapper around a SummaryWriter object. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | path: str, 43 | *args: Any, 44 | filename_suffix: str = None, 45 | summary_writer_method: Any = SummaryWriter, 46 | **kwargs: Any, 47 | ) -> None: 48 | """Create a new TensorBoard logger. 49 | On construction, the logger creates a new events file that logs 50 | will be written to. If the environment variable `RANK` is defined, 51 | logger will only log if RANK = 0. 52 | 53 | NOTE: If using the logger with distributed training: 54 | - This logger can call collective operations 55 | - Logs will be written on rank 0 only 56 | - Logger must be constructed synchronously *after* initializing distributed process group. 57 | 58 | Args: 59 | path (str): path to write logs to 60 | *args, **kwargs: Extra arguments to pass to SummaryWriter 61 | """ 62 | self._writer: Optional[SummaryWriter] = None 63 | _, self._rank = get_machine_local_and_dist_rank() 64 | self._path: str = path 65 | if self._rank == 0: 66 | logging.info( 67 | f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" 68 | ) 69 | self._writer = summary_writer_method( 70 | log_dir=path, 71 | *args, 72 | filename_suffix=filename_suffix or str(uuid.uuid4()), 73 | **kwargs, 74 | ) 75 | else: 76 | logging.debug( 77 | f"Not logging meters on this host because env RANK: {self._rank} != 0" 78 | ) 79 | atexit.register(self.close) 80 | 81 | @property 82 | def writer(self) -> Optional[SummaryWriter]: 83 | return self._writer 84 | 85 | @property 86 | def path(self) -> str: 87 | return self._path 88 | 89 | def flush(self) -> None: 90 | """Writes pending logs to disk.""" 91 | 92 | if not self._writer: 93 | return 94 | 95 | self._writer.flush() 96 | 97 | def close(self) -> None: 98 | """Close writer, flushing pending logs to disk. 99 | Logs cannot be written after `close` is called. 100 | """ 101 | 102 | if not self._writer: 103 | return 104 | 105 | self._writer.close() 106 | self._writer = None 107 | 108 | 109 | class TensorBoardLogger(TensorBoardWriterWrapper): 110 | """ 111 | A simple logger for TensorBoard. 112 | """ 113 | 114 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 115 | """Add multiple scalar values to TensorBoard. 116 | 117 | Args: 118 | payload (dict): dictionary of tag name and scalar value 119 | step (int, Optional): step value to record 120 | """ 121 | if not self._writer: 122 | return 123 | for k, v in payload.items(): 124 | self.log(k, v, step) 125 | 126 | def log(self, name: str, data: Scalar, step: int) -> None: 127 | """Add scalar data to TensorBoard. 128 | 129 | Args: 130 | name (string): tag name used to group scalars 131 | data (float/int/Tensor): scalar data to log 132 | step (int, optional): step value to record 133 | """ 134 | if not self._writer: 135 | return 136 | self._writer.add_scalar(name, data, global_step=step, new_style=True) 137 | 138 | def log_hparams( 139 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 140 | ) -> None: 141 | """Add hyperparameter data to TensorBoard. 142 | 143 | Args: 144 | hparams (dict): dictionary of hyperparameter names and corresponding values 145 | meters (dict): dictionary of name of meter and corersponding values 146 | """ 147 | if not self._writer: 148 | return 149 | self._writer.add_hparams(hparams, meters) 150 | 151 | 152 | class Logger: 153 | """ 154 | A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. 155 | """ 156 | 157 | def __init__(self, logging_conf): 158 | # allow turning off TensorBoard with "should_log: false" in config 159 | tb_config = logging_conf.tensorboard_writer 160 | tb_should_log = tb_config and tb_config.pop("should_log", True) 161 | self.tb_logger = instantiate(tb_config) if tb_should_log else None 162 | 163 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 164 | if self.tb_logger: 165 | self.tb_logger.log_dict(payload, step) 166 | 167 | def log(self, name: str, data: Scalar, step: int) -> None: 168 | if self.tb_logger: 169 | self.tb_logger.log(name, data, step) 170 | 171 | def log_hparams( 172 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 173 | ) -> None: 174 | if self.tb_logger: 175 | self.tb_logger.log_hparams(hparams, meters) 176 | 177 | 178 | # cache the opened file object, so that different calls to `setup_logger` 179 | # with the same file name can safely write to the same file. 180 | @functools.lru_cache(maxsize=None) 181 | def _cached_log_stream(filename): 182 | # we tune the buffering value so that the logs are updated 183 | # frequently. 184 | log_buffer_kb = 10 * 1024 # 10KB 185 | io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) 186 | atexit.register(io.close) 187 | return io 188 | 189 | 190 | def setup_logging( 191 | name, 192 | output_dir=None, 193 | rank=0, 194 | log_level_primary="INFO", 195 | log_level_secondary="ERROR", 196 | ): 197 | """ 198 | Setup various logging streams: stdout and file handlers. 199 | For file handlers, we only setup for the master gpu. 200 | """ 201 | # get the filename if we want to log to the file as well 202 | log_filename = None 203 | if output_dir: 204 | makedir(output_dir) 205 | if rank == 0: 206 | log_filename = f"{output_dir}/log.txt" 207 | 208 | logger = logging.getLogger(name) 209 | logger.setLevel(log_level_primary) 210 | 211 | # create formatter 212 | FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" 213 | formatter = logging.Formatter(FORMAT) 214 | 215 | # Cleanup any existing handlers 216 | for h in logger.handlers: 217 | logger.removeHandler(h) 218 | logger.root.handlers = [] 219 | 220 | # setup the console handler 221 | console_handler = logging.StreamHandler(sys.stdout) 222 | console_handler.setFormatter(formatter) 223 | logger.addHandler(console_handler) 224 | if rank == 0: 225 | console_handler.setLevel(log_level_primary) 226 | else: 227 | console_handler.setLevel(log_level_secondary) 228 | 229 | # we log to file as well if user wants 230 | if log_filename and rank == 0: 231 | file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) 232 | file_handler.setLevel(log_level_primary) 233 | file_handler.setFormatter(formatter) 234 | logger.addHandler(file_handler) 235 | 236 | logging.root = logger 237 | 238 | 239 | def shutdown_logging(): 240 | """ 241 | After training is done, we ensure to shut down all the logger streams. 242 | """ 243 | logging.info("Shutting down loggers...") 244 | handlers = logging.root.handlers 245 | for handler in handlers: 246 | handler.close() 247 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be an even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be an even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /mmmg_eval/sam2/sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | # Following settings only relevant 29 | # for warmping up cache for compilation 30 | warmup_cache: bool = True, 31 | image_size: int = 1024, 32 | strides: Tuple[int] = (4, 8, 16, 32), 33 | ): 34 | super().__init__() 35 | assert num_pos_feats % 2 == 0, "Expecting even model width" 36 | self.num_pos_feats = num_pos_feats // 2 37 | self.temperature = temperature 38 | self.normalize = normalize 39 | if scale is not None and normalize is False: 40 | raise ValueError("normalize should be True if scale is passed") 41 | if scale is None: 42 | scale = 2 * math.pi 43 | self.scale = scale 44 | 45 | self.cache = {} 46 | if warmup_cache and torch.cuda.is_available(): 47 | # Warmup cache for cuda, to help with compilation 48 | device = torch.device("cuda") 49 | for stride in strides: 50 | cache_key = (image_size // stride, image_size // stride) 51 | self._pe(1, device, *cache_key) 52 | 53 | def _encode_xy(self, x, y): 54 | # The positions are expected to be normalized 55 | assert len(x) == len(y) and x.ndim == y.ndim == 1 56 | x_embed = x * self.scale 57 | y_embed = y * self.scale 58 | 59 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 60 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 61 | 62 | pos_x = x_embed[:, None] / dim_t 63 | pos_y = y_embed[:, None] / dim_t 64 | pos_x = torch.stack( 65 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 66 | ).flatten(1) 67 | pos_y = torch.stack( 68 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 69 | ).flatten(1) 70 | return pos_x, pos_y 71 | 72 | @torch.no_grad() 73 | def encode_boxes(self, x, y, w, h): 74 | pos_x, pos_y = self._encode_xy(x, y) 75 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 76 | return pos 77 | 78 | encode = encode_boxes # Backwards compatibility 79 | 80 | @torch.no_grad() 81 | def encode_points(self, x, y, labels): 82 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 83 | assert bx == by and nx == ny and bx == bl and nx == nl 84 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 85 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 86 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 87 | return pos 88 | 89 | @torch.no_grad() 90 | def _pe(self, B, device, *cache_key): 91 | H, W = cache_key 92 | if cache_key in self.cache: 93 | return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) 94 | 95 | y_embed = ( 96 | torch.arange(1, H + 1, dtype=torch.float32, device=device) 97 | .view(1, -1, 1) 98 | .repeat(B, 1, W) 99 | ) 100 | x_embed = ( 101 | torch.arange(1, W + 1, dtype=torch.float32, device=device) 102 | .view(1, 1, -1) 103 | .repeat(B, H, 1) 104 | ) 105 | 106 | if self.normalize: 107 | eps = 1e-6 108 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 109 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 110 | 111 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) 112 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 113 | 114 | pos_x = x_embed[:, :, :, None] / dim_t 115 | pos_y = y_embed[:, :, :, None] / dim_t 116 | pos_x = torch.stack( 117 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 118 | ).flatten(3) 119 | pos_y = torch.stack( 120 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 121 | ).flatten(3) 122 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 123 | self.cache[cache_key] = pos[0] 124 | return pos 125 | 126 | @torch.no_grad() 127 | def forward(self, x: torch.Tensor): 128 | B = x.shape[0] 129 | cache_key = (x.shape[-2], x.shape[-1]) 130 | return self._pe(B, x.device, *cache_key) 131 | 132 | 133 | class PositionEmbeddingRandom(nn.Module): 134 | """ 135 | Positional encoding using random spatial frequencies. 136 | """ 137 | 138 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 139 | super().__init__() 140 | if scale is None or scale <= 0.0: 141 | scale = 1.0 142 | self.register_buffer( 143 | "positional_encoding_gaussian_matrix", 144 | scale * torch.randn((2, num_pos_feats)), 145 | ) 146 | 147 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 148 | """Positionally encode points that are normalized to [0,1].""" 149 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 150 | coords = 2 * coords - 1 151 | coords = coords @ self.positional_encoding_gaussian_matrix 152 | coords = 2 * np.pi * coords 153 | # outputs d_1 x ... x d_n x C shape 154 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 155 | 156 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 157 | """Generate positional encoding for a grid of the specified size.""" 158 | h, w = size 159 | device: Any = self.positional_encoding_gaussian_matrix.device 160 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 161 | y_embed = grid.cumsum(dim=0) - 0.5 162 | x_embed = grid.cumsum(dim=1) - 0.5 163 | y_embed = y_embed / h 164 | x_embed = x_embed / w 165 | 166 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 167 | return pe.permute(2, 0, 1) # C x H x W 168 | 169 | def forward_with_coords( 170 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 171 | ) -> torch.Tensor: 172 | """Positionally encode points that are not normalized to [0,1].""" 173 | coords = coords_input.clone() 174 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 175 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 176 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 177 | 178 | 179 | # Rotary Positional Encoding, adapted from: 180 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 181 | # 2. https://github.com/naver-ai/rope-vit 182 | # 3. https://github.com/lucidrains/rotary-embedding-torch 183 | 184 | 185 | def init_t_xy(end_x: int, end_y: int): 186 | t = torch.arange(end_x * end_y, dtype=torch.float32) 187 | t_x = (t % end_x).float() 188 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 189 | return t_x, t_y 190 | 191 | 192 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 193 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 194 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 195 | 196 | t_x, t_y = init_t_xy(end_x, end_y) 197 | freqs_x = torch.outer(t_x, freqs_x) 198 | freqs_y = torch.outer(t_y, freqs_y) 199 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 200 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 201 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 202 | 203 | 204 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 205 | ndim = x.ndim 206 | assert 0 <= 1 < ndim 207 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 208 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 209 | return freqs_cis.view(*shape) 210 | 211 | 212 | def apply_rotary_enc( 213 | xq: torch.Tensor, 214 | xk: torch.Tensor, 215 | freqs_cis: torch.Tensor, 216 | repeat_freqs_k: bool = False, 217 | ): 218 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 219 | xk_ = ( 220 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 221 | if xk.shape[-2] != 0 222 | else None 223 | ) 224 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 225 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 226 | if xk_ is None: 227 | # no keys to rotate, due to dropout 228 | return xq_out.type_as(xq).to(xq.device), xk 229 | # repeat freqs along seq_len dim to match k seq_len 230 | if repeat_freqs_k: 231 | r = xk_.shape[-2] // xq_.shape[-2] 232 | if freqs_cis.is_cuda: 233 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 234 | else: 235 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 236 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 237 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 238 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 239 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 240 | --------------------------------------------------------------------------------