├── .gitignore ├── LICENSE ├── README.md ├── assets ├── avatar.gif ├── bees.gif ├── figure-1.png ├── horsejump-high.gif ├── interactive-camel.gif ├── interactive-drift-straight.gif ├── interactive-loading.gif └── street.gif ├── configs ├── demo.yaml ├── logging │ ├── base.yaml │ ├── vis_eval.yaml │ ├── vos_eval.yaml │ └── wandb │ │ └── base.yaml ├── model │ ├── point_tracker │ │ ├── cotracker.yaml │ │ ├── pips.yaml │ │ ├── pips_plus_plus.yaml │ │ ├── raft.yaml │ │ ├── superglue.yaml │ │ ├── tapir.yaml │ │ └── tapnet.yaml │ ├── sam │ │ ├── image_encoder │ │ │ ├── vit_base.yaml │ │ │ ├── vit_huge.yaml │ │ │ └── vit_large.yaml │ │ ├── mask_decoder │ │ │ └── sam.yaml │ │ ├── prompt_encoder │ │ │ └── sam.yaml │ │ ├── sam_mobile_vit_tiny.yaml │ │ ├── sam_vit_base.yaml │ │ ├── sam_vit_huge.yaml │ │ ├── sam_vit_large.yaml │ │ ├── samhq_light_vit_tiny.yaml │ │ └── samhq_vit_huge.yaml │ └── sam_pt.yaml ├── vis_eval_root.yaml ├── vis_eval_sam_pt.yaml └── vos_eval_root.yaml ├── data └── demo_data │ ├── README.md │ ├── bees.mp4 │ ├── query_points__bees.txt │ ├── query_points__street.txt │ └── street.mp4 ├── demo ├── __init__.py └── demo.py ├── docs ├── 01-getting-started.md ├── 02-prepare-datasets.md ├── 03-prepare-checkpoints.md └── 04-running-experiments.md ├── requirements-jax.txt ├── requirements.txt ├── sam_pt ├── __init__.py ├── modeling │ ├── __init__.py │ ├── sam.py │ ├── sam_pt.py │ ├── sam_pt_interactive.py │ └── vis_to_vos_adapter.py ├── point_tracker │ ├── __init__.py │ ├── cotracker │ │ ├── __init__.py │ │ └── tracker.py │ ├── pips │ │ ├── __init__.py │ │ ├── pips.py │ │ └── tracker.py │ ├── pips_plus_plus │ │ ├── __init__.py │ │ ├── pips_plus_plus.py │ │ └── tracker.py │ ├── raft │ │ ├── __init__.py │ │ ├── raft_core │ │ │ ├── __init__.py │ │ │ ├── corr.py │ │ │ ├── extractor.py │ │ │ ├── raft.py │ │ │ ├── update.py │ │ │ └── util.py │ │ ├── raftnet.py │ │ └── tracker.py │ ├── superglue │ │ ├── __init__.py │ │ ├── match_pairs.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── matching.py │ │ │ ├── superglue.py │ │ │ ├── superpoint.py │ │ │ └── utils.py │ │ └── tracker.py │ ├── tapir │ │ ├── __init__.py │ │ ├── configs │ │ │ └── tapir_config.py │ │ ├── demo.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── resnet.py │ │ ├── tapir_model.py │ │ ├── tracker.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── model_utils.py │ │ │ └── transforms.py │ ├── tapnet │ │ ├── __init__.py │ │ ├── configs │ │ │ └── tapnet_config.py │ │ ├── demo.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── tsm_resnet.py │ │ │ └── tsm_utils.py │ │ ├── tapnet_model.py │ │ ├── tracker.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── transforms.py │ ├── tracker.py │ └── utils │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── improc.py │ │ ├── misc.py │ │ ├── samp.py │ │ ├── saverloader.py │ │ └── test.py ├── utils │ ├── __init__.py │ ├── query_points.py │ └── util.py ├── vis_eval │ ├── __init__.py │ ├── eval.py │ ├── mask2former │ │ ├── __init__.py │ │ └── config.py │ ├── mask2former_video │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ └── data_video │ │ │ ├── __init__.py │ │ │ ├── augmentation.py │ │ │ ├── build.py │ │ │ ├── dataset_mapper.py │ │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── builtin.py │ │ │ ├── uvo.py │ │ │ ├── ytvis.py │ │ │ └── ytvis_api │ │ │ │ ├── __init__.py │ │ │ │ ├── ytvos.py │ │ │ │ └── ytvoseval.py │ │ │ └── ytvis_eval.py │ └── train_net_video.py └── vos_eval │ ├── __init__.py │ ├── bdd100keval.py │ ├── data │ ├── __init__.py │ ├── mask_mapper.py │ ├── test_datasets.py │ └── video_reader.py │ ├── davis2017eval.py │ ├── eval.py │ └── evaluator.py └── scripts ├── annotation_comparison_gif.py ├── bdd100k_from_instance_seg_to_vos_annotations.py ├── clean_tapnet_checkpoint.py ├── davis_mask_to_contour.py ├── uvo_video2frames.py └── visualize_point_sampling_methods.py /.gitignore: -------------------------------------------------------------------------------- 1 | artifacts 2 | /logs 3 | /wandb 4 | /data/** 5 | !/data/demo_data 6 | !/data/demo_data/README.md 7 | !/data/demo_data/bees.mp4 8 | !/data/demo_data/street.mp4 9 | !/data/demo_data/query_points__bees.txt 10 | !/data/demo_data/query_points__street.txt 11 | /outputs 12 | /output 13 | instant_test_output 14 | inference_test_output 15 | experiments 16 | 17 | *.png 18 | *.json 19 | *.diff 20 | *.jpg 21 | !/projects/DensePose/doc/images/*.jpg 22 | 23 | # compilation and distribution 24 | __pycache__ 25 | _ext 26 | *.pyc 27 | *.pyd 28 | *.so 29 | *.dll 30 | *.egg-info/ 31 | build/ 32 | dist/ 33 | wheels/ 34 | 35 | # pytorch/python/numpy formats 36 | *.pth 37 | *.pkl 38 | *.npy 39 | *.ts 40 | model_ts*.txt 41 | 42 | # ipython/jupyter notebooks 43 | *.ipynb 44 | **/.ipynb_checkpoints/ 45 | 46 | # Editor temporaries 47 | *.swn 48 | *.swo 49 | *.swp 50 | *~ 51 | 52 | # editor settings 53 | .idea 54 | .vscode 55 | _darcs 56 | 57 | # project dirs 58 | /detectron2/model_zoo/configs 59 | /datasets/* 60 | !/datasets/*.* 61 | /projects/*/datasets 62 | /models 63 | /snippet 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Meets Point Tracking 2 | 3 | > [**Segment Anything Meets Point Tracking**](https://arxiv.org/abs/2307.01197) \ 4 | > [Frano Rajič](https://m43.github.io/), [Lei Ke](http://www.kelei.site/), [Yu-Wing Tai](https://yuwingtai.github.io/), [Chi-Keung Tang](http://home.cse.ust.hk/~cktang/bio.html), [Martin Danelljan](https://martin-danelljan.github.io/), [Fisher Yu](https://www.yf.io/) \ 5 | > ETH Zürich, HKUST, EPFL 6 | 7 | 8 | ![SAM-PT design](assets/figure-1.png?raw=true) 9 | 10 | We propose SAM-PT, an extension of the [Segment Anything Model](https://github.com/facebookresearch/segment-anything) (SAM) for zero-shot video segmentation. Our work offers a simple yet effective point-based perspective in video object segmentation research. For more details, refer to our paper. 11 | 12 | ## Video Object Segmentation Demo 13 | 14 | Annotators only provide a few points to denote the target object at the first video frame to get video segmentation results. Please visit our [project page](https://www.vis.xyz/pub/sam-pt/) for more visualizations, including qualitative results on DAVIS 2017 videos and more Avatar clips. 15 |

16 | street 17 | bees 18 | avatar 19 | horsejump-high 20 |

21 | 22 | ## Interactive Point-Based Video Segmentation 23 | 24 | Annotators can interactively add or remove points to refine the segmentation results. 25 |

26 | camel 27 | drift 28 | loading 29 |

30 | 31 | ## Documentation 32 | 33 | Explore our step-by-step guides to get up and running: 34 | 35 | 1. [Getting Started](./docs/01-getting-started.md): Learn how to set up your environment and run the demo. 36 | 2. [Prepare Datasets](./docs/02-prepare-datasets.md): Instructions on acquiring and prepping necessary datasets. 37 | 3. [Prepare Checkpoints](./docs/03-prepare-checkpoints.md): Steps to fetch model checkpoints. 38 | 4. [Running Experiments](./docs/04-running-experiments.md): Details on how to execute experiments. 39 | 40 | ## Acknowledgments 41 | 42 | We want to thank [SAM](https://github.com/facebookresearch/segment-anything), [PIPS](https://github.com/aharley/pips), [CoTracker](https://github.com/facebookresearch/co-tracker), [HQ-SAM](https://github.com/SysCV/sam-hq), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [XMem](https://github.com/hkchengrex/XMem), and [Mask2Former](https://github.com/facebookresearch/Mask2Former) for publicly releasing their code and pretrained models. 43 | 44 | ## Citation 45 | 46 | If you find SAM-PT useful in your research or if you refer to the results mentioned in our work, please star :star: this repository and consider citing :pencil:: 47 | ```bibtex 48 | @article{sam-pt, 49 | title = {Segment Anything Meets Point Tracking}, 50 | author = {Rajič, Frano and Ke, Lei and Tai, Yu-Wing and Tang, Chi-Keung and Danelljan, Martin and Yu, Fisher}, 51 | journal = {arXiv:2307.01197}, 52 | year = {2023} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /assets/avatar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/avatar.gif -------------------------------------------------------------------------------- /assets/bees.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/bees.gif -------------------------------------------------------------------------------- /assets/figure-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/figure-1.png -------------------------------------------------------------------------------- /assets/horsejump-high.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/horsejump-high.gif -------------------------------------------------------------------------------- /assets/interactive-camel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-camel.gif -------------------------------------------------------------------------------- /assets/interactive-drift-straight.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-drift-straight.gif -------------------------------------------------------------------------------- /assets/interactive-loading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-loading.gif -------------------------------------------------------------------------------- /assets/street.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/street.gif -------------------------------------------------------------------------------- /configs/demo.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: sam_pt 3 | - logging: base 4 | - _self_ 5 | 6 | logging: 7 | wandb: 8 | project: demo 9 | 10 | model: 11 | iterative_refinement_iterations: 12 12 | add_other_objects_positive_points_as_negative_points: true 13 | use_point_reinit: false 14 | positive_points_per_mask: -1 15 | negative_points_per_mask: -1 16 | 17 | frames_path: ${hydra:runtime.cwd}/data/demo_data/bees # Path to the folder with frames of the video 18 | query_points_path: ${hydra:runtime.cwd}/data/demo_data/query_points__bees.txt # Path or null 19 | 20 | longest_side_length: 1024 # Resize the image so that the longest side is of this length 21 | frame_stride: 1 # Evaluate on every n frames 22 | max_frames: null # Maximum number of video frames to evaluate for 23 | 24 | seed: 72 25 | 26 | annot_size: 16 # Size of the point annotations in visualisations 27 | annot_line_width: 6 # Line width of the point annotations in visualisations -------------------------------------------------------------------------------- /configs/logging/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - wandb: base 3 | 4 | debug: false 5 | exp_id: debug 6 | -------------------------------------------------------------------------------- /configs/logging/vis_eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - wandb: base 3 | 4 | exp_id: eval 5 | wandb: 6 | project: point-tracking-for-vis 7 | -------------------------------------------------------------------------------- /configs/logging/vos_eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - wandb: base 3 | 4 | exp_id: eval 5 | wandb: 6 | project: point-tracking-for-vos 7 | -------------------------------------------------------------------------------- /configs/logging/wandb/base.yaml: -------------------------------------------------------------------------------- 1 | entity: null 2 | project: ??? 3 | tensorboard: true 4 | log_code_path: ${hydra:runtime.cwd}/sam_pt 5 | -------------------------------------------------------------------------------- /configs/model/point_tracker/cotracker.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.cotracker.CoTrackerPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_4_wind_8.pth" 3 | #checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_4_wind_12.pth" 4 | #checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_8_wind_16.pth" 5 | 6 | interp_shape: [384, 512] 7 | visibility_threshold: 0.7 8 | support_grid_size: 2 9 | support_grid_every_n_frames: 12 10 | 11 | add_debug_visualisations: false 12 | -------------------------------------------------------------------------------- /configs/model/point_tracker/pips.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.pips.PipsPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/pips_ckpts/reference_model" 3 | stride: 4 4 | s: 8 5 | initial_next_frame_visibility_threshold: 0.9 6 | -------------------------------------------------------------------------------- /configs/model/point_tracker/pips_plus_plus.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.pips_plus_plus.PipsPlusPlusPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/pips_plus_plus_ckpts/reference_model" 3 | stride: 8 4 | max_sequence_length: 128 5 | iters: 16 6 | image_size: null # [ 512, 896 ] 7 | -------------------------------------------------------------------------------- /configs/model/point_tracker/raft.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.raft.RaftPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/raft_ckpts/raft-things.pth" 3 | -------------------------------------------------------------------------------- /configs/model/point_tracker/superglue.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.superglue.SuperGluePointTracker 2 | 3 | positive_points_per_mask: ${..positive_points_per_mask} 4 | negative_points_per_mask: ${..negative_points_per_mask} 5 | 6 | #resize: [ 640, 480 ] 7 | resize: [ -1, -1 ] 8 | 9 | matching_config: 10 | superpoint: 11 | checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superpoint_v1.pth 12 | nms_radius: 3 13 | keypoint_threshold: 0.005 14 | max_keypoints: -1 15 | descriptor_dim: 256 16 | remove_borders: 4 17 | superglue: 18 | #checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superglue_indoor.pth 19 | checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superglue_outdoor.pth 20 | sinkhorn_iterations: 20 21 | match_threshold: 0.2 22 | -------------------------------------------------------------------------------- /configs/model/point_tracker/tapir.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.tapir.TapirPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/tapir_ckpts/open_source_ckpt/tapir_checkpoint_panning.npy" 3 | visibility_threshold: 0.1 4 | -------------------------------------------------------------------------------- /configs/model/point_tracker/tapnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.point_tracker.tapnet.TapnetPointTracker 2 | checkpoint_path: "${hydra:runtime.cwd}/models/tapnet_ckpts/open_source_ckpt/checkpoint_wo_optstate.npy" 3 | visibility_threshold: 0.5 4 | -------------------------------------------------------------------------------- /configs/model/sam/image_encoder/vit_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: segment_anything.modeling.image_encoder.ImageEncoderViT 2 | depth: 12 3 | embed_dim: 768 4 | img_size: ${ ..image_size } 5 | mlp_ratio: 4 6 | norm_layer: 7 | _partial_: true 8 | _target_: torch.nn.LayerNorm 9 | eps: 1e-6 10 | num_heads: 12 11 | patch_size: ${ ..vit_patch_size } 12 | qkv_bias: True 13 | use_rel_pos: True 14 | global_attn_indexes: [ 2, 5, 8, 11 ] 15 | window_size: 14 16 | out_chans: ${ ..prompt_embed_dim } 17 | -------------------------------------------------------------------------------- /configs/model/sam/image_encoder/vit_huge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vit_base 3 | depth: 32 4 | embed_dim: 1280 5 | num_heads: 16 6 | global_attn_indexes: [ 7, 15, 23, 31 ] 7 | -------------------------------------------------------------------------------- /configs/model/sam/image_encoder/vit_large.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vit_base 3 | depth: 24 4 | embed_dim: 1024 5 | num_heads: 16 6 | global_attn_indexes: [ 5, 11, 17, 23 ] 7 | -------------------------------------------------------------------------------- /configs/model/sam/mask_decoder/sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: segment_anything.modeling.mask_decoder.MaskDecoder 2 | num_multimask_outputs: 3 3 | transformer: 4 | _target_: segment_anything.modeling.transformer.TwoWayTransformer 5 | depth: 2 6 | embedding_dim: ${ ...prompt_embed_dim } 7 | mlp_dim: 2048 8 | num_heads: 8 9 | transformer_dim: ${ ..prompt_embed_dim } 10 | iou_head_depth: 3 11 | iou_head_hidden_dim: 256 12 | -------------------------------------------------------------------------------- /configs/model/sam/prompt_encoder/sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: segment_anything.modeling.prompt_encoder.PromptEncoder 2 | embed_dim: ${ ..prompt_embed_dim } 3 | image_embedding_size: 4 | - ${ ...image_embedding_size } 5 | - ${ ...image_embedding_size } 6 | input_image_size: 7 | - ${ ...image_size } 8 | - ${ ...image_size } 9 | mask_in_chans: 16 10 | -------------------------------------------------------------------------------- /configs/model/sam/sam_mobile_vit_tiny.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - prompt_encoder: sam 3 | - mask_decoder: sam 4 | - _self_ 5 | 6 | _target_: sam_pt.modeling.sam.MobileSamHydra 7 | 8 | checkpoint: ${hydra:runtime.cwd}/models/sam_mobile_ckpts/sam_mobile_vit_t.pth 9 | 10 | prompt_embed_dim: 256 11 | image_size: 1024 12 | vit_patch_size: 16 13 | image_embedding_size: 64 14 | 15 | pixel_mean: [ 123.675, 116.28, 103.53 ] 16 | pixel_std: [ 58.395, 57.12, 57.375 ] 17 | 18 | image_encoder: 19 | _target_: mobile_sam.modeling.TinyViT 20 | img_size: ${..image_size} 21 | in_chans: 3 22 | num_classes: 1000 23 | embed_dims: [ 64, 128, 160, 320 ] 24 | depths: [ 2, 2, 6, 2 ] 25 | num_heads: [ 2, 4, 5, 10 ] 26 | window_sizes: [ 7, 7, 14, 7 ] 27 | mlp_ratio: 4. 28 | drop_rate: 0. 29 | drop_path_rate: 0.0 30 | use_checkpoint: False 31 | mbconv_expand_ratio: 4.0 32 | local_conv_size: 3 33 | layer_lr_decay: 0.8 34 | 35 | prompt_encoder: 36 | _target_: mobile_sam.modeling.prompt_encoder.PromptEncoder 37 | 38 | mask_decoder: 39 | _target_: mobile_sam.modeling.mask_decoder.MaskDecoder 40 | transformer: 41 | _target_: mobile_sam.modeling.transformer.TwoWayTransformer 42 | -------------------------------------------------------------------------------- /configs/model/sam/sam_vit_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - image_encoder: vit_base 3 | - prompt_encoder: sam 4 | - mask_decoder: sam 5 | 6 | _target_: sam_pt.modeling.sam.SamHydra 7 | 8 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_b_01ec64.pth 9 | 10 | prompt_embed_dim: 256 11 | image_size: 1024 12 | vit_patch_size: 16 13 | image_embedding_size: 64 14 | 15 | pixel_mean: [ 123.675, 116.28, 103.53 ] 16 | pixel_std: [ 58.395, 57.12, 57.375 ] 17 | -------------------------------------------------------------------------------- /configs/model/sam/sam_vit_huge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - sam_vit_base 3 | - override image_encoder: vit_huge 4 | 5 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_h_4b8939.pth 6 | -------------------------------------------------------------------------------- /configs/model/sam/sam_vit_large.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - sam_vit_base 3 | - override image_encoder: vit_large 4 | 5 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_l_0b3195.pth 6 | -------------------------------------------------------------------------------- /configs/model/sam/samhq_light_vit_tiny.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - prompt_encoder: sam 3 | - mask_decoder: sam 4 | - _self_ 5 | 6 | _target_: sam_pt.modeling.sam.SamHQHydra 7 | 8 | checkpoint: ${hydra:runtime.cwd}/models/samhq_ckpts/sam_hq_vit_t.pth 9 | 10 | prompt_embed_dim: 256 11 | image_size: 1024 12 | vit_patch_size: 16 13 | image_embedding_size: 64 14 | 15 | pixel_mean: [ 123.675, 116.28, 103.53 ] 16 | pixel_std: [ 58.395, 57.12, 57.375 ] 17 | 18 | image_encoder: 19 | _target_: segment_anything_hq.modeling.TinyViT 20 | img_size: ${..image_size} 21 | in_chans: 3 22 | num_classes: 1000 23 | embed_dims: [ 64, 128, 160, 320 ] 24 | depths: [ 2, 2, 6, 2 ] 25 | num_heads: [ 2, 4, 5, 10 ] 26 | window_sizes: [ 7, 7, 14, 7 ] 27 | mlp_ratio: 4. 28 | drop_rate: 0. 29 | drop_path_rate: 0.0 30 | use_checkpoint: False 31 | mbconv_expand_ratio: 4.0 32 | local_conv_size: 3 33 | layer_lr_decay: 0.8 34 | 35 | prompt_encoder: 36 | _target_: segment_anything_hq.modeling.prompt_encoder.PromptEncoder 37 | 38 | mask_decoder: 39 | _target_: segment_anything_hq.modeling.mask_decoder_hq.MaskDecoderHQ 40 | vit_dim: 160 41 | -------------------------------------------------------------------------------- /configs/model/sam/samhq_vit_huge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - image_encoder: vit_huge 3 | - prompt_encoder: sam 4 | - mask_decoder: sam 5 | - _self_ 6 | 7 | _target_: sam_pt.modeling.sam.SamHQHydra 8 | 9 | checkpoint: ${hydra:runtime.cwd}/models/samhq_ckpts/sam_hq_vit_h.pth 10 | 11 | prompt_embed_dim: 256 12 | image_size: 1024 13 | vit_patch_size: 16 14 | image_embedding_size: 64 15 | 16 | pixel_mean: [ 123.675, 116.28, 103.53 ] 17 | pixel_std: [ 58.395, 57.12, 57.375 ] 18 | 19 | image_encoder: 20 | _target_: segment_anything_hq.modeling.image_encoder.ImageEncoderViT 21 | 22 | prompt_encoder: 23 | _target_: segment_anything_hq.modeling.prompt_encoder.PromptEncoder 24 | 25 | mask_decoder: 26 | _target_: segment_anything_hq.modeling.mask_decoder_hq.MaskDecoderHQ 27 | vit_dim: ${..image_encoder.embed_dim} 28 | -------------------------------------------------------------------------------- /configs/model/sam_pt.yaml: -------------------------------------------------------------------------------- 1 | _target_: sam_pt.modeling.sam_pt.SamPt 2 | 3 | defaults: 4 | - point_tracker: cotracker 5 | - sam@sam_predictor.sam_model: samhq_vit_huge 6 | 7 | sam_predictor: 8 | _target_: segment_anything_hq.predictor.SamPredictor 9 | 10 | sam_iou_threshold: 0.7 11 | 12 | iterative_refinement_iterations: 12 13 | 14 | positive_point_selection_method: "kmedoids" # kmedoids, shi-tomasi, random, mixed 15 | negative_point_selection_method: "mixed" # kmedoids, shi-tomasi, random, mixed 16 | positive_points_per_mask: 16 17 | negative_points_per_mask: 1 18 | add_other_objects_positive_points_as_negative_points: true 19 | max_other_objects_positive_points: null 20 | 21 | point_tracker_mask_batch_size: 5 22 | 23 | use_patch_matching_filtering: false 24 | patch_size: 3 25 | patch_similarity_threshold: 0.01 26 | 27 | use_point_reinit: false 28 | reinit_point_tracker_horizon: 24 29 | reinit_horizon: 24 30 | reinit_variant: "reinit-at-median-of-area-diff" 31 | # Reinitialization variants: 32 | # A) reinit-on-horizon-and-sync-masks: 33 | # - simplest variant: reinitialize the points after a fixed number of 34 | # frames (e.g., every 8 frames) can fail if the mask happens to be 35 | # empty at the reinitialization timestep 36 | # - as fast as not using reinit 37 | # B) reinit-at-median-of-area-diff: 38 | # - reinitialize points for the non-empty mask with the mean mask area 39 | # - multiple times slower than no reinit, as many sam masks will be 40 | # rejected (e.g., 8 masks were computed, but we might reinit on the 41 | # second one, recomputing the rejected masks in the next step again) 42 | # C) reinit-on-similar-mask-area: 43 | # - reinit when the mask area is similar to the initial mask area 44 | # - multiple times slower than no reinit 45 | # D) reinit-on-similar-mask-area-and-sync-masks: 46 | # - reinit when the mask area is similar to the initial mask area for 47 | # all masks in the batch and synchronize the masks to be tracked from 48 | # the same timestep, as to be able to use negative points from other 49 | # masks when querying sam 50 | # - multiple times slower than no reinit 51 | -------------------------------------------------------------------------------- /configs/vis_eval_root.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - logging: vis_eval 3 | - _self_ 4 | - model/sam@model.sam_generator.model: sam_vit_huge 5 | - model@model.model: ??? 6 | 7 | model: 8 | _target_: sam_pt.modeling.vis_to_vos_adapter.SamBasedVisToVosAdapter 9 | max_num_masks: 100 10 | masks_batch_size: 100 11 | visualize_results: true 12 | max_videos_to_visualize: 30 13 | sam_generator: 14 | _target_: segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator 15 | model: ??? 16 | points_per_side: 32 17 | points_per_batch: 64 18 | pred_iou_thresh: 0.88 19 | stability_score_thresh: 0.95 20 | stability_score_offset: 1.0 21 | box_nms_thresh: 0.7 22 | crop_n_layers: 0 23 | crop_nms_thresh: 0.7 24 | crop_overlap_ratio: 512 / 1500 25 | crop_n_points_downscale_factor: 1 26 | point_grids: null 27 | min_mask_region_area: 0 28 | output_mode: "binary_mask" 29 | 30 | output: results 31 | 32 | device: cuda 33 | num_gpus_per_machine: 1 34 | num_machines: 1 35 | machine_rank: 0 36 | dist_url: tcp://127.0.0.1:27036 37 | 38 | DETECTRON2_CONFIG: 39 | CUDNN_BENCHMARK: false 40 | DATALOADER: 41 | ASPECT_RATIO_GROUPING: true 42 | FILTER_EMPTY_ANNOTATIONS: false 43 | NUM_WORKERS: 0 44 | REPEAT_THRESHOLD: 0.0 45 | SAMPLER_TRAIN: TrainingSampler 46 | DATASETS: 47 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000 48 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000 49 | PROPOSAL_FILES_TEST: [ ] 50 | PROPOSAL_FILES_TRAIN: [ ] 51 | TEST: 52 | - uvo_v1_val 53 | TRAIN: 54 | - null 55 | GLOBAL: 56 | HACK: 1.0 57 | INPUT: 58 | AUGMENTATIONS: [ ] 59 | COLOR_AUG_SSD: false 60 | CROP: 61 | ENABLED: false 62 | SINGLE_CATEGORY_MAX_AREA: 1.0 63 | SIZE: 64 | - 600 65 | - 720 66 | TYPE: absolute_range 67 | DATASET_MAPPER_NAME: mask_former_semantic 68 | FORMAT: RGB 69 | IMAGE_SIZE: 1024 70 | MASK_FORMAT: polygon 71 | MAX_SCALE: 2.0 72 | MAX_SIZE_TEST: 1333 73 | MAX_SIZE_TRAIN: 1333 74 | MIN_SCALE: 0.1 75 | MIN_SIZE_TEST: 360 76 | MIN_SIZE_TRAIN: 77 | - 360 78 | - 480 79 | MIN_SIZE_TRAIN_SAMPLING: choice_by_clip 80 | RANDOM_FLIP: flip_by_clip 81 | SAMPLING_FRAME_NUM: 2 82 | SAMPLING_FRAME_RANGE: 20 83 | SAMPLING_FRAME_SHUFFLE: false 84 | SIZE_DIVISIBILITY: -1 85 | MODEL: 86 | MASK_ON: false 87 | SEM_SEG_HEAD: 88 | NUM_CLASSES: 54 89 | LOAD_PROPOSALS: false 90 | OUTPUT_DIR: ${ output } 91 | SEED: -1 92 | TEST: 93 | AUG: 94 | ENABLED: false 95 | FLIP: true 96 | MAX_SIZE: 4000 97 | MIN_SIZES: 98 | - 400 99 | - 500 100 | - 600 101 | - 700 102 | - 800 103 | - 900 104 | - 1000 105 | - 1100 106 | - 1200 107 | DETECTIONS_PER_IMAGE: 10 108 | EVAL_PERIOD: 0 109 | EXPECTED_RESULTS: [ ] 110 | KEYPOINT_OKS_SIGMAS: [ ] 111 | PRECISE_BN: 112 | ENABLED: false 113 | NUM_ITER: 200 114 | VERSION: 2 115 | VIS_PERIOD: 0 116 | -------------------------------------------------------------------------------- /configs/vis_eval_sam_pt.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vis_eval_root 3 | - override model@model.model: sam_pt 4 | - _self_ 5 | 6 | model: 7 | model: 8 | point_tracker_mask_batch_size: 100 9 | sam_predictor: 10 | sam_model: ${ ...sam_generator.model } 11 | -------------------------------------------------------------------------------- /configs/vos_eval_root.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: sam_pt 3 | - logging: vos_eval 4 | - _self_ 5 | 6 | evaluator: 7 | _target_: sam_pt.vos_eval.evaluator.SamPtEvaluator 8 | _recursive_: false 9 | 10 | dataset: D17 # D16/D17/Y18/Y19/LV1/LV3/MOSE/BDD100K/G 11 | split: val # val/test 12 | simulate_interactive_point_correction: false 13 | masks_batch_size: 100 14 | seed: 72 15 | 16 | d16_path: ${hydra:runtime.cwd}/data/DAVIS/2016 17 | d17_path: ${hydra:runtime.cwd}/data/DAVIS/2017 18 | y18_path: ${hydra:runtime.cwd}/data/YouTube2018 19 | y19_path: ${hydra:runtime.cwd}/data/YouTube 20 | lv_path: ${hydra:runtime.cwd}/data/long_video_set 21 | mose_path: ${hydra:runtime.cwd}/data/mose 22 | bdd100k_path: ${hydra:runtime.cwd}/data/bdd100k/vos 23 | generic_path: null # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations" 24 | 25 | input_only_one_gt_mask_point: false # If true, only one gt mask point will be used for evaluation 26 | 27 | size: -1 # Resize the shorter side to this size. -1 to use original resolution 28 | longest_size: 1024 # Resize the longest side to this size. null to use original resolution. Must be used with size=-1 29 | 30 | flip: false 31 | 32 | output: eval_${dataset}_${split} # Path to save the results. If None, will save to the default path 33 | save_all: false # Save all frames. Useful only in YouTubeVOS/long-time video 34 | save_scores: false 35 | save_overlapping_masks: false # Save overlapping masks along with non-overlapping multi-object masks 36 | 37 | visualize_results: true # Whether to visualize the results using wandb 38 | verbose_visualisations: false # Whether to visualize the results in a verbose way (e.g. with input GIFs), slower 39 | vid_ids: null # Evaluate only on the videos specified in the list, e.g. [0,1,2] (or vid_ids=\[0,1,2\] in command line) 40 | max_videos_to_visualize: 30 # Max number of videos to visualize, used when visualize_results flag is set, videos with id >= max_videos_to_visualize will not be visualized 41 | vid_ids_to_visualize: [ 0, 1, 2, 15 ] # Videos to visualize, used when visualize_results flag is set, null for all videos 42 | log_fmt: gif # gif/mp4 43 | 44 | max_videos: null # Max number of videos to process, useful with the visualize_results flag and for debugging 45 | max_frames: null # Max number of frames to process per video. Useful for debugging 46 | 47 | logging: 48 | exp_id_verbose: ${logging.exp_id}_${dataset}_${split}_${seed}_${now:%Y.%m.%d_%H.%M.%S} 49 | 50 | 51 | hydra: 52 | job: 53 | chdir: True 54 | run: 55 | dir: outputs/${logging.exp_id_verbose} 56 | -------------------------------------------------------------------------------- /data/demo_data/README.md: -------------------------------------------------------------------------------- 1 | # Demo Data 2 | 3 | This directory contains demo data that users can use to understand the structure and format of input data. Below, we've detailed the sources of our demo data and provided an in-depth explanation of the query points format. 4 | 5 | ## Data Sources 6 | 7 | The provided clips in this directory serve as sample data for the demo and were obtained from Pixabay: 8 | 9 | 1. [`street.mp4`](.street.mp4) - [Video source](https://pixabay.com/videos/street-bus-village-bus-stop-city-38590/). 10 | 2. [`bees.mp4`](bees.mp4) - [Video source](https://pixabay.com/videos/bees-honey-bees-insect-pollen-35093/). 11 | 12 | ## Query Points Format 13 | 14 | Query points are crucial for our application as they define the target object (positive points) and the background/non-target objects (negative points). 15 | 16 | They can be provided interactively by the user or derived from a ground truth mask. The following section explains how they're structured when saved to a text file: 17 | 18 | ```bash 19 | number_of_positive_points 20 | mask_1_timestep ; pos_x_1,pos_y_1 ... pos_x_n,pos_y_n neg_x_1,neg_y_1 ... neg_x_m,neg_y_m 21 | mask_2_timestep ; pos_x_1,pos_y_1 ... pos_x_n,pos_y_n neg_x_1,neg_y_1 ... neg_x_m,neg_y_m 22 | ... 23 | ``` 24 | 25 | - `number_of_positive_points` - Specifies the number of positive points 26 | - `mask_x_timestep` - The timestamp for each mask 27 | - `pos_x_i,pos_y_i` - x, y coordinates of the positive points 28 | - `neg_x_i,neg_y_i` - x, y coordinates of the negative points 29 | 30 | Note: The number of negative points is inferred from the total number of points minus the number of positive points. 31 | 32 | Here is a simple example of a query point file with two masks: 33 | 34 | ```sh 35 | 1 36 | 0 ; 10,20 30,30 40,40 37 | 4 ; 123.123,456.456 72,72 5,6 38 | ``` 39 | 40 | In this example, each mask has one positive point and two negative points. The positive query point for the first mask, for instance, has (x,y) coordinates of (10,20). Here, the value '10' denotes a distance of 10 pixels from the left image border, and '20' indicates a distance of 20 pixels from the top image border (as the coordinate system begins at the top left corner of the image). 41 | -------------------------------------------------------------------------------- /data/demo_data/bees.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/data/demo_data/bees.mp4 -------------------------------------------------------------------------------- /data/demo_data/query_points__bees.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 0;282.5,241.25 336.25,240.0 357.5,276.25 531.25,242.5 3 | 0;1116.25,428.75 1137.5,411.25 1165.0,412.5 1130.0,272.5 4 | -------------------------------------------------------------------------------- /data/demo_data/query_points__street.txt: -------------------------------------------------------------------------------- 1 | 4 2 | 0;403.75,426.25 382.5,481.25 423.75,507.5 403.75,561.25 425.0,603.75 350.0,360.0 3 | 0;668.75,520.0 631.25,493.75 648.75,443.75 681.25,427.5 870.0,513.75 691.25,316.25 4 | 0;307.5,393.75 337.5,540.0 335.0,296.25 497.5,231.25 265.0,286.25 330.0,588.75 5 | -------------------------------------------------------------------------------- /data/demo_data/street.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/data/demo_data/street.mp4 -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/demo/__init__.py -------------------------------------------------------------------------------- /docs/01-getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Setting Up the Environment 4 | 5 | This codebase has been tested and confirmed to be compatible with the package versions listed in [`requirements.txt`](../requirements.txt), along with PyTorch 1.12.0, and Python 3.8.16. These versions were tested on Manjaro Linux and Debian GNU/Linux 10 (buster) systems. 6 | 7 | Start by cloning the repository: 8 | 9 | ```bash 10 | git clone https://github.com/SysCV/sam-pt.git 11 | cd sam-pt 12 | ``` 13 | 14 | With the repository now cloned, we recommend creating a new [conda](https://docs.conda.io/en/latest/) virtual environment: 15 | 16 | ```bash 17 | conda create --name sam-pt python=3.8.16 -y 18 | conda activate sam-pt 19 | ``` 20 | 21 | Next, install [PyTorch](https://pytorch.org/) 1.12.0 and [torchvision](https://pytorch.org/vision/stable/index.html) 0.13.0, for example with CUDA 11 support: 22 | 23 | ```bash 24 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch 25 | ``` 26 | 27 | Finally, install the required packages: 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | If you wish to use TapNet (or TAPIR) as a point tracker, it's necessary to configure JAX on your system. The required packages, including JAX library version [0.4.11](https://github.com/google/jax/tree/jax-v0.4.11) and others needed by TapNet, can be found in the [`requirements-jax.txt`](../requirements-jax.txt) file. To install JAX, we recommend following the [official installation instructions](https://github.com/google/jax#installation). In some environments, like ours, it may be necessary to build PyTorch and JAX from source. 34 | 35 | ## Running the Demo 36 | 37 | To run the demo, start by preparing your demo data. This can either be one of the clips provided in `data/demo_data`, or a clip of your own. You can also use the horse jumping video `data/DAVIS/2017/trainval/JPEGImages/Full-Resolution/horsejump-high` from [DAVIS 2017](02-prepare-datasets.md#davis-2017). 38 | 39 | The demo expects a sequence of images as input. If your data is a video clip, convert it to images ensuring their filenames are lexicographically ordered (e.g., `frame-000.png`, `frame-001.png`, etc.). For example, the `ffmpeg` command can be used to convert the provided demo clips as follows: 40 | 41 | ```bash 42 | # List the content of the demo_data directory 43 | ls data/demo_data 44 | # bees.mp4 street.mp4 ... 45 | 46 | # Convert bees.mp4 to png frames 47 | mkdir data/demo_data/bees 48 | ffmpeg -i data/demo_data/bees.mp4 -vf fps=5 data/demo_data/bees/frame-%05d.png 49 | 50 | # Convert street.mp4 to png frames 51 | mkdir data/demo_data/street 52 | ffmpeg -i data/demo_data/street.mp4 -vf fps=10 data/demo_data/street/frame-%05d.png 53 | ``` 54 | 55 | Before running the demo, you additionally have to make sure to have the SAM and PIPS checkpoints downloaded, as described under [minimal checkpoints](03-prepare-checkpoints.md#minimal-checkpoints). 56 | 57 | ### Running the Interactive Demo 58 | 59 | The interactive demo allows you to specify query points using mouse clicks on a pop-up window. This requires a GUI environment, which is typically available on personal computers. If you're using remote GPUs, you may need to set up X forwarding. 60 | 61 | 62 | Note that the [`${hydra:runtime.cwd}`](https://hydra.cc/docs/1.3/configure_hydra/intro/#hydraruntime) prefix in the commands below needs to be used to prefix relative paths. This is because we launch demos within a [working directory created by Hydra](https://hydra.cc/docs/1.3/tutorials/basic/running_your_app/working_directory/). Follow the instructions displayed in your terminal after launching the interactive demo. 63 | 64 | 65 | ```bash 66 | # Run demo on bees.mp4 67 | export HYDRA_FULL_ERROR=1 68 | python -m demo.demo \ 69 | frames_path='${hydra:runtime.cwd}/data/demo_data/bees/' \ 70 | query_points_path=null \ 71 | longest_side_length=1024 frame_stride=1 max_frames=-1 72 | 73 | # Run demo on street.mp4 74 | export HYDRA_FULL_ERROR=1 75 | python -m demo.demo \ 76 | frames_path='${hydra:runtime.cwd}/data/demo_data/street/' \ 77 | query_points_path=null \ 78 | longest_side_length=1024 frame_stride=1 max_frames=-1 79 | ``` 80 | 81 | ### Running the Non-interactive Demo 82 | 83 | You also have the option to run the demo in a non-interactive mode where query points are predefined in a file. You can create the content of a query points file using the interactive demo, which will print a string of the query points. This string can be saved and used for running the non-interactive demo. More details about the format of the query points file can be found in [`data/demo_data/README.md`](../data/demo_data/README.md). Examples of query point files for the [bees](../data/demo_data/query_points__bees.txt) and [street](../data/demo_data/query_points__street.txt) clips are also provided and can be used as in the following commands: 84 | 85 | ```bash 86 | # Run non-interactive demo on bees.mp4 87 | export HYDRA_FULL_ERROR=1 88 | python -m demo.demo \ 89 | frames_path='${hydra:runtime.cwd}/data/demo_data/bees/' \ 90 | query_points_path='${hydra:runtime.cwd}/data/demo_data/query_points__bees.txt' \ 91 | longest_side_length=1024 frame_stride=1 max_frames=-1 92 | 93 | # Run non-interactive demo on street.mp4 94 | export HYDRA_FULL_ERROR=1 95 | python -m demo.demo \ 96 | frames_path='${hydra:runtime.cwd}/data/demo_data/street/' \ 97 | query_points_path='${hydra:runtime.cwd}/data/demo_data/query_points__street.txt' \ 98 | longest_side_length=1024 frame_stride=1 max_frames=-1 99 | ``` 100 | 101 | ## Codebase Overview 102 | 103 | Here's a quick overview of our project's codebase and its structure: 104 | 105 | - [`assets`](../assets): Assets related to the GitHub repository 106 | - [`configs`](../configs): YAML configuration files used with Hydra 107 | - [`data`](../data): Directory to store data 108 | - [`demo_data`](../data/demo_data): Demo data with README for data sources and query points file format 109 | - [`demo`](../demo): Code for running the demo 110 | - [`docs`](../docs): Documentation on how to use the codebase 111 | - [`sam_pt`](../sam_pt): Source for SAM-PT 112 | - [`modeling`](../sam_pt/modeling): Main code for SAM-PT 113 | - [`point_tracker`](../sam_pt/point_tracker): Code for different point trackers 114 | - [`utils`](../sam_pt/utils): Utilities used within the SAM-PT module 115 | - [`vis_eval`](../sam_pt/vis_eval): Code for evaluating on Video Instance Segmentation (VIS) 116 | - [`vos_eval`](../sam_pt/vos_eval): Code for evaluating on Video Object Segmentation (VOS) 117 | - [`scripts`](../scripts): Scripts used for small tasks 118 | - [`README.md`](../README.md): Main README file 119 | - [`requirements.txt`](../requirements.txt): General project requirements 120 | - [`requirements-jax.txt`](../requirements-jax.txt): Requirements for using the JAX-based TapNet and TAPIR point trackers 121 | 122 | 123 | ## What's Next? 124 | 125 | Once you are comfortable with running the demo, you might want to explore [how to prepare the data](02-prepare-datasets.md) and [how to prepare the checkpoints](03-prepare-checkpoints.md) that are necessary for [running our VOS and VIS experiments](04-running-experiments.md). 126 | -------------------------------------------------------------------------------- /requirements-jax.txt: -------------------------------------------------------------------------------- 1 | # JAX version requirements, recommended to be installed following JAX's instructions here: https://github.com/google/jax#installation 2 | # jax==0.4.11 3 | # jaxlib==0.4.11 4 | 5 | jaxline==0.0.5 6 | chex==0.1.7 7 | dm-haiku==0.0.9 8 | optax==0.1.5 9 | einshape@git+https://github.com/deepmind/einshape -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Torch version requirements, recommended to be installed w/o pip: 2 | # torch==1.12.0 3 | # torchvision==0.13.0 4 | 5 | tensorflow==2.12.1 6 | einops==0.4.1 7 | opencv-python==4.7.0.72 8 | timm==0.9.2 9 | flow_vis==0.1 10 | 11 | numpy==1.24.3 12 | h5py==3.9.0 13 | Pillow==9.5.0 14 | pandas==1.5.3 15 | matplotlib==3.5.1 16 | seaborn==0.12.2 17 | scikit-learn==1.1.1 18 | scikit-learn-extra==0.3.0 19 | 20 | hydra-core==1.3.2 21 | wandb==0.15.3 22 | imageio==2.31.1 23 | moviepy==1.0.3 24 | mediapy==1.1.8 25 | 26 | git+https://github.com/facebookresearch/detectron2@v0.6 27 | git+https://github.com/m43/davis2016-davis2017-davis2019-evaluation.git@35401a5619757359673d9d1a7d9e02c177f06f7f 28 | git+https://github.com/facebookresearch/segment-anything.git@aac76a1fb03cf90dc7cb2ad481d511642e51aeba 29 | git+https://github.com/ChaoningZhang/MobileSAM.git@01ea8d0f5590082f0c1ceb0a3e2272593f20154b 30 | git+https://github.com/m43/sam-hq.git@75c73fa27b32435f33119d08a47788db4601e1da 31 | git+https://github.com/facebookresearch/co-tracker.git@4f297a92fe1a684b1b0980da138b706d62e45472 32 | -------------------------------------------------------------------------------- /sam_pt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/__init__.py -------------------------------------------------------------------------------- /sam_pt/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/modeling/__init__.py -------------------------------------------------------------------------------- /sam_pt/modeling/sam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains hydra wrapper classes for different types of Sam models. Each hydra wrapper provides functionality 3 | for loading checkpoints and storing additional parameters that we used for variable interpolation within Hydra. 4 | """ 5 | 6 | import torch 7 | from mobile_sam.modeling import Sam as MobileSam 8 | from segment_anything.modeling import Sam 9 | from segment_anything_hq.modeling import Sam as SamHQ 10 | 11 | 12 | class BaseHydra: 13 | """ 14 | Base class for hydra wrappers that loads the model checkpoint and stores additional parameters that we used for 15 | variable interpolation within Hydra. 16 | """ 17 | 18 | def __init__(self, model, checkpoint, prompt_embed_dim, image_size, vit_patch_size, image_embedding_size, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | if checkpoint is not None: 22 | with open(checkpoint, "rb") as f: 23 | state_dict = torch.load(f) 24 | model.load_state_dict(self, state_dict, strict=False) 25 | print(f"Loaded checkpoint from {checkpoint}.") 26 | 27 | # Store additional parameters used for variable interpolation within Hydra 28 | self.prompt_embed_dim = prompt_embed_dim 29 | self.image_size = image_size 30 | self.vit_patch_size = vit_patch_size 31 | self.image_embedding_size = image_embedding_size 32 | 33 | 34 | class SamHydra(BaseHydra, Sam): 35 | """ 36 | Wrapper for the Sam model that allows for loading a checkpoint 37 | and setting additional parameters used for variable interpolation. 38 | """ 39 | 40 | def __init__(self, *args, **kwargs): 41 | super().__init__(Sam, *args, **kwargs) 42 | 43 | 44 | class SamHQHydra(BaseHydra, SamHQ): 45 | """ 46 | Wrapper for the SamHQ model that allows for loading a checkpoint 47 | and setting additional parameters used for variable interpolation. 48 | """ 49 | 50 | def __init__(self, *args, **kwargs): 51 | super().__init__(SamHQ, *args, **kwargs) 52 | 53 | 54 | class MobileSamHydra(BaseHydra, MobileSam): 55 | """ 56 | Wrapper for the MobileSAM model that allows for loading a checkpoint 57 | and setting additional parameters used for variable interpolation. 58 | """ 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(MobileSam, *args, **kwargs) 62 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import PointTracker 2 | from .pips import PipsPointTracker 3 | from .raft import RaftPointTracker 4 | from .superglue import SuperGluePointTracker 5 | from .tapir import TapirPointTracker 6 | from .tapnet import TapnetPointTracker 7 | from .cotracker import CoTrackerPointTracker -------------------------------------------------------------------------------- /sam_pt/point_tracker/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import CoTrackerPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/__init__.py: -------------------------------------------------------------------------------- 1 | from .pips import Pips 2 | from .tracker import PipsPointTracker 3 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips_plus_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from .pips_plus_plus import PipsPlusPlus 2 | from .tracker import PipsPlusPlusPointTracker 3 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips_plus_plus/tracker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | 5 | from sam_pt.point_tracker import PointTracker 6 | from sam_pt.point_tracker.pips_plus_plus import PipsPlusPlus 7 | from sam_pt.point_tracker.utils import saverloader 8 | 9 | 10 | class PipsPlusPlusPointTracker(PointTracker): 11 | 12 | def __init__(self, checkpoint_path, stride=8, max_sequence_length=128, iters=16, image_size=(512, 896)): 13 | super().__init__() 14 | self.checkpoint_path = checkpoint_path 15 | self.stride = stride 16 | self.max_sequence_length = max_sequence_length 17 | self.iters = iters 18 | self.image_size = tuple(image_size) if image_size is not None else None 19 | 20 | print(f"Loading PIPS++ model from {self.checkpoint_path}") 21 | self.model = PipsPlusPlus(stride=self.stride) 22 | self._loaded_checkpoint_step = saverloader.load(self.checkpoint_path, self.model, 23 | device="cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | def _forward(self, rgbs, query_points): 26 | """ 27 | Single direction forward pass. 28 | """ 29 | B, S, C, H, W = rgbs.shape 30 | assert query_points.ndim == 2 31 | assert query_points.shape[1] == 2 32 | 33 | # zero-vel init 34 | trajs_e = query_points[None, None, :, :].repeat(1, rgbs.shape[1], 1, 1) 35 | 36 | cur_frame = 0 37 | done = False 38 | feat_init = None 39 | while not done: 40 | end_frame = cur_frame + self.max_sequence_length 41 | 42 | if end_frame > S: 43 | diff = end_frame - S 44 | end_frame = end_frame - diff 45 | cur_frame = max(cur_frame - diff, 0) 46 | 47 | traj_seq = trajs_e[:, cur_frame:end_frame] 48 | rgb_seq = rgbs[:, cur_frame:end_frame] 49 | S_local = rgb_seq.shape[1] 50 | 51 | if feat_init is not None: 52 | feat_init = [fi[:, :S_local] for fi in feat_init] 53 | 54 | preds, preds_anim, feat_init, _ = self.model(traj_seq, rgb_seq, iters=self.iters, feat_init=feat_init) 55 | 56 | trajs_e[:, cur_frame:end_frame] = preds[-1][:, :S_local] 57 | trajs_e[:, end_frame:] = trajs_e[:, end_frame - 1:end_frame] # update the future with new zero-vel 58 | 59 | if end_frame >= S: 60 | done = True 61 | else: 62 | cur_frame = cur_frame + self.max_sequence_length - 1 63 | 64 | visibilities = torch.ones_like(trajs_e[:, :, :, 0]) 65 | return trajs_e, visibilities 66 | 67 | def forward(self, rgbs, query_points): 68 | """ 69 | Forward function for the tracker. 70 | """ 71 | batch_size, num_frames, C, H, W = rgbs.shape 72 | if self.image_size is not None: 73 | rgbs = rgbs.reshape(batch_size * num_frames, C, H, W) 74 | rgbs = rgbs / 255.0 75 | rgbs = torch.nn.functional.interpolate(rgbs, size=tuple(self.image_size), mode="bilinear") 76 | rgbs = rgbs * 255.0 77 | rgbs = rgbs.reshape(batch_size, num_frames, C, *self.image_size) 78 | query_points[:, :, 1] *= self.image_size[0] / H 79 | query_points[:, :, 2] *= self.image_size[1] / W 80 | 81 | # Group query points by their time-step 82 | groups = defaultdict(list) 83 | assert query_points.shape[0] == batch_size == 1, "Only batch size 1 is supported." 84 | for idx, point in enumerate(query_points[0]): 85 | t = int(point[0].item()) 86 | groups[t].append((idx, point[1:].tolist())) 87 | 88 | # Dictionary to store results 89 | trajectories_dict = {} 90 | visibilities_dict = {} 91 | 92 | for t, points_with_indices in groups.items(): 93 | points = [x[1] for x in points_with_indices] 94 | 95 | # Left to right 96 | if t == num_frames - 1: 97 | left_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda() 98 | left_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda() 99 | else: 100 | left_rgbs = rgbs[:, t:] 101 | left_query = torch.tensor(points, dtype=torch.float32).cuda() 102 | left_trajectories, left_visibilities = self._forward(left_rgbs, left_query) 103 | 104 | # Right to left 105 | if t == 0: 106 | right_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda() 107 | right_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda() 108 | else: 109 | right_rgbs = rgbs[:, :t + 1].flip(1) 110 | right_query = torch.tensor(points, dtype=torch.float32).cuda() 111 | right_trajectories, right_visibilities = self._forward(right_rgbs, right_query) 112 | right_trajectories = right_trajectories.flip(1) 113 | right_visibilities = right_visibilities.flip(1) 114 | 115 | # Merge the results 116 | trajectories = torch.cat([right_trajectories[:, :-1], left_trajectories], dim=1) 117 | visibilities = torch.cat([right_visibilities[:, :-1], left_visibilities], dim=1) 118 | 119 | # Store in dictionary 120 | for idx, (idx, _) in enumerate(points_with_indices): 121 | trajectories_dict[idx] = trajectories[:, :, idx, :] 122 | visibilities_dict[idx] = visibilities[:, :, idx] 123 | 124 | # Assemble the results back in the order of the input query points 125 | n_points = query_points.shape[1] 126 | final_trajectories = torch.stack([trajectories_dict[i] for i in range(n_points)], dim=2) 127 | final_visibilities = torch.stack([visibilities_dict[i] for i in range(n_points)], dim=2) 128 | 129 | # Rescale trajectories back to the original size 130 | if self.image_size is not None: 131 | final_trajectories[:, :, :, 0] *= H / self.image_size[0] 132 | final_trajectories[:, :, :, 1] *= W / self.image_size[1] 133 | 134 | return final_trajectories, final_visibilities 135 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import RaftPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/raft/raft_core/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/corr.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/corr.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .util import bilinear_sampler 7 | 8 | try: 9 | import alt_cuda_corr 10 | except: 11 | # alt_cuda_corr is not compiled 12 | pass 13 | 14 | 15 | class CorrBlock: 16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 17 | self.num_levels = num_levels 18 | self.radius = radius 19 | self.corr_pyramid = [] 20 | 21 | # all pairs correlation 22 | corr = CorrBlock.corr(fmap1, fmap2) 23 | 24 | batch, h1, w1, dim, h2, w2 = corr.shape 25 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 26 | 27 | self.corr_pyramid.append(corr) 28 | for i in range(self.num_levels - 1): 29 | corr = F.avg_pool2d(corr, 2, stride=2) 30 | self.corr_pyramid.append(corr) 31 | 32 | def __call__(self, coords): 33 | r = self.radius 34 | coords = coords.permute(0, 2, 3, 1) 35 | batch, h1, w1, _ = coords.shape 36 | 37 | out_pyramid = [] 38 | for i in range(self.num_levels): 39 | corr = self.corr_pyramid[i] 40 | dx = torch.linspace(-r, r, 2 * r + 1) 41 | dy = torch.linspace(-r, r, 2 * r + 1) 42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 43 | 44 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 45 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 46 | coords_lvl = centroid_lvl + delta_lvl 47 | 48 | corr = bilinear_sampler(corr, coords_lvl) 49 | corr = corr.view(batch, h1, w1, -1) 50 | out_pyramid.append(corr) 51 | 52 | out = torch.cat(out_pyramid, dim=-1) 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | fmap1 = fmap1.view(batch, dim, ht * wd) 59 | fmap2 = fmap2.view(batch, dim, ht * wd) 60 | 61 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 62 | corr = corr.view(batch, ht, wd, 1, ht, wd) 63 | return corr / torch.sqrt(torch.tensor(dim).float()) 64 | 65 | 66 | class AlternateCorrBlock: 67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 68 | self.num_levels = num_levels 69 | self.radius = radius 70 | 71 | self.pyramid = [(fmap1, fmap2)] 72 | for i in range(self.num_levels): 73 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 74 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 75 | self.pyramid.append((fmap1, fmap2)) 76 | 77 | def __call__(self, coords): 78 | coords = coords.permute(0, 2, 3, 1) 79 | B, H, W, _ = coords.shape 80 | dim = self.pyramid[0][0].shape[1] 81 | 82 | corr_list = [] 83 | for i in range(self.num_levels): 84 | r = self.radius 85 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 86 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 87 | 88 | coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 89 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 90 | corr_list.append(corr.squeeze(1)) 91 | 92 | corr = torch.stack(corr_list, dim=1) 93 | corr = corr.reshape(B, -1, H, W) 94 | return corr / torch.sqrt(torch.tensor(dim).float()) 95 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/raft.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/raft.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .corr import CorrBlock, AlternateCorrBlock 8 | from .extractor import BasicEncoder, SmallEncoder 9 | from .update import BasicUpdateBlock, SmallUpdateBlock 10 | from .util import coords_grid, upflow8 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class RAFT(nn.Module): 28 | def __init__(self, args): 29 | super(RAFT, self).__init__() 30 | self.args = args 31 | 32 | if args.small: 33 | self.hidden_dim = hdim = 96 34 | self.context_dim = cdim = 64 35 | args.corr_levels = 4 36 | args.corr_radius = 3 37 | 38 | else: 39 | self.hidden_dim = hdim = 128 40 | self.context_dim = cdim = 128 41 | args.corr_levels = 4 42 | args.corr_radius = 4 43 | 44 | if 'dropout' not in self.args: 45 | self.args.dropout = 0 46 | 47 | if 'alternate_corr' not in self.args: 48 | self.args.alternate_corr = False 49 | 50 | # feature network, context network, and update block 51 | if args.small: 52 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 53 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout) 54 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 55 | 56 | else: 57 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 58 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout) 59 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 60 | 61 | def freeze_bn(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.BatchNorm2d): 64 | m.eval() 65 | 66 | def initialize_flow(self, img): 67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 68 | N, C, H, W = img.shape 69 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 70 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 71 | 72 | # optical flow computed as difference: flow = coords1 - coords0 73 | return coords0, coords1 74 | 75 | def upsample_flow(self, flow, mask): 76 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 77 | N, _, H, W = flow.shape 78 | mask = mask.view(N, 1, 9, 8, 8, H, W) 79 | mask = torch.softmax(mask, dim=2) 80 | 81 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 82 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 83 | 84 | up_flow = torch.sum(mask * up_flow, dim=2) 85 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 86 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 87 | 88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 89 | """ Estimate optical flow between pair of frames """ 90 | 91 | image1 = 2 * (image1 / 255.0) - 1.0 92 | image2 = 2 * (image2 / 255.0) - 1.0 93 | 94 | image1 = image1.contiguous() 95 | image2 = image2.contiguous() 96 | 97 | hdim = self.hidden_dim 98 | cdim = self.context_dim 99 | 100 | # run the feature network 101 | with autocast(enabled=self.args.mixed_precision): 102 | fmap1, fmap2 = self.fnet([image1, image2]) 103 | 104 | fmap1 = fmap1.float() 105 | fmap2 = fmap2.float() 106 | if self.args.alternate_corr: 107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | else: 109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 110 | 111 | # run the context network 112 | with autocast(enabled=self.args.mixed_precision): 113 | cnet = self.cnet(image1) 114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 115 | net = torch.tanh(net) 116 | inp = torch.relu(inp) 117 | 118 | coords0, coords1 = self.initialize_flow(image1) 119 | 120 | if flow_init is not None: 121 | coords1 = coords1 + flow_init 122 | 123 | flow_predictions = [] 124 | for itr in range(iters): 125 | coords1 = coords1.detach() 126 | corr = corr_fn(coords1) # index correlation volume 127 | 128 | flow = coords1 - coords0 129 | with autocast(enabled=self.args.mixed_precision): 130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 131 | 132 | # F(t+1) = F(t) + \Delta(t) 133 | coords1 = coords1 + delta_flow 134 | 135 | # upsample predictions 136 | if up_mask is None: 137 | flow_up = upflow8(coords1 - coords0) 138 | else: 139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 140 | 141 | flow_predictions.append(flow_up) 142 | 143 | if test_mode: 144 | corr = corr_fn(coords1) # index correlation volume 145 | # feat = torch.cat([inp, corr], dim=1) 146 | feat = inp 147 | return coords1 - coords0, flow_up, (feat, fmap1, fmap2) 148 | 149 | return flow_predictions 150 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/update.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/update.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FlowHead(nn.Module): 9 | def __init__(self, input_dim=128, hidden_dim=256): 10 | super(FlowHead, self).__init__() 11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | return self.conv2(self.relu(self.conv1(x))) 17 | 18 | 19 | class ConvGRU(nn.Module): 20 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 21 | super(ConvGRU, self).__init__() 22 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 23 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 24 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 25 | 26 | def forward(self, h, x): 27 | hx = torch.cat([h, x], dim=1) 28 | 29 | z = torch.sigmoid(self.convz(hx)) 30 | r = torch.sigmoid(self.convr(hx)) 31 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 32 | 33 | h = (1 - z) * h + z * q 34 | return h 35 | 36 | 37 | class SepConvGRU(nn.Module): 38 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 39 | super(SepConvGRU, self).__init__() 40 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 41 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 42 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 43 | 44 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 45 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 46 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 54 | h = (1 - z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 61 | h = (1 - z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class SmallMotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(SmallMotionEncoder, self).__init__() 69 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 70 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 71 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 72 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 73 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 74 | 75 | def forward(self, flow, corr): 76 | cor = F.relu(self.convc1(corr)) 77 | flo = F.relu(self.convf1(flow)) 78 | flo = F.relu(self.convf2(flo)) 79 | cor_flo = torch.cat([cor, flo], dim=1) 80 | out = F.relu(self.conv(cor_flo)) 81 | return torch.cat([out, flow], dim=1) 82 | 83 | 84 | class BasicMotionEncoder(nn.Module): 85 | def __init__(self, args): 86 | super(BasicMotionEncoder, self).__init__() 87 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 88 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 89 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 90 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 91 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 92 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 93 | 94 | def forward(self, flow, corr): 95 | cor = F.relu(self.convc1(corr)) 96 | cor = F.relu(self.convc2(cor)) 97 | flo = F.relu(self.convf1(flow)) 98 | flo = F.relu(self.convf2(flo)) 99 | 100 | cor_flo = torch.cat([cor, flo], dim=1) 101 | out = F.relu(self.conv(cor_flo)) 102 | return torch.cat([out, flow], dim=1) 103 | 104 | 105 | class SmallUpdateBlock(nn.Module): 106 | def __init__(self, args, hidden_dim=96): 107 | super(SmallUpdateBlock, self).__init__() 108 | self.encoder = SmallMotionEncoder(args) 109 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) 110 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 111 | 112 | def forward(self, net, inp, corr, flow): 113 | motion_features = self.encoder(flow, corr) 114 | inp = torch.cat([inp, motion_features], dim=1) 115 | net = self.gru(net, inp) 116 | delta_flow = self.flow_head(net) 117 | 118 | return net, None, delta_flow 119 | 120 | 121 | class BasicUpdateBlock(nn.Module): 122 | def __init__(self, args, hidden_dim=128, input_dim=128): 123 | super(BasicUpdateBlock, self).__init__() 124 | self.args = args 125 | self.encoder = BasicMotionEncoder(args) 126 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 127 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 128 | 129 | self.mask = nn.Sequential( 130 | nn.Conv2d(128, 256, 3, padding=1), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 133 | 134 | def forward(self, net, inp, corr, flow, upsample=True): 135 | motion_features = self.encoder(flow, corr) 136 | inp = torch.cat([inp, motion_features], dim=1) 137 | 138 | net = self.gru(net, inp) 139 | delta_flow = self.flow_head(net) 140 | 141 | # scale mask to balence gradients 142 | mask = .25 * self.mask(net) 143 | return net, mask, delta_flow 144 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/util.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/utils/utils.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy import interpolate 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | 12 | def __init__(self, dims, mode='sintel'): 13 | self.ht, self.wd = dims[-2:] 14 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 15 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 16 | if mode == 'sintel': 17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 18 | else: 19 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 20 | 21 | def pad(self, *inputs): 22 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 23 | 24 | def unpad(self, x): 25 | ht, wd = x.shape[-2:] 26 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 27 | return x[..., c[0]:c[1], c[2]:c[3]] 28 | 29 | 30 | def forward_interpolate(flow): 31 | flow = flow.detach().cpu().numpy() 32 | dx, dy = flow[0], flow[1] 33 | 34 | ht, wd = dx.shape 35 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 36 | 37 | x1 = x0 + dx 38 | y1 = y0 + dy 39 | 40 | x1 = x1.reshape(-1) 41 | y1 = y1.reshape(-1) 42 | dx = dx.reshape(-1) 43 | dy = dy.reshape(-1) 44 | 45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 46 | x1 = x1[valid] 47 | y1 = y1[valid] 48 | dx = dx[valid] 49 | dy = dy[valid] 50 | 51 | flow_x = interpolate.griddata( 52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow_y = interpolate.griddata( 55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 56 | 57 | flow = np.stack([flow_x, flow_y], axis=0) 58 | return torch.from_numpy(flow).float() 59 | 60 | 61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 62 | """ Wrapper for grid_sample, uses pixel coordinates """ 63 | H, W = img.shape[-2:] 64 | xgrid, ygrid = coords.split([1, 1], dim=-1) 65 | xgrid = 2 * xgrid / (W - 1) - 1 66 | ygrid = 2 * ygrid / (H - 1) - 1 67 | 68 | grid = torch.cat([xgrid, ygrid], dim=-1) 69 | img = F.grid_sample(img, grid, align_corners=True) 70 | 71 | if mask: 72 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 73 | return img, mask.float() 74 | 75 | return img 76 | 77 | 78 | def coords_grid(batch, ht, wd): 79 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 80 | coords = torch.stack(coords[::-1], dim=0).float() 81 | return coords[None].repeat(batch, 1, 1, 1) 82 | 83 | 84 | def upflow8(flow, mode='bilinear'): 85 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 86 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 87 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raftnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/nets/raftnet.py 2 | 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .raft_core.raft import RAFT 9 | from .raft_core.util import InputPadder 10 | 11 | 12 | class Raftnet(nn.Module): 13 | def __init__(self, ckpt_name=None, small=False, alternate_corr=False, mixed_precision=True): 14 | super(Raftnet, self).__init__() 15 | args = argparse.Namespace() 16 | args.small = small 17 | args.alternate_corr = alternate_corr 18 | args.mixed_precision = mixed_precision 19 | self.model = RAFT(args) 20 | if ckpt_name is not None: 21 | state_dict = torch.load(ckpt_name) 22 | state_dict = { # The checkpoint was saved as wrapped in nn.DataParallel, this removes the wrapper 23 | k.replace('module.', ''): v 24 | for k, v in state_dict.items() 25 | if k != 'module' 26 | } 27 | self.model.load_state_dict(state_dict) 28 | 29 | def forward(self, image1, image2, iters=20, test_mode=True): 30 | # input images are in [-0.5, 0.5] 31 | # raftnet wants the images to be in [0,255] 32 | image1 = (image1 + 0.5) * 255.0 33 | image2 = (image2 + 0.5) * 255.0 34 | 35 | padder = InputPadder(image1.shape) 36 | image1, image2 = padder.pad(image1, image2) 37 | if test_mode: 38 | flow_low, flow_up, feat = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode) 39 | flow_up = padder.unpad(flow_up) 40 | return flow_up, feat 41 | else: 42 | flow_predictions = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode) 43 | return flow_predictions 44 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import sam_pt.point_tracker.utils.improc 5 | import sam_pt.point_tracker.utils.samp 6 | from sam_pt.point_tracker import PointTracker 7 | from .raftnet import Raftnet 8 | 9 | 10 | class RaftPointTracker(PointTracker): 11 | """ 12 | Implements a point tracker that uses the RAFT algorithm for optical flow estimation 13 | from https://arxiv.org/abs/2003.12039. The tracker computes forward and backward flows 14 | for each frame in a video sequence and uses these to estimate the trajectories of given points. 15 | """ 16 | 17 | def __init__(self, checkpoint_path): 18 | """ 19 | Args: 20 | checkpoint_path (str): The path to the trained RAFT model checkpoint. 21 | """ 22 | super().__init__() 23 | self.checkpoint_path = checkpoint_path 24 | if self.checkpoint_path is not None and not os.path.exists(self.checkpoint_path): 25 | raise FileNotFoundError(f"Raft checkpoint not found at {self.checkpoint_path}") 26 | print(f"Loading Raft model from {self.checkpoint_path}") 27 | self.model = Raftnet(ckpt_name=self.checkpoint_path) 28 | 29 | def forward(self, rgbs, query_points, summary_writer=None): 30 | batch_size, n_frames, channels, height, width = rgbs.shape 31 | n_points = query_points.shape[1] 32 | 33 | prep_rgbs = sam_pt.point_tracker.utils.improc.preprocess_color(rgbs) 34 | 35 | flows_forward = [] 36 | flows_backward = [] 37 | for t in range(1, n_frames): 38 | rgb0 = prep_rgbs[:, t - 1] 39 | rgb1 = prep_rgbs[:, t] 40 | flows_forward.append(self.model.forward(rgb0, rgb1, iters=32)[0]) 41 | flows_backward.append(self.model.forward(rgb1, rgb0, iters=32)[0]) 42 | flows_forward = torch.stack(flows_forward, dim=1) 43 | flows_backward = torch.stack(flows_backward, dim=1) 44 | assert flows_forward.shape == flows_backward.shape == (batch_size, n_frames - 1, 2, height, width) 45 | 46 | coords = [] 47 | for t in range(n_frames): 48 | if t == 0: 49 | coord = torch.zeros_like(query_points[:, :, 1:]) 50 | else: 51 | prev_coord = coords[t - 1] 52 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d( 53 | im=flows_forward[:, t - 1], 54 | x=prev_coord[:, :, 0], 55 | y=prev_coord[:, :, 1], 56 | ).permute(0, 2, 1) 57 | assert delta.shape == (batch_size, n_points, 2), "Forward flow at the discrete points" 58 | coord = prev_coord + delta 59 | 60 | # Set the ground truth query point location if the timestep is correct 61 | query_point_mask = query_points[:, :, 0] == t 62 | coord = coord * ~query_point_mask.unsqueeze(-1) + query_points[:, :, 1:] * query_point_mask.unsqueeze(-1) 63 | 64 | coords.append(coord) 65 | 66 | for t in range(n_frames - 2, -1, -1): 67 | coord = coords[t] 68 | successor_coord = coords[t + 1] 69 | 70 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d( 71 | im=flows_backward[:, t], 72 | x=successor_coord[:, :, 0], 73 | y=successor_coord[:, :, 1], 74 | ).permute(0, 2, 1) 75 | assert delta.shape == (batch_size, n_points, 2), "Backward flow at the discrete points" 76 | 77 | # Update only the points that are located prior to the query point 78 | prior_to_query_point_mask = t < query_points[:, :, 0] 79 | coord = (coord * ~prior_to_query_point_mask.unsqueeze(-1) + 80 | (successor_coord + delta) * prior_to_query_point_mask.unsqueeze(-1)) 81 | coords[t] = coord 82 | 83 | trajectories = torch.stack(coords, dim=1) 84 | visibilities = (trajectories[:, :, :, 0] >= 0) & \ 85 | (trajectories[:, :, :, 1] >= 0) & \ 86 | (trajectories[:, :, :, 0] < width) & \ 87 | (trajectories[:, :, :, 1] < height) 88 | return trajectories, visibilities 89 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import SuperGluePointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/superglue/models/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/matching.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Taken from: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/matching.py 44 | 45 | import torch 46 | 47 | from .superglue import SuperGlue 48 | from .superpoint import SuperPoint 49 | 50 | 51 | class Matching(torch.nn.Module): 52 | """ Image Matching Frontend (SuperPoint + SuperGlue) """ 53 | 54 | def __init__(self, config={}): 55 | super().__init__() 56 | self.superpoint = SuperPoint(config.get('superpoint', {})) 57 | self.superglue = SuperGlue(config.get('superglue', {})) 58 | 59 | def forward(self, data): 60 | """ Run SuperPoint (optionally) and SuperGlue 61 | SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input 62 | Args: 63 | data: dictionary with minimal keys: ['image0', 'image1'] 64 | """ 65 | pred = {} 66 | 67 | # Extract SuperPoint (keypoints, scores, descriptors) if not provided 68 | if 'keypoints0' not in data: 69 | pred0 = self.superpoint({'image': data['image0']}) 70 | pred = {**pred, **{k + '0': v for k, v in pred0.items()}} 71 | if 'keypoints1' not in data: 72 | pred1 = self.superpoint({'image': data['image1']}) 73 | pred = {**pred, **{k + '1': v for k, v in pred1.items()}} 74 | 75 | # Batch all features 76 | # We should either have i) one image per batch, or 77 | # ii) the same number of local features for all images in the batch. 78 | data = {**data, **pred} 79 | 80 | for k in data: 81 | if isinstance(data[k], (list, tuple)): 82 | data[k] = torch.stack(data[k]) 83 | 84 | # Perform the matching 85 | pred = {**pred, **self.superglue(data)} 86 | 87 | return pred 88 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import TapirPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/configs/tapir_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapir_config.py 17 | 18 | """Default config to train the TAPIR.""" 19 | 20 | from jaxline import base_config 21 | from ml_collections import config_dict 22 | 23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels) 24 | 25 | 26 | # We define the experiment launch config in the same file as the experiment to 27 | # keep things self-contained in a single file. 28 | def get_config() -> config_dict.ConfigDict: 29 | """Return config object for training.""" 30 | config = base_config.get_base_config() 31 | 32 | # Experiment config. 33 | config.training_steps = 100000 34 | 35 | # NOTE: duplicates not allowed. 36 | config.shared_module_names = ('tapir_model',) 37 | 38 | config.dataset_names = ('kubric',) 39 | # Note: eval modes must always start with 'eval_'. 40 | config.eval_modes = ( 41 | 'eval_davis_points', 42 | 'eval_jhmdb', 43 | 'eval_robotics_points', 44 | 'eval_kinetics_points', 45 | ) 46 | config.checkpoint_dir = '/tmp/tapnet_training/' 47 | config.evaluate_every = 10000 48 | 49 | config.experiment_kwargs = config_dict.ConfigDict( 50 | dict( 51 | config=dict( 52 | sweep_name='default_sweep', 53 | save_final_checkpoint_as_npy=True, 54 | # `enable_double_transpose` should always be false when using 1D. 55 | # For other D It is also completely untested and very unlikely 56 | # to work. 57 | optimizer=dict( 58 | base_lr=1e-3, 59 | max_norm=-1, # < 0 to turn off. 60 | weight_decay=1e-1, 61 | schedule_type='cosine', 62 | cosine_decay_kwargs=dict( 63 | init_value=0.0, 64 | warmup_steps=1000, 65 | end_value=0.0, 66 | ), 67 | optimizer='adam', 68 | # Optimizer-specific kwargs. 69 | adam_kwargs=dict( 70 | b1=0.9, 71 | b2=0.95, 72 | eps=1e-8, 73 | ), 74 | ), 75 | fast_variables=tuple(), 76 | shared_modules=dict( 77 | shared_module_names=config.get_oneway_ref( 78 | 'shared_module_names', 79 | ), 80 | tapir_model_kwargs=dict( 81 | bilinear_interp_with_depthwise_conv=True, 82 | use_causal_conv=False, 83 | ), 84 | ), 85 | datasets=dict( 86 | dataset_names=config.get_oneway_ref('dataset_names'), 87 | kubric_kwargs=dict( 88 | batch_dims=8, 89 | shuffle_buffer_size=128, 90 | train_size=TRAIN_SIZE[1:3], 91 | ), 92 | ), 93 | supervised_point_prediction_kwargs=dict( 94 | prediction_algo='cost_volume_regressor', 95 | model_key='tapir_model', 96 | ), 97 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 98 | evaluate_every=config.get_oneway_ref('evaluate_every'), 99 | eval_modes=config.get_oneway_ref('eval_modes'), 100 | # If true, run evaluate() on the experiment once before 101 | # you load a checkpoint. 102 | # This is useful for getting initial values of metrics 103 | # at random weights, or when debugging locally if you 104 | # do not have any train job running. 105 | davis_points_path='', 106 | jhmdb_path='', 107 | robotics_points_path='', 108 | training=dict( 109 | # Note: to sweep n_training_steps, DO NOT sweep these 110 | # fields directly. Instead sweep config.training_steps. 111 | # Otherwise, decay/stopping logic 112 | # is not guaranteed to be consistent. 113 | n_training_steps=config.get_oneway_ref('training_steps'), 114 | ), 115 | inference=dict( 116 | input_video_path='', 117 | output_video_path='', 118 | resize_height=256, # video height resized to before inference 119 | resize_width=256, # video width resized to before inference 120 | num_points=20, # number of random points to sample 121 | ), 122 | ) 123 | ) 124 | ) 125 | 126 | # Set up where to store the resulting model. 127 | config.train_checkpoint_all_hosts = False 128 | config.save_checkpoint_interval = 10 129 | config.eval_initial_weights = True 130 | 131 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 132 | config.lock() 133 | 134 | return config 135 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo program for TAPIR, to make sure that pytorch+jax has been set up correctly. 3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking. 4 | 5 | Example usage: 6 | ``` 7 | python -m sam_pt.point_tracker.tapir.demo 8 | ``` 9 | """ 10 | import time 11 | 12 | import haiku as hk 13 | import jax 14 | import matplotlib.cm as cm 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import tensorflow as tf 18 | import torch 19 | from torch.nn import functional as F 20 | 21 | from demo.demo import load_demo_data 22 | from . import tapir_model 23 | from .configs.tapir_config import get_config 24 | 25 | if __name__ == '__main__': 26 | # 1. Prepare config 27 | config = get_config() 28 | checkpoint_dir = "./models/tapir_ckpts/open_source_ckpt/" 29 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX. 30 | tf.config.experimental.set_visible_devices([], 'GPU') 31 | tf.config.experimental.set_visible_devices([], 'TPU') 32 | 33 | # 2. Prepare model 34 | checkpoint = np.load(checkpoint_dir + "tapir_checkpoint_panning.npy", allow_pickle=True).item() 35 | params, state = checkpoint["params"], checkpoint["state"] 36 | # tapir_model_kwargs = config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"] 37 | tapir_model_kwargs = { 38 | "bilinear_interp_with_depthwise_conv": False, 39 | "pyramid_level": 0, 40 | "use_causal_conv": False, 41 | } 42 | 43 | 44 | def forward(rgbs, query_points): 45 | tapir = tapir_model.TAPIR(**tapir_model_kwargs) 46 | outputs = tapir( 47 | video=rgbs[None, ...], 48 | query_points=query_points[None, ...], 49 | query_chunk_size=64, 50 | is_training=False, 51 | ) 52 | return outputs 53 | 54 | 55 | transform = hk.transform_with_state(forward) 56 | 57 | 58 | def f(rgbs_tapir, query_points_tapir): 59 | rng = jax.random.PRNGKey(72) 60 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir) 61 | return outputs 62 | 63 | 64 | jitted_f = jax.jit(f) 65 | 66 | # 3. Prepare data 67 | rgbs, _, query_points = load_demo_data( 68 | frames_path="data/demo_data/bees", 69 | query_points_path="data/demo_data/query_points__bees.txt", 70 | ) 71 | original_hw = rgbs.shape[-2:] 72 | tapir_input_hw = ( 73 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width) 74 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw) 75 | rgbs_tapir = F.interpolate(rgbs / 255, tapir_input_hw, mode="bilinear", align_corners=False, antialias=True) 76 | rgbs_tapir = rgbs_tapir.numpy() * 2 - 1 77 | rgbs_tapir = rgbs_tapir.transpose(0, 2, 3, 1) 78 | 79 | ## Take the loaded query points 80 | # query_points = query_points 81 | ## Or make a 16x16 grid of query points 82 | query_points = torch.zeros((1, 16, 16, 3), dtype=torch.float32) 83 | query_points[:, :, :, 0] = 1 84 | query_points[:, :, :, 1] = torch.linspace(1, original_hw[1] - 1, 16) 85 | query_points[:, :, :, 2] = torch.linspace(1, original_hw[0] - 1, 16).unsqueeze(-1) 86 | query_points = query_points.reshape(1, -1, 3) 87 | 88 | query_points_tapir = query_points.clone() 89 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0) 90 | query_points_tapir = query_points_tapir.flatten(0, 1) 91 | query_points_tapir[:, 1:] = query_points_tapir[:, 1:].flip(-1) 92 | query_points_tapir = query_points_tapir.numpy() 93 | 94 | # 4. Run model 95 | outputs = jitted_f(rgbs_tapir, query_points_tapir) 96 | 97 | n_frames = rgbs.shape[0] 98 | n_masks, n_points_per_mask, _ = query_points.shape 99 | 100 | # 5. Postprocess 101 | tapir_visibility_threshold = 0.5 102 | 103 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"][0]).copy()).permute(1, 0) 104 | expected_dist = expected_dist.unflatten(1, (n_masks, n_points_per_mask)) 105 | 106 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0) 107 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask)) 108 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist)) 109 | visibilities = visibilities_probs > tapir_visibility_threshold 110 | 111 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2) 112 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask)) 113 | trajectories = trajectories / rescale_factor_hw.flip(-1) 114 | 115 | # 6. Visualize 116 | mask_idx = -1 117 | for frame_idx in range(n_frames): 118 | h, w = rgbs.shape[2], rgbs.shape[3] 119 | dpi = 100 120 | plt.figure(figsize=(w / dpi, h / dpi)) 121 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none") 122 | x = trajectories[frame_idx, mask_idx, :, 0] 123 | y = trajectories[frame_idx, mask_idx, :, 1] 124 | colors = cm.rainbow(np.linspace(0, 1, len(y))) 125 | v = visibilities[frame_idx, mask_idx, :] 126 | # v = (visibilities[frame_idx, mask_idx, :] * 0) == 0 127 | x = x[v] 128 | y = y[v] 129 | colors = colors[v] 130 | plt.title(f"F{frame_idx:02}-M{mask_idx:02}-V{(visibilities_probs[frame_idx, mask_idx, :5] * 1)}") 131 | plt.scatter(x, y, color=colors, linewidths=6) 132 | plt.xlim(trajectories[..., 0].min(), trajectories[..., 0].max()) 133 | plt.ylim(trajectories[..., 1].max(), trajectories[..., 1].min()) 134 | plt.axis("off") 135 | plt.tight_layout(pad=0) 136 | plt.show() 137 | time.sleep(0.1) 138 | 139 | # 7. Benchmark forward pass speed in for loop 140 | n_loops = 100 141 | start_time = time.time() 142 | for _ in range(n_loops): 143 | outputs = jitted_f(rgbs_tapir, query_points_tapir) 144 | end_time = time.time() 145 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms") 146 | 147 | print("Done") 148 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapir/models/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from sam_pt.point_tracker import PointTracker 7 | 8 | 9 | class TapirPointTracker(PointTracker): 10 | """ 11 | A point tracker that uses TAPIR from https://arxiv.org/abs/2306.08637 to track points. 12 | """ 13 | 14 | def __init__(self, checkpoint_path, visibility_threshold): 15 | from .configs.tapir_config import get_config 16 | super().__init__() 17 | 18 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX 19 | tf.config.experimental.set_visible_devices([], 'GPU') 20 | tf.config.experimental.set_visible_devices([], 'TPU') 21 | 22 | # # v1: use the last GPU 23 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch) 24 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this 25 | # gpus = jax.devices('gpu') 26 | # device = gpus[-1] 27 | # jax.jit ... device=device 28 | 29 | # v2: share the gpu with Sam since they are run sequentially 30 | # but make jax free up the allocated memory once it is done 31 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform` 32 | 33 | assert checkpoint_path is not None 34 | self.checkpoint_path = checkpoint_path 35 | self.config = get_config() 36 | self.visibility_threshold = visibility_threshold 37 | self.jitted_forward = self._create_jitted_forward() 38 | 39 | def _create_jitted_forward(self): 40 | import haiku as hk 41 | import jax 42 | from . import tapir_model 43 | 44 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item() 45 | params, state = checkpoint["params"], checkpoint["state"] 46 | # tapir_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"] 47 | tapir_model_kwargs = { 48 | "bilinear_interp_with_depthwise_conv": False, 49 | "pyramid_level": 0, 50 | "use_causal_conv": False, 51 | } 52 | 53 | def _forward(rgbs, query_points): 54 | tapir = tapir_model.TAPIR(**tapir_model_kwargs) 55 | outputs = tapir( 56 | video=rgbs, 57 | query_points=query_points, 58 | query_chunk_size=64, 59 | is_training=False, 60 | ) 61 | return outputs 62 | 63 | transform = hk.transform_with_state(_forward) 64 | 65 | def forward(rgbs_tapir, query_points_tapir): 66 | rng = jax.random.PRNGKey(72) 67 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir) 68 | return outputs 69 | 70 | return jax.jit(forward) 71 | 72 | def forward(self, rgbs, query_points, summary_writer=None): 73 | batch_size, n_frames, channels, height, width = rgbs.shape 74 | n_points = query_points.shape[1] 75 | 76 | # 1. Prepare image resizing 77 | original_hw = (height, width) 78 | tapir_input_hw = ( 79 | self.config.experiment_kwargs.config.inference.resize_height, 80 | self.config.experiment_kwargs.config.inference.resize_width, 81 | ) 82 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw) 83 | 84 | # 2. Prepare inputs 85 | assert rgbs.dtype == torch.uint8 86 | rgbs_tapir = F.interpolate(rgbs.flatten(0, 1) / 255, tapir_input_hw, mode="bilinear", align_corners=False, 87 | antialias=True) 88 | rgbs_tapir = rgbs_tapir.unflatten(0, (batch_size, n_frames)) 89 | rgbs_tapir = rgbs_tapir.cpu().numpy() * 2 - 1 90 | rgbs_tapir = rgbs_tapir.transpose(0, 1, 3, 4, 2) 91 | query_points_tapir = query_points.cpu().clone() 92 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0) 93 | query_points_tapir[:, :, 1:] = query_points_tapir[:, :, 1:].flip(-1) # flip x and y 94 | query_points_tapir = query_points_tapir.numpy() 95 | 96 | # 3. Run model 97 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once? 98 | outputs = self.jitted_forward(rgbs_tapir, query_points_tapir) 99 | 100 | # 4. Postprocess outputs 101 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"]).copy()).permute(0, 2, 1) 102 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1) 103 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist)) 104 | visibilities = visibilities_probs > self.visibility_threshold 105 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3) 106 | trajectories = trajectories / rescale_factor_hw.flip(-1) 107 | 108 | return trajectories, visibilities 109 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapir/utils/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py 17 | 18 | """Utilities for transforming image coordinates.""" 19 | 20 | from typing import Sequence 21 | 22 | import numpy as np 23 | 24 | 25 | def convert_grid_coordinates( 26 | coords: np.ndarray, 27 | input_grid_size: Sequence[int], 28 | output_grid_size: Sequence[int], 29 | coordinate_format: str = 'xy', 30 | ) -> np.ndarray: 31 | """Convert image coordinates between image grids of different sizes. 32 | 33 | By default, it assumes that the image corners are aligned. Therefore, 34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 35 | cell), multiplies by the size ratio, and then subtracts .5. 36 | 37 | Args: 38 | coords: The coordinates to be converted. It is of shape [..., 2] if 39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 40 | input_grid_size: The size of the image/grid that the coordinates currently 41 | are with respect to. This is a 2-tuple of the format [width, height] 42 | if coordinate_format is 'xy' or a 3-tuple of the format 43 | [num_frames, height, width] if coordinate_format is 'tyx'. 44 | output_grid_size: The size of the target image/grid that you want the 45 | coordinates to be with respect to. This is a 2-tuple of the format 46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 47 | [num_frames, height, width] if coordinate_format is 'tyx'. 48 | coordinate_format: Which format the coordinates are in. This can be one 49 | of 'xy' (the default) or 'tyx', which are the only formats used in this 50 | project. 51 | 52 | Returns: 53 | The transformed coordinates, of the same shape as coordinates. 54 | 55 | Raises: 56 | ValueError: if coordinates don't match the given format. 57 | """ 58 | if isinstance(input_grid_size, tuple): 59 | input_grid_size = np.array(input_grid_size) 60 | if isinstance(output_grid_size, tuple): 61 | output_grid_size = np.array(output_grid_size) 62 | 63 | if coordinate_format == 'xy': 64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 65 | raise ValueError( 66 | 'If coordinate_format is xy, the shapes must be length 2.') 67 | elif coordinate_format == 'tyx': 68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 69 | raise ValueError( 70 | 'If coordinate_format is tyx, the shapes must be length 3.') 71 | if input_grid_size[0] != output_grid_size[0]: 72 | raise ValueError('converting frame count is not supported.') 73 | else: 74 | raise ValueError('Recognized coordinate formats are xy and tyx.') 75 | 76 | position_in_grid = coords 77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 78 | 79 | return position_in_grid 80 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import TapnetPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/configs/tapnet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapnet_config.py 17 | 18 | """Default config to train the TapNet.""" 19 | 20 | from jaxline import base_config 21 | from ml_collections import config_dict 22 | 23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels) 24 | 25 | 26 | # We define the experiment launch config in the same file as the experiment to 27 | # keep things self-contained in a single file. 28 | def get_config() -> config_dict.ConfigDict: 29 | """Return config object for training.""" 30 | config = base_config.get_base_config() 31 | 32 | # Experiment config. 33 | config.training_steps = 100000 34 | 35 | # NOTE: duplicates not allowed. 36 | config.shared_module_names = ('tapnet_model',) 37 | 38 | config.dataset_names = ('kubric',) 39 | # Note: eval modes must always start with 'eval_'. 40 | config.eval_modes = ( 41 | 'eval_davis_points', 42 | 'eval_jhmdb', 43 | 'eval_robotics_points', 44 | 'eval_kinetics_points', 45 | ) 46 | config.checkpoint_dir = 'logs/tapnet_training/' 47 | config.evaluate_every = 100 48 | 49 | config.experiment_kwargs = config_dict.ConfigDict( 50 | dict( 51 | config=dict( 52 | sweep_name='default_sweep', 53 | save_final_checkpoint_as_npy=True, 54 | # `enable_double_transpose` should always be false when using 1D. 55 | # For other D It is also completely untested and very unlikely 56 | # to work. 57 | optimizer=dict( 58 | base_lr=2e-3, 59 | max_norm=-1, # < 0 to turn off. 60 | weight_decay=1e-2, 61 | schedule_type='cosine', 62 | cosine_decay_kwargs=dict( 63 | init_value=0.0, 64 | warmup_steps=5000, 65 | end_value=0.0, 66 | ), 67 | optimizer='adam', 68 | # Optimizer-specific kwargs. 69 | adam_kwargs=dict( 70 | b1=0.9, 71 | b2=0.95, 72 | eps=1e-8, 73 | ), 74 | ), 75 | fast_variables=tuple(), 76 | shared_modules=dict( 77 | shared_module_names=config.get_oneway_ref( 78 | 'shared_module_names', 79 | ), 80 | tapnet_model_kwargs=dict(), 81 | ), 82 | datasets=dict( 83 | dataset_names=config.get_oneway_ref('dataset_names'), 84 | kubric_kwargs=dict( 85 | batch_dims=8, 86 | shuffle_buffer_size=128, 87 | train_size=TRAIN_SIZE[1:3], 88 | ), 89 | ), 90 | supervised_point_prediction_kwargs=dict( 91 | prediction_algo='cost_volume_regressor', 92 | ), 93 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 94 | evaluate_every=config.get_oneway_ref('evaluate_every'), 95 | eval_modes=config.get_oneway_ref('eval_modes'), 96 | # If true, run evaluate() on the experiment once before 97 | # you load a checkpoint. 98 | # This is useful for getting initial values of metrics 99 | # at random weights, or when debugging locally if you 100 | # do not have any train job running. 101 | davis_points_path='', 102 | jhmdb_path='', 103 | robotics_points_path='', 104 | training=dict( 105 | # Note: to sweep n_training_steps, DO NOT sweep these 106 | # fields directly. Instead, sweep config.training_steps. 107 | # Otherwise, decay/stopping logic 108 | # is not guaranteed to be consistent. 109 | n_training_steps=config.get_oneway_ref('training_steps'), 110 | ), 111 | inference=dict( 112 | input_video_path='', 113 | output_video_path='', 114 | resize_height=256, # video height resized to before inference 115 | resize_width=256, # video width resized to before inference 116 | num_points=20, # number of random points to sample 117 | ), 118 | ) 119 | ) 120 | ) 121 | 122 | # Set up where to store the resulting model. 123 | config.train_checkpoint_all_hosts = False 124 | config.save_checkpoint_interval = 10 125 | config.eval_initial_weights = True 126 | 127 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 128 | config.lock() 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo program for TAPNet, to make sure that pytorch+jax has been set up correctly. 3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking. 4 | 5 | Example usage: 6 | ``` 7 | python -m sam_pt.point_tracker.tapnet.demo 8 | ``` 9 | """ 10 | import time 11 | 12 | import haiku as hk 13 | import jax 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import tensorflow as tf 17 | import torch 18 | from torch.nn import functional as F 19 | 20 | from demo.demo import load_demo_data 21 | from . import tapnet_model 22 | from .configs.tapnet_config import get_config 23 | 24 | if __name__ == '__main__': 25 | # 1. Prepare config 26 | config = get_config() 27 | checkpoint_dir = "./models/tapnet_ckpts/open_source_ckpt/" 28 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX. 29 | tf.config.experimental.set_visible_devices([], 'GPU') 30 | tf.config.experimental.set_visible_devices([], 'TPU') 31 | 32 | # 2. Prepare model 33 | checkpoint = np.load(checkpoint_dir + "checkpoint_wo_optstate.npy", allow_pickle=True).item() 34 | params, state = checkpoint["params"], checkpoint["state"] 35 | tapnet_model_kwargs = config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"] 36 | 37 | 38 | def forward(rgbs, query_points): 39 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs) 40 | outputs = tapnet( 41 | video=rgbs[None, ...], 42 | query_points=query_points[None, ...], 43 | query_chunk_size=16, 44 | get_query_feats=True, 45 | is_training=False, 46 | ) 47 | return outputs 48 | 49 | 50 | transform = hk.transform_with_state(forward) 51 | 52 | 53 | def f(rgbs_tapnet, query_points_tapnet): 54 | rng = jax.random.PRNGKey(72) 55 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet) 56 | return outputs 57 | 58 | 59 | jitted_f = jax.jit(f) 60 | 61 | # 3. Prepare data 62 | rgbs, _, query_points = load_demo_data( 63 | frames_path="data/demo_data/bees", 64 | query_points_path="data/demo_data/query_points__bees.txt", 65 | ) 66 | original_hw = rgbs.shape[-2:] 67 | tapnet_input_hw = ( 68 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width) 69 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw) 70 | rgbs_tapnet = F.interpolate(rgbs / 255, tapnet_input_hw, mode="bilinear", align_corners=False, antialias=True) 71 | rgbs_tapnet = rgbs_tapnet.numpy() * 2 - 1 72 | rgbs_tapnet = rgbs_tapnet.transpose(0, 2, 3, 1) 73 | query_points_tapnet = query_points.clone() 74 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0) 75 | query_points_tapnet = query_points_tapnet.flatten(0, 1) 76 | query_points_tapnet[:, 1:] = query_points_tapnet[:, 1:].flip(-1) 77 | query_points_tapnet = query_points_tapnet.numpy() 78 | query_points_tapnet = query_points_tapnet 79 | 80 | # 4. Run model 81 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet) 82 | 83 | n_frames = rgbs.shape[0] 84 | n_masks, n_points_per_mask, _ = query_points.shape 85 | 86 | # 5. Postprocess 87 | tapnet_visibility_threshold = 0.5 88 | 89 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0) 90 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask)) 91 | occlussion_probs = torch.sigmoid(occlussion_logits) 92 | visibilities_probs = 1 - occlussion_probs 93 | visibilities = visibilities_probs > tapnet_visibility_threshold 94 | 95 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2) 96 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask)) 97 | trajectories = trajectories / rescale_factor_hw.flip(-1) 98 | 99 | # 6. Visualize 100 | for mask_idx in range(n_masks): 101 | if mask_idx != 2: 102 | continue 103 | for frame_idx in range(n_frames): 104 | h, w = rgbs.shape[2], rgbs.shape[3] 105 | dpi = 100 106 | plt.figure(figsize=(w / dpi, h / dpi)) 107 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none") 108 | plt.scatter(trajectories[frame_idx, mask_idx, :, 0], trajectories[frame_idx, mask_idx, :, 1]) 109 | plt.axis("off") 110 | plt.tight_layout(pad=0) 111 | plt.show() 112 | 113 | # 7. Benchmark forward pass speed in for loop 114 | n_loops = 100 115 | start_time = time.time() 116 | for _ in range(n_loops): 117 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet) 118 | end_time = time.time() 119 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms") 120 | 121 | print("Done") 122 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapnet/models/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from sam_pt.point_tracker import PointTracker 7 | 8 | 9 | class TapnetPointTracker(PointTracker): 10 | """ 11 | A point tracker that uses TapNet from https://arxiv.org/abs/2211.03726 to track points. 12 | """ 13 | def __init__(self, checkpoint_path, visibility_threshold): 14 | from .configs.tapnet_config import get_config 15 | super().__init__() 16 | 17 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX 18 | tf.config.experimental.set_visible_devices([], 'GPU') 19 | tf.config.experimental.set_visible_devices([], 'TPU') 20 | 21 | # # v1: use the last GPU 22 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch) 23 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this 24 | # gpus = jax.devices('gpu') 25 | # device = gpus[-1] 26 | # jax.jit ... device=device 27 | 28 | # v2: share the gpu with Sam since they are run sequentially 29 | # but make jax free up the allocated memory once it is done 30 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform` 31 | 32 | assert checkpoint_path is not None 33 | self.checkpoint_path = checkpoint_path 34 | self.config = get_config() 35 | self.visibility_threshold = visibility_threshold 36 | self.jitted_forward = self._create_jitted_forward() 37 | 38 | def _create_jitted_forward(self): 39 | import haiku as hk 40 | import jax 41 | from . import tapnet_model 42 | 43 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item() 44 | params, state = checkpoint["params"], checkpoint["state"] 45 | tapnet_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"] 46 | 47 | def _forward(rgbs, query_points): 48 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs) 49 | outputs = tapnet( 50 | video=rgbs, 51 | query_points=query_points, 52 | query_chunk_size=16, 53 | get_query_feats=True, 54 | is_training=False, 55 | ) 56 | return outputs 57 | 58 | transform = hk.transform_with_state(_forward) 59 | 60 | def forward(rgbs_tapnet, query_points_tapnet): 61 | rng = jax.random.PRNGKey(72) 62 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet) 63 | return outputs 64 | 65 | return jax.jit(forward) 66 | 67 | def forward(self, rgbs, query_points, summary_writer=None): 68 | batch_size, n_frames, channels, height, width = rgbs.shape 69 | n_points = query_points.shape[1] 70 | 71 | # 1. Prepare image resizing 72 | original_hw = (height, width) 73 | tapnet_input_hw = ( 74 | self.config.experiment_kwargs.config.inference.resize_height, 75 | self.config.experiment_kwargs.config.inference.resize_width, 76 | ) 77 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw) 78 | 79 | # 2. Prepare inputs 80 | rgbs_tapnet = F.interpolate(rgbs.flatten(0, 1) / 255, tapnet_input_hw, mode="bilinear", align_corners=False, 81 | antialias=True) 82 | rgbs_tapnet = rgbs_tapnet.unflatten(0, (batch_size, n_frames)) 83 | rgbs_tapnet = rgbs_tapnet.cpu().numpy() * 2 - 1 84 | rgbs_tapnet = rgbs_tapnet.transpose(0, 1, 3, 4, 2) 85 | query_points_tapnet = query_points.cpu().clone() 86 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0) 87 | query_points_tapnet[:, :, 1:] = query_points_tapnet[:, :, 1:].flip(-1) # flip x and y 88 | query_points_tapnet = query_points_tapnet.numpy() 89 | 90 | # 3. Run model 91 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once? 92 | outputs = self.jitted_forward(rgbs_tapnet, query_points_tapnet) 93 | 94 | # 4. Postprocess outputs 95 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1) 96 | occlussion_probs = torch.sigmoid(occlussion_logits) 97 | visibilities_probs = 1 - occlussion_probs 98 | visibilities = visibilities_probs > self.visibility_threshold 99 | 100 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3) 101 | trajectories = trajectories / rescale_factor_hw.flip(-1) 102 | 103 | return trajectories, visibilities 104 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapnet/utils/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py 17 | 18 | """Utilities for transforming image coordinates.""" 19 | 20 | from typing import Sequence 21 | 22 | import numpy as np 23 | 24 | 25 | def convert_grid_coordinates( 26 | coords: np.ndarray, 27 | input_grid_size: Sequence[int], 28 | output_grid_size: Sequence[int], 29 | coordinate_format: str = 'xy', 30 | ) -> np.ndarray: 31 | """Convert image coordinates between image grids of different sizes. 32 | 33 | By default, it assumes that the image corners are aligned. Therefore, 34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 35 | cell), multiplies by the size ratio, and then subtracts .5. 36 | 37 | Args: 38 | coords: The coordinates to be converted. It is of shape [..., 2] if 39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 40 | input_grid_size: The size of the image/grid that the coordinates currently 41 | are with respect to. This is a 2-tuple of the format [width, height] 42 | if coordinate_format is 'xy' or a 3-tuple of the format 43 | [num_frames, height, width] if coordinate_format is 'tyx'. 44 | output_grid_size: The size of the target image/grid that you want the 45 | coordinates to be with respect to. This is a 2-tuple of the format 46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 47 | [num_frames, height, width] if coordinate_format is 'tyx'. 48 | coordinate_format: Which format the coordinates are in. This can be one 49 | of 'xy' (the default) or 'tyx', which are the only formats used in this 50 | project. 51 | 52 | Returns: 53 | The transformed coordinates, of the same shape as coordinates. 54 | 55 | Raises: 56 | ValueError: if coordinates don't match the given format. 57 | """ 58 | if isinstance(input_grid_size, tuple): 59 | input_grid_size = np.array(input_grid_size) 60 | if isinstance(output_grid_size, tuple): 61 | output_grid_size = np.array(output_grid_size) 62 | 63 | if coordinate_format == 'xy': 64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 65 | raise ValueError( 66 | 'If coordinate_format is xy, the shapes must be length 2.') 67 | elif coordinate_format == 'tyx': 68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 69 | raise ValueError( 70 | 'If coordinate_format is tyx, the shapes must be length 3.') 71 | if input_grid_size[0] != output_grid_size[0]: 72 | raise ValueError('converting frame count is not supported.') 73 | else: 74 | raise ValueError('Recognized coordinate formats are xy and tyx.') 75 | 76 | position_in_grid = coords 77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 78 | 79 | return position_in_grid 80 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from torch import nn 4 | from typing import Tuple 5 | 6 | 7 | class PointTracker(ABC, nn.Module): 8 | """ 9 | Abstract class for point trackers. 10 | 11 | Methods 12 | ------- 13 | forward(rgbs, query_points) 14 | Performs a forward pass through the model and returns the predicted trajectories and visibilities. 15 | evaluate_batch(rgbs, query_points, trajectories_gt=None, visibilities_gt=None) 16 | Evaluates a batch of videos and returns the results. 17 | unpack_results(packed_results, batch_idx) 18 | Unpacks the results for all point and all videos in the batch. 19 | """ 20 | 21 | @abstractmethod 22 | def forward(self, rgbs, query_points) -> Tuple[torch.Tensor, torch.Tensor]: 23 | """ 24 | Performs a forward pass through the model and returns the predicted trajectories and visibilities. 25 | 26 | Parameters 27 | ---------- 28 | rgbs : torch.Tensor 29 | A tensor of shape (batch_size, n_frames, channels, height, width) 30 | containing the RGB images in uint8 [0-255] format. 31 | query_points : torch.Tensor 32 | A tensor of shape (batch_size, n_points, 3) containing the query points, 33 | each point being (t, x, y). 34 | 35 | Returns 36 | ------- 37 | tuple of two torch.Tensor 38 | Returns a tuple of (trajectories, visibilities). 39 | - `trajectories`: Predicted point trajectories with shape (batch_size, n_frames, n_points, 2), where each 40 | trajectory represents a series of (x, y) coordinates in the video for a specific point. 41 | - `visibilities`: Predicted point visibilities with shape (batch_size, n_frames, n_points), where each 42 | visibility represents the likelihood of a point being visible in the corresponding frame 43 | of the video. 44 | """ 45 | pass 46 | 47 | def evaluate_batch(self, rgbs, query_points, trajectories_gt=None, visibilities_gt=None): 48 | """ 49 | Evaluates a batch of data and returns the results. 50 | 51 | Parameters 52 | ---------- 53 | rgbs : torch.Tensor 54 | A tensor of shape (batch_size, n_frames, channels, height, width) 55 | containing the RGB images in uint8 [0-255] format. 56 | query_points : torch.Tensor 57 | A tensor of shape (batch_size, n_points, 3) containing the query points, 58 | each point being (t, x, y). 59 | trajectories_gt : torch.Tensor, optional 60 | A 4D tensor representing the ground-truth trajectory. Its shape is (batch_size, n_frames, n_points, 2). 61 | visibilities_gt : torch.Tensor, optional 62 | A 3D tensor representing the ground-truth visibilities. Its shape is (batch_size, n_frames, n_points). 63 | 64 | Returns 65 | ------- 66 | dict 67 | A dictionary containing the results. 68 | """ 69 | trajectories_pred, visibilities_pred = self.forward(rgbs, query_points) 70 | batch_size = rgbs.shape[0] 71 | n_frames = rgbs.shape[1] 72 | n_points = query_points.shape[1] 73 | assert trajectories_pred.shape == (batch_size, n_frames, n_points, 2) 74 | 75 | results = { 76 | "trajectories_pred": trajectories_pred.detach().clone().cpu(), 77 | "visibilities_pred": visibilities_pred.detach().clone().cpu(), 78 | "query_points": query_points.detach().clone().cpu(), 79 | "trajectories_gt": trajectories_gt.detach().clone().cpu() if trajectories_gt is not None else None, 80 | "visibilities_gt": visibilities_gt.detach().clone().cpu() if visibilities_gt is not None else None, 81 | } 82 | 83 | return results 84 | 85 | @classmethod 86 | def unpack_results(cls, packed_results, batch_idx): 87 | """ 88 | Unpacks the results for all point and all videos in the batch. 89 | 90 | Parameters 91 | ---------- 92 | packed_results : dict 93 | The dictionary containing the packed results, for all videos in the batch and all points in the video. 94 | batch_idx : int 95 | The index of the current batch. 96 | 97 | Returns 98 | ------- 99 | list 100 | A list of dictionaries, each containing the unpacked results for a data point. 101 | """ 102 | unpacked_results_list = [] 103 | for b in range(packed_results["trajectories_pred"].shape[0]): 104 | for n in range(packed_results["trajectories_pred"].shape[2]): 105 | result = { 106 | "idx": f"{batch_idx}_{b}_{n}", 107 | "iter": batch_idx, 108 | "video_idx": b, 109 | "point_idx_in_video": n, 110 | "query_point": packed_results["query_points"][b, n, :], 111 | "trajectory_pred": packed_results["trajectories_pred"][b, :, n, :], 112 | "visibility_pred": packed_results["visibilities_pred"][b, :, n], 113 | } 114 | if packed_results["trajectories_gt"] is not None: 115 | result["trajectory_gt"] = packed_results["trajectories_gt"][b, :, n, :] 116 | result["visibility_gt"] = packed_results["visibilities_gt"][b, :, n] 117 | unpacked_results_list += [result] 118 | return unpacked_results_list 119 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/utils/__init__.py -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Adapted from: 2 | # - https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/misc.py 3 | # - https://github.com/aharley/pips2/blob/06bff81f25f2866728ff94f5d3a02c00893a8f15/utils/misc.py 4 | 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def posemb_sincos_2d_xy(xy, C, temperature=10000, cat_coords=False): 11 | device = xy.device 12 | dtype = xy.dtype 13 | B, S, D = xy.shape 14 | assert (D == 2) 15 | x = xy[:, :, 0] 16 | y = xy[:, :, 1] 17 | assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' 18 | omega = torch.arange(C // 4, device=device) / (C // 4 - 1) 19 | omega = 1. / (temperature ** omega) 20 | 21 | y = y.flatten()[:, None] * omega[None, :] 22 | x = x.flatten()[:, None] * omega[None, :] 23 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 24 | pe = pe.reshape(B, S, C).type(dtype) 25 | if cat_coords: 26 | pe = torch.cat([pe, xy], dim=2) # B,N,C+2 27 | return pe 28 | 29 | 30 | def get_3d_embedding(xyz, C, cat_coords=True): 31 | B, N, D = xyz.shape 32 | assert (D == 3) 33 | 34 | x = xyz[:, :, 0:1] 35 | y = xyz[:, :, 1:2] 36 | z = xyz[:, :, 2:3] 37 | div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) 38 | 39 | pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 40 | pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 41 | pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 42 | 43 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 44 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 45 | 46 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 47 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 48 | 49 | pe_z[:, :, 0::2] = torch.sin(z * div_term) 50 | pe_z[:, :, 1::2] = torch.cos(z * div_term) 51 | 52 | pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 53 | if cat_coords: 54 | pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 55 | return pe 56 | 57 | 58 | class SimplePool(): 59 | def __init__(self, pool_size, version='pt'): 60 | self.pool_size = pool_size 61 | self.version = version 62 | # random.seed(125) 63 | if self.pool_size > 0: 64 | self.num = 0 65 | self.items = [] 66 | if not (version == 'pt' or version == 'np'): 67 | print('version = %s; please choose pt or np') 68 | assert (False) # please choose pt or np 69 | 70 | def __len__(self): 71 | return len(self.items) 72 | 73 | def mean(self, min_size='none'): 74 | if min_size == 'half': 75 | pool_size_thresh = self.pool_size / 2 76 | else: 77 | pool_size_thresh = 1 78 | 79 | if self.version == 'np': 80 | if len(self.items) >= pool_size_thresh: 81 | return np.sum(self.items) / float(len(self.items)) 82 | else: 83 | return np.nan 84 | if self.version == 'pt': 85 | if len(self.items) >= pool_size_thresh: 86 | return torch.sum(self.items) / float(len(self.items)) 87 | else: 88 | return torch.from_numpy(np.nan) 89 | 90 | def sample(self): 91 | idx = np.random.randint(len(self.items)) 92 | return self.items[idx] 93 | 94 | def fetch(self, num=None): 95 | if self.version == 'pt': 96 | item_array = torch.stack(self.items) 97 | elif self.version == 'np': 98 | item_array = np.stack(self.items) 99 | if num is not None: 100 | # there better be some items 101 | assert (len(self.items) >= num) 102 | 103 | # if there are not that many elements just return however many there are 104 | if len(self.items) < num: 105 | return item_array 106 | else: 107 | idxs = np.random.randint(len(self.items), size=num) 108 | return item_array[idxs] 109 | else: 110 | return item_array 111 | 112 | def is_full(self): 113 | full = self.num == self.pool_size 114 | # print 'num = %d; full = %s' % (self.num, full) 115 | return full 116 | 117 | def empty(self): 118 | self.items = [] 119 | self.num = 0 120 | 121 | def update(self, items): 122 | for item in items: 123 | if self.num < self.pool_size: 124 | # the pool is not full, so let's add this in 125 | self.num = self.num + 1 126 | else: 127 | # the pool is full 128 | # pop from the front 129 | self.items.pop(0) 130 | # add to the back 131 | self.items.append(item) 132 | return self.items 133 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/samp.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/samp.py 2 | 3 | import torch 4 | 5 | 6 | def bilinear_sample2d(im, x, y, return_inbounds=False): 7 | # x and y are each B, N 8 | # output is B, C, N 9 | B, C, H, W = list(im.shape) 10 | N = list(x.shape)[1] 11 | 12 | x = x.float() 13 | y = y.float() 14 | H_f = torch.tensor(H, dtype=torch.float32) 15 | W_f = torch.tensor(W, dtype=torch.float32) 16 | 17 | # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() 74 | y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() 75 | inbounds = (x_valid & y_valid).float() 76 | inbounds = inbounds.reshape(B, 77 | N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) 78 | return output, inbounds 79 | 80 | return output # B, C, N 81 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/saverloader.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/saverloader.py 2 | 3 | import os 4 | import pathlib 5 | 6 | import torch 7 | 8 | 9 | def save(ckpt_dir, optimizer, model, global_step, scheduler=None, model_ema=None, keep_latest=5, model_name='model'): 10 | if not os.path.exists(ckpt_dir): 11 | os.makedirs(ckpt_dir) 12 | 13 | prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*' % model_name)) 14 | prev_ckpts.sort(key=lambda p: p.stat().st_mtime, reverse=True) 15 | if len(prev_ckpts) > keep_latest - 1: 16 | for f in prev_ckpts[keep_latest - 1:]: 17 | f.unlink() 18 | model_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step) 19 | 20 | ckpt = {'optimizer_state_dict': optimizer.state_dict()} 21 | ckpt['model_state_dict'] = model.state_dict() 22 | if scheduler is not None: 23 | ckpt['scheduler_state_dict'] = scheduler.state_dict() 24 | if model_ema is not None: 25 | ckpt['ema_model_state_dict'] = model_ema.state_dict() 26 | torch.save(ckpt, model_path) 27 | print("saved a checkpoint: %s" % (model_path)) 28 | 29 | 30 | def load(ckpt_dir, model, device=None, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', 31 | ignore_load=None): 32 | print('reading ckpt from %s' % ckpt_dir) 33 | assert os.path.exists(ckpt_dir) 34 | 35 | ckpt_names = os.listdir(ckpt_dir) 36 | steps = [int((i.split('-')[1]).split('.')[0]) for i in ckpt_names] 37 | assert len(ckpt_names) > 0 38 | 39 | if step == 0: 40 | step = max(steps) 41 | model_name = '%s-%09d.pth' % (model_name, step) 42 | path = os.path.join(ckpt_dir, model_name) 43 | print('...found checkpoint %s' % (path)) 44 | 45 | if ignore_load is not None: 46 | 47 | print('ignoring', ignore_load) 48 | 49 | checkpoint = torch.load(path)['model_state_dict'] 50 | 51 | model_dict = model.state_dict() 52 | 53 | # 1. filter out ignored keys 54 | pretrained_dict = {k: v for k, v in checkpoint.items()} 55 | for ign in ignore_load: 56 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if not ign in k} 57 | 58 | # 2. overwrite entries in the existing state dict 59 | model_dict.update(pretrained_dict) 60 | # 3. load the new state dict 61 | model.load_state_dict(model_dict, strict=False) 62 | else: 63 | checkpoint = torch.load(path, map_location=device) 64 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 65 | 66 | if optimizer is not None: 67 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 68 | if scheduler is not None: 69 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 70 | if model_ema is not None: 71 | model_ema.load_state_dict(checkpoint['ema_model_state_dict']) 72 | 73 | return step 74 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/test.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/test.py 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def prep_frame_for_dino(img, scale_size=[192]): 10 | """ 11 | read a single frame & preprocess 12 | """ 13 | ori_h, ori_w, _ = img.shape 14 | if len(scale_size) == 1: 15 | if (ori_h > ori_w): 16 | tw = scale_size[0] 17 | th = (tw * ori_h) / ori_w 18 | th = int((th // 64) * 64) 19 | else: 20 | th = scale_size[0] 21 | tw = (th * ori_w) / ori_h 22 | tw = int((tw // 64) * 64) 23 | else: 24 | th, tw = scale_size 25 | img = cv2.resize(img, (tw, th)) 26 | img = img.astype(np.float32) 27 | img = img / 255.0 28 | img = img[:, :, ::-1] 29 | img = np.transpose(img.copy(), (2, 0, 1)) 30 | img = torch.from_numpy(img).float() 31 | 32 | def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]): 33 | for t, m, s in zip(x, mean, std): 34 | t.sub_(m) 35 | t.div_(s) 36 | return x 37 | 38 | img = color_normalize(img) 39 | return img, ori_h, ori_w 40 | 41 | 42 | def get_feats_from_dino(model, frame): 43 | # batch version of the other func 44 | B = frame.shape[0] 45 | patch_size = model.patch_embed.patch_size 46 | h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size) 47 | out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim 48 | dim = out.shape[-1] 49 | out = out[:, 1:, :] # discard the [CLS] token 50 | outmap = out.permute(0, 2, 1).reshape(B, dim, h, w) 51 | return out, outmap, h, w 52 | 53 | 54 | def restrict_neighborhood(h, w): 55 | size_mask_neighborhood = 12 56 | # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'') 57 | mask = torch.zeros(h, w, h, w) 58 | for i in range(h): 59 | for j in range(w): 60 | for p in range(2 * size_mask_neighborhood + 1): 61 | for q in range(2 * size_mask_neighborhood + 1): 62 | if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h: 63 | continue 64 | if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w: 65 | continue 66 | mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1 67 | 68 | mask = mask.reshape(h * w, h * w) 69 | return mask.cuda(non_blocking=True) 70 | 71 | 72 | def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None): 73 | ncontext = len(list_frame_feats) 74 | feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w 75 | 76 | feat_tar = F.normalize(feat_tar, dim=1, p=2) 77 | feat_sources = F.normalize(feat_sources, dim=1, p=2) 78 | 79 | # print('feat_tar', feat_tar.shape) 80 | # print('feat_sources', feat_sources.shape) 81 | 82 | feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1) 83 | aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1) 84 | 85 | size_mask_neighborhood = 12 86 | if size_mask_neighborhood > 0: 87 | if mask_neighborhood is None: 88 | mask_neighborhood = restrict_neighborhood(h, w) 89 | mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1) 90 | aff *= mask_neighborhood 91 | 92 | aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries) 93 | topk = 5 94 | tk_val, _ = torch.topk(aff, dim=0, k=topk) 95 | tk_val_min, _ = torch.min(tk_val, dim=0) 96 | aff[aff < tk_val_min] = 0 97 | 98 | aff = aff / torch.sum(aff, keepdim=True, axis=0) 99 | 100 | list_segs = [s.cuda() for s in list_segs] 101 | segs = torch.cat(list_segs) 102 | nmb_context, C, h, w = segs.shape 103 | segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w 104 | seg_tar = torch.mm(segs, aff) 105 | seg_tar = seg_tar.reshape(1, C, h, w) 106 | 107 | return seg_tar, mask_neighborhood 108 | 109 | 110 | def norm_mask(mask): 111 | c, h, w = mask.size() 112 | for cnt in range(c): 113 | mask_cnt = mask[cnt, :, :] 114 | if (mask_cnt.max() > 0): 115 | mask_cnt = (mask_cnt - mask_cnt.min()) 116 | mask_cnt = mask_cnt / mask_cnt.max() 117 | mask[cnt, :, :] = mask_cnt 118 | return mask 119 | 120 | 121 | def get_dino_output(dino, rgbs, trajs_g, vis_g): 122 | B, S, C, H, W = rgbs.shape 123 | 124 | B1, S1, N, D = trajs_g.shape 125 | assert (B1 == B) 126 | assert (S1 == S) 127 | assert (D == 2) 128 | 129 | assert (B == 1) 130 | xy0 = trajs_g[:, 0] # B, N, 2 131 | 132 | # The queue stores the n preceeding frames 133 | import queue 134 | import copy 135 | n_last_frames = 7 136 | que = queue.Queue(n_last_frames) 137 | 138 | # run dino 139 | prep_rgbs = [] 140 | for s in range(S): 141 | prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1, 2, 0).detach().cpu().numpy(), scale_size=[H]) 142 | prep_rgbs.append(prep_rgb) 143 | prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W 144 | with torch.no_grad(): 145 | bs = 8 146 | idx = 0 147 | featmaps = [] 148 | while idx < S: 149 | end_id = min(S, idx + bs) 150 | _, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w 151 | idx = end_id 152 | featmaps.append(featmaps_cur) 153 | featmaps = torch.cat(featmaps, dim=0) 154 | C = featmaps.shape[1] 155 | featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w 156 | # featmaps = F.normalize(featmaps, dim=2, p=2) 157 | 158 | xy0 = trajs_g[:, 0, :] # B, N, 2 159 | patch_size = dino.patch_embed.patch_size 160 | first_seg = torch.zeros((1, N, H // patch_size, W // patch_size)) 161 | for n in range(N): 162 | first_seg[0, n, (xy0[0, n, 1] / patch_size).long(), (xy0[0, n, 0] / patch_size).long()] = 1 163 | 164 | frame1_feat = featmaps[0, 0].reshape(C, h * w) # dim x h*w 165 | mask_neighborhood = None 166 | accs = [] 167 | trajs_e = torch.zeros_like(trajs_g) 168 | trajs_e[0, 0] = trajs_g[0, 0] 169 | for cnt in range(1, S): 170 | used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)] 171 | used_segs = [first_seg] + [pair[1] for pair in list(que.queue)] 172 | 173 | feat_tar = featmaps[0, cnt].reshape(C, h * w) 174 | 175 | frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs, 176 | mask_neighborhood) 177 | 178 | # pop out oldest frame if neccessary 179 | if que.qsize() == n_last_frames: 180 | que.get() 181 | # push current results into queue 182 | seg = copy.deepcopy(frame_tar_avg) 183 | que.put([feat_tar, seg]) 184 | 185 | # upsampling & argmax 186 | frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False, 187 | recompute_scale_factor=False)[0] 188 | frame_tar_avg = norm_mask(frame_tar_avg) 189 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0) 190 | 191 | for n in range(N): 192 | vis = vis_g[0, cnt, n] 193 | if len(torch.nonzero(frame_tar_avg[n])) > 0: 194 | # weighted average 195 | nz = torch.nonzero(frame_tar_avg[n]) 196 | coord_e = torch.sum(frame_tar_avg[n][nz[:, 0], nz[:, 1]].reshape(-1, 1) * nz.float(), 0) / \ 197 | frame_tar_avg[n][nz[:, 0], nz[:, 1]].sum() # 2 198 | coord_e = coord_e[[1, 0]] 199 | else: 200 | # stay where it was 201 | coord_e = trajs_e[0, cnt - 1, n] 202 | 203 | trajs_e[0, cnt, n] = coord_e 204 | return trajs_e 205 | -------------------------------------------------------------------------------- /sam_pt/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/utils/__init__.py -------------------------------------------------------------------------------- /sam_pt/vis_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vis_eval/__init__.py -------------------------------------------------------------------------------- /sam_pt/vis_eval/eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pandas as pd 3 | import wandb 4 | from hydra.core.hydra_config import HydraConfig 5 | from hydra.utils import instantiate 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .train_net_video import * 9 | 10 | 11 | def main_inner(cfg: DictConfig) -> None: 12 | # Setup config 13 | detectron2_config = cfg.DETECTRON2_CONFIG 14 | default_setup(detectron2_config, {"eval_only": True}) 15 | 16 | # Setup logging 17 | setup_logger(name="point_tracking_vis_eval") 18 | setup_logger(output=cfg.DETECTRON2_CONFIG.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="point_tracking_video") 19 | if comm.is_main_process(): 20 | wandb.init( 21 | entity=cfg.logging.wandb.entity, 22 | project=cfg.logging.wandb.project, 23 | name=cfg.logging.exp_id, 24 | group=cfg.logging.exp_id, 25 | config={ 26 | "cfg": OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), 27 | "work_dir": os.getcwd(), 28 | "hydra_cfg": HydraConfig.get() if HydraConfig.instance().cfg is not None else None, 29 | }, 30 | ) 31 | wandb.run.log_code(cfg.logging.wandb.log_code_path) 32 | wandb.run.summary["work_dir"] = os.path.abspath(os.getcwd()) 33 | 34 | # Load model 35 | model = instantiate(cfg.model) 36 | model = model.to(cfg.device) 37 | model = model.eval() 38 | 39 | # Evaluate model 40 | results = Trainer.test(detectron2_config, model) 41 | print(f"Process {comm.get_rank()} has finished evaluation. Results: {results}") 42 | if detectron2_config.TEST.AUG.ENABLED: 43 | raise NotImplementedError 44 | if comm.is_main_process(): 45 | print("Results verification by the main process has started") 46 | verify_results(detectron2_config, results) 47 | print("Results verification has finished") 48 | 49 | df_global = pd.DataFrame.from_dict(results["segm"], orient="index").T 50 | wandb.log({"df_global": wandb.Table(dataframe=df_global)}) 51 | wandb.run.summary["score"] = df_global["AR100"].item() 52 | 53 | 54 | @hydra.main(config_path="../../configs", config_name="vis_eval_sam_pt", version_base="1.1") 55 | def main(cfg: DictConfig) -> None: 56 | print(OmegaConf.to_yaml(cfg)) 57 | OmegaConf.resolve(cfg) 58 | OmegaConf.set_readonly(cfg, True) 59 | launch( 60 | main_inner, 61 | num_gpus_per_machine=cfg.num_gpus_per_machine, 62 | num_machines=cfg.num_machines, 63 | machine_rank=cfg.machine_rank, 64 | dist_url=cfg.dist_url, 65 | args=(cfg,), 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import add_maskformer2_config 2 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # Taken from: https://github.com/facebookresearch/Mask2Former/blob/9b0651c6c1d5b3af2e6da0589b719c514ec0d69a/mask2former/config.py 4 | 5 | from detectron2.config import CfgNode as CN 6 | 7 | 8 | def add_maskformer2_config(cfg): 9 | """ 10 | Add config for MASK_FORMER. 11 | """ 12 | # NOTE: configs from original maskformer 13 | # data config 14 | # select the dataset mapper 15 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 16 | # Color augmentation 17 | cfg.INPUT.COLOR_AUG_SSD = False 18 | # We retry random cropping until no single category in semantic segmentation GT occupies more 19 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 20 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 21 | # Pad image and segmentation GT in dataset mapper. 22 | cfg.INPUT.SIZE_DIVISIBILITY = -1 23 | 24 | # solver config 25 | # weight decay on embedding 26 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 27 | # optimizer 28 | cfg.SOLVER.OPTIMIZER = "ADAMW" 29 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 30 | 31 | # mask_former model config 32 | cfg.MODEL.MASK_FORMER = CN() 33 | 34 | # loss 35 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True 36 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 37 | cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0 38 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 39 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 40 | 41 | # transformer config 42 | cfg.MODEL.MASK_FORMER.NHEADS = 8 43 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 44 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 45 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 46 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 47 | cfg.MODEL.MASK_FORMER.PRE_NORM = False 48 | 49 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 50 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 51 | 52 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" 53 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False 54 | 55 | # mask_former inference config 56 | cfg.MODEL.MASK_FORMER.TEST = CN() 57 | cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True 58 | cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False 59 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False 60 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 61 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 62 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False 63 | 64 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) 65 | # you can use this config to override 66 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 67 | 68 | # pixel decoder config 69 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 70 | # adding transformer in pixel decoder 71 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 72 | # pixel decoder 73 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" 74 | 75 | # swin transformer backbone 76 | cfg.MODEL.SWIN = CN() 77 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 78 | cfg.MODEL.SWIN.PATCH_SIZE = 4 79 | cfg.MODEL.SWIN.EMBED_DIM = 96 80 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 81 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 82 | cfg.MODEL.SWIN.WINDOW_SIZE = 7 83 | cfg.MODEL.SWIN.MLP_RATIO = 4.0 84 | cfg.MODEL.SWIN.QKV_BIAS = True 85 | cfg.MODEL.SWIN.QK_SCALE = None 86 | cfg.MODEL.SWIN.DROP_RATE = 0.0 87 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 88 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 89 | cfg.MODEL.SWIN.APE = False 90 | cfg.MODEL.SWIN.PATCH_NORM = True 91 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 92 | cfg.MODEL.SWIN.USE_CHECKPOINT = False 93 | 94 | # NOTE: maskformer2 extra configs 95 | # transformer module 96 | cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder" 97 | 98 | # LSJ aug 99 | cfg.INPUT.IMAGE_SIZE = 1024 100 | cfg.INPUT.MIN_SCALE = 0.1 101 | cfg.INPUT.MAX_SCALE = 2.0 102 | 103 | # MSDeformAttn encoder configs 104 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"] 105 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4 106 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8 107 | 108 | # point loss configs 109 | # Number of points sampled during training for a mask point head. 110 | cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112 111 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the 112 | # original paper. 113 | cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0 114 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in 115 | # the original paper. 116 | cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75 117 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/README.md: -------------------------------------------------------------------------------- 1 | # License 2 | 3 | This directory contains code adapted from the [Mask2Former](https://github.com/facebookresearch/Mask2Former/tree/9b0651c6c1d5b3af2e6da0589b719c514ec0d69a) project by Facebook Research, which was released under the MIT license as follows: 4 | 5 | ```txt 6 | Copyright (c) 2022 Meta, Inc. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | ``` 26 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # config 4 | from .config import add_maskformer2_video_config 5 | 6 | # video 7 | from .data_video import ( 8 | YTVISDatasetMapper, 9 | YTVISEvaluator, 10 | build_detection_train_loader, 11 | build_detection_test_loader, 12 | get_detection_dataset_dicts, 13 | ) 14 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | from detectron2.config import CfgNode as CN 4 | 5 | 6 | def add_maskformer2_video_config(cfg): 7 | # video data 8 | # DataLoader 9 | cfg.INPUT.SAMPLING_FRAME_NUM = 2 10 | cfg.INPUT.SAMPLING_FRAME_RANGE = 20 11 | cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False 12 | cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation" 13 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC 3 | 4 | from .dataset_mapper import YTVISDatasetMapper, CocoClipDatasetMapper 5 | from .build import * 6 | 7 | from .datasets import * 8 | from .ytvis_eval import YTVISEvaluator 9 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC 3 | 4 | import logging 5 | import sys 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from detectron2.data import transforms as T 10 | from fvcore.transforms.transform import ( 11 | HFlipTransform, 12 | NoOpTransform, 13 | VFlipTransform, 14 | ) 15 | 16 | 17 | class ResizeShortestEdge(T.Augmentation): 18 | """ 19 | Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge. 20 | If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. 21 | """ 22 | 23 | def __init__( 24 | self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR, clip_frame_cnt=1 25 | ): 26 | """ 27 | Args: 28 | short_edge_length (list[int]): If ``sample_style=="range"``, 29 | a [min, max] interval from which to sample the shortest edge length. 30 | If ``sample_style=="choice"``, a list of shortest edge lengths to sample from. 31 | max_size (int): maximum allowed longest edge length. 32 | sample_style (str): either "range" or "choice". 33 | """ 34 | super().__init__() 35 | assert sample_style in ["range", "choice", "range_by_clip", "choice_by_clip"], sample_style 36 | 37 | self.is_range = ("range" in sample_style) 38 | if isinstance(short_edge_length, int): 39 | short_edge_length = (short_edge_length, short_edge_length) 40 | if self.is_range: 41 | assert len(short_edge_length) == 2, ( 42 | "short_edge_length must be two values using 'range' sample style." 43 | f" Got {short_edge_length}!" 44 | ) 45 | self._cnt = 0 46 | self._init(locals()) 47 | self.inerp = interp 48 | 49 | def get_transform(self, image): 50 | if self._cnt % self.clip_frame_cnt == 0: 51 | if self.is_range: 52 | self.size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1) 53 | else: 54 | self.size = np.random.choice(self.short_edge_length) 55 | if self.size == 0: 56 | return NoOpTransform() 57 | 58 | self._cnt = 0 # avoiding overflow 59 | self._cnt += 1 60 | 61 | h, w = image.shape[:2] 62 | 63 | scale = self.size * 1.0 / min(h, w) 64 | if h < w: 65 | newh, neww = self.size, scale * w 66 | else: 67 | newh, neww = scale * h, self.size 68 | if max(newh, neww) > self.max_size: 69 | scale = self.max_size * 1.0 / max(newh, neww) 70 | newh = newh * scale 71 | neww = neww * scale 72 | neww = int(neww + 0.5) 73 | newh = int(newh + 0.5) 74 | return T.ResizeTransform(h, w, newh, neww, self.interp) 75 | 76 | 77 | class RandomFlip(T.Augmentation): 78 | """ 79 | Flip the image horizontally or vertically with the given probability. 80 | """ 81 | 82 | def __init__(self, prob=0.5, *, horizontal=True, vertical=False, clip_frame_cnt=1): 83 | """ 84 | Args: 85 | prob (float): probability of flip. 86 | horizontal (boolean): whether to apply horizontal flipping 87 | vertical (boolean): whether to apply vertical flipping 88 | """ 89 | super().__init__() 90 | 91 | if horizontal and vertical: 92 | raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") 93 | if not horizontal and not vertical: 94 | raise ValueError("At least one of horiz or vert has to be True!") 95 | self._cnt = 0 96 | 97 | self._init(locals()) 98 | 99 | def get_transform(self, image): 100 | if self._cnt % self.clip_frame_cnt == 0: 101 | self.do = self._rand_range() < self.prob 102 | self._cnt = 0 # avoiding overflow 103 | self._cnt += 1 104 | 105 | h, w = image.shape[:2] 106 | 107 | if self.do: 108 | if self.horizontal: 109 | return HFlipTransform(w) 110 | elif self.vertical: 111 | return VFlipTransform(h) 112 | else: 113 | return NoOpTransform() 114 | 115 | 116 | def build_augmentation(cfg, is_train): 117 | logger = logging.getLogger(__name__) 118 | aug_list = [] 119 | if is_train: 120 | # Crop 121 | if cfg.INPUT.CROP.ENABLED: 122 | aug_list.append(T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) 123 | 124 | # Resize 125 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 126 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 127 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 128 | ms_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM if "by_clip" in cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else 1 129 | aug_list.append(ResizeShortestEdge(min_size, max_size, sample_style, clip_frame_cnt=ms_clip_frame_cnt)) 130 | 131 | # Flip 132 | if cfg.INPUT.RANDOM_FLIP != "none": 133 | if cfg.INPUT.RANDOM_FLIP == "flip_by_clip": 134 | flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM 135 | else: 136 | flip_clip_frame_cnt = 1 137 | 138 | aug_list.append( 139 | # NOTE using RandomFlip modified for the support of flip maintenance 140 | RandomFlip( 141 | horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"), 142 | vertical=cfg.INPUT.RANDOM_FLIP == "vertical", 143 | clip_frame_cnt=flip_clip_frame_cnt, 144 | ) 145 | ) 146 | 147 | # Additional augmentations : brightness, contrast, saturation, rotation 148 | augmentations = cfg.INPUT.AUGMENTATIONS 149 | if "brightness" in augmentations: 150 | aug_list.append(T.RandomBrightness(0.9, 1.1)) 151 | if "contrast" in augmentations: 152 | aug_list.append(T.RandomContrast(0.9, 1.1)) 153 | if "saturation" in augmentations: 154 | aug_list.append(T.RandomSaturation(0.9, 1.1)) 155 | if "rotation" in augmentations: 156 | aug_list.append( 157 | T.RandomRotation( 158 | [-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], sample_style="range" 159 | ) 160 | ) 161 | else: 162 | # Resize 163 | min_size = cfg.INPUT.MIN_SIZE_TEST 164 | max_size = cfg.INPUT.MAX_SIZE_TEST 165 | sample_style = "choice" 166 | aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 167 | 168 | return aug_list 169 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC 3 | 4 | from . import builtin # ensure the builtin datasets are registered 5 | 6 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 7 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC 3 | 4 | import os 5 | 6 | from .uvo import ( 7 | _get_uvo_v1_instances_meta, 8 | ) 9 | from .ytvis import ( 10 | register_ytvis_instances, 11 | _get_ytvis_2019_instances_meta, 12 | _get_ytvis_2021_instances_meta, 13 | ) 14 | 15 | # ==== Predefined splits for YTVIS 2019 =========== 16 | _PREDEFINED_SPLITS_YTVIS_2019 = { 17 | "ytvis_2019_train": ("ytvis_2019/train/JPEGImages", 18 | "ytvis_2019/train.json"), 19 | "ytvis_2019_val": ("ytvis_2019/valid/JPEGImages", 20 | "ytvis_2019/valid.json"), 21 | "ytvis_2019_test": ("ytvis_2019/test/JPEGImages", 22 | "ytvis_2019/test.json"), 23 | } 24 | 25 | # ==== Predefined splits for YTVIS 2021 =========== 26 | _PREDEFINED_SPLITS_YTVIS_2021 = { 27 | "ytvis_2021_train": ("ytvis_2021/train/JPEGImages", 28 | "ytvis_2021/train/instances.json"), 29 | "ytvis_2021_train_mini": ("ytvis_2021/train/JPEGImages", 30 | "ytvis_2021/train/instances.mini.27.json"), 31 | "ytvis_2021_train_tiny": ( 32 | # cat data/ytvis_2021/train/instances.mini.27.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1]]' > data/ytvis_2021/train/instances.tiny.1.json 33 | "ytvis_2021/train/JPEGImages", 34 | "ytvis_2021/train/instances.tiny.1.json", 35 | ), 36 | "ytvis_2021_val": ("ytvis_2021/valid/JPEGImages", 37 | "ytvis_2021/valid/instances.json"), 38 | "ytvis_2021_val_mini": ("ytvis_2021/valid/JPEGImages", 39 | "ytvis_2021/valid/instances.mini.27.json"), 40 | "ytvis_2021_val_tiny": ("ytvis_2021/valid/JPEGImages", 41 | "ytvis_2021/valid/instances.mini.1.json"), 42 | "ytvis_2021_test": ("ytvis_2021/test/JPEGImages", 43 | "ytvis_2021/test/instances.json"), 44 | } 45 | 46 | _PREDEFINED_SPLITS_UVO_V1 = { 47 | "uvo_v1_train": ("UVOv1.0/uvo_videos_dense_frames/", 48 | "UVOv1.0/VideoDenseSet/UVO_video_train_dense.json"), 49 | "uvo_v1_val": ("UVOv1.0/uvo_videos_dense_frames/", 50 | "UVOv1.0/VideoDenseSet/UVO_video_val_dense.json"), 51 | "uvo_v1_val_tiny": ( 52 | # Contains only 1 video 53 | # Split created using jq: `cat data/UVOv1.0/VideoDenseSet/UVO_video_val_dense.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1,2,3]]' > data/UVOv1.0/VideoDenseSet/UVO_video_val_dense.tiny.1.json` 54 | "UVOv1.0/uvo_videos_dense_frames/", 55 | "UVOv1.0/VideoDenseSet/UVO_video_val_dense.tiny.1.json", 56 | ), 57 | "uvo_v1_test": ("UVOv1.0/uvo_videos_dense_frames/", 58 | "UVOv1.0/VideoDenseSet/UVO_video_test_dense.json"), 59 | } 60 | 61 | _PREDEFINED_SPLITS_UVO_V05 = { 62 | "uvo_v05_train": ("UVOv1.0/uvo_videos_dense_frames/", 63 | "UVOv0.5/VideoDenseSet/UVO_video_train_dense.json"), 64 | "uvo_v05_val": ("UVOv1.0/uvo_videos_dense_frames/", 65 | "UVOv0.5/VideoDenseSet/UVO_video_val_dense.json"), 66 | "uvo_v05_val_tiny": ( 67 | # Contains only 1 video 68 | # Split created using jq: `cat data/UVOv0.5/VideoDenseSet/UVO_video_val_dense.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1,2,3]]' > data/UVOv0.5/VideoDenseSet/UVO_video_val_dense.tiny.1.json` 69 | "UVOv1.0/uvo_videos_dense_frames/", 70 | "UVOv0.5/VideoDenseSet/UVO_video_val_dense.tiny.1.json", 71 | ), 72 | "uvo_v05_test": ("UVOv1.0/uvo_videos_dense_frames/", 73 | "UVOv0.5/VideoDenseSet/UVO_video_test_dense.json"), 74 | } 75 | 76 | 77 | def register_all_ytvis_2019(root): 78 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_YTVIS_2019.items(): 79 | # Assume pre-defined datasets live in `./datasets`. 80 | register_ytvis_instances( 81 | key, 82 | _get_ytvis_2019_instances_meta(), 83 | os.path.join(root, json_file) if "://" not in json_file else json_file, 84 | os.path.join(root, image_root), 85 | ) 86 | 87 | 88 | def register_all_ytvis_2021(root): 89 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_YTVIS_2021.items(): 90 | # Assume pre-defined datasets live in `./datasets`. 91 | register_ytvis_instances( 92 | key, 93 | _get_ytvis_2021_instances_meta(), 94 | os.path.join(root, json_file) if "://" not in json_file else json_file, 95 | os.path.join(root, image_root), 96 | ) 97 | 98 | 99 | def register_all_uvo_v1(_root): 100 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_UVO_V1.items(): 101 | # Assume pre-defined datasets live in `./datasets`. 102 | register_ytvis_instances( 103 | key, 104 | _get_uvo_v1_instances_meta(), 105 | os.path.join(_root, json_file) if "://" not in json_file else json_file, 106 | os.path.join(_root, image_root), 107 | ) 108 | 109 | 110 | def register_all_uvo_v05(_root): 111 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_UVO_V05.items(): 112 | # Assume pre-defined datasets live in `./datasets`. 113 | register_ytvis_instances( 114 | key, 115 | _get_uvo_v1_instances_meta(), 116 | os.path.join(_root, json_file) if "://" not in json_file else json_file, 117 | os.path.join(_root, image_root), 118 | ) 119 | 120 | 121 | if __name__.endswith(".builtin"): 122 | # Assume pre-defined datasets live in `./datasets`. 123 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 124 | register_all_ytvis_2019(_root) 125 | register_all_ytvis_2021(_root) 126 | register_all_uvo_v1(_root) 127 | register_all_uvo_v05(_root) 128 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/datasets/uvo.py: -------------------------------------------------------------------------------- 1 | UVO_CATEGORIES_V1_CLASS_AGNOSTIC = [ 2 | {"color": [106, 0, 228], "isthing": 1, "id": 1, "name": "object"}, 3 | ] 4 | 5 | 6 | def _get_uvo_v1_instances_meta(): 7 | thing_ids = [k["id"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1] 8 | assert len(thing_ids) == 1, len(thing_ids) 9 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 10 | thing_classes = [k["name"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1] 11 | thing_colors = [k["color"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1] 12 | ret = { 13 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 14 | "thing_classes": thing_classes, 15 | "thing_colors": thing_colors, 16 | } 17 | return ret 18 | -------------------------------------------------------------------------------- /sam_pt/vis_eval/mask2former_video/data_video/datasets/ytvis_api/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/youtubevos/cocoapi 3 | -------------------------------------------------------------------------------- /sam_pt/vos_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vos_eval/__init__.py -------------------------------------------------------------------------------- /sam_pt/vos_eval/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vos_eval/data/__init__.py -------------------------------------------------------------------------------- /sam_pt/vos_eval/data/mask_mapper.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/mask_mapper.py 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def all_to_onehot(masks, labels): 8 | if len(masks.shape) == 3: 9 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 10 | else: 11 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 12 | 13 | for ni, l in enumerate(labels): 14 | Ms[ni] = (masks == l).astype(np.uint8) 15 | 16 | return Ms 17 | 18 | 19 | class MaskMapper: 20 | """ 21 | This class is used to convert a indexed-mask to a one-hot representation. 22 | It also takes care of remapping non-continuous indices 23 | It has two modes: 24 | 1. Default. Only masks with new indices are supposed to go into the remapper. 25 | This is also the case for YouTubeVOS. 26 | i.e., regions with index 0 are not "background", but "don't care". 27 | 28 | 2. Exhaustive. Regions with index 0 are considered "background". 29 | Every single pixel is considered to be "labeled". 30 | """ 31 | 32 | def __init__(self): 33 | self.labels = [] 34 | self.remappings = {} 35 | 36 | # if coherent, no mapping is required 37 | self.coherent = True 38 | 39 | def convert_mask(self, mask, exhaustive=False, dtype=np.uint8, old_labels_allowed=False): 40 | # mask is in index representation, H*W numpy array 41 | labels = np.unique(mask).astype(dtype) 42 | labels = labels[labels != 0].tolist() 43 | 44 | new_labels = list(set(labels) - set(self.labels)) 45 | if not exhaustive and not old_labels_allowed: 46 | assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' 47 | 48 | # add new remappings 49 | for i, l in enumerate(new_labels): 50 | self.remappings[l] = i + len(self.labels) + 1 51 | if self.coherent and i + len(self.labels) + 1 != l: 52 | self.coherent = False 53 | 54 | if exhaustive: 55 | new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1) 56 | else: 57 | if self.coherent: 58 | new_mapped_labels = new_labels 59 | else: 60 | new_mapped_labels = range(len(self.labels) + 1, len(self.labels) + len(new_labels) + 1) 61 | 62 | self.labels.extend(new_labels) 63 | mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() 64 | 65 | # mask num_objects*H*W 66 | return mask, new_mapped_labels 67 | 68 | def remap_index_mask(self, mask): 69 | # mask is in index representation, H*W numpy array 70 | if self.coherent: 71 | return mask 72 | 73 | new_mask = np.zeros_like(mask) 74 | for l, i in self.remappings.items(): 75 | new_mask[mask == i] = l 76 | return new_mask 77 | -------------------------------------------------------------------------------- /sam_pt/vos_eval/data/test_datasets.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/test_datasets.py 2 | 3 | import json 4 | import os 5 | from os import path 6 | 7 | import numpy as np 8 | 9 | from .video_reader import VideoReader 10 | 11 | 12 | class LongTestDataset: 13 | def __init__(self, data_root, size=-1, longest_size=None): 14 | self.image_dir = path.join(data_root, 'JPEGImages') 15 | self.mask_dir = path.join(data_root, 'Annotations') 16 | self.size = size 17 | self.longest_size = longest_size 18 | 19 | self.vid_list = sorted(os.listdir(self.image_dir)) 20 | 21 | def get_datasets(self): 22 | for video in self.vid_list: 23 | yield VideoReader(video, 24 | path.join(self.image_dir, video), 25 | path.join(self.mask_dir, video), 26 | to_save=[ 27 | name[:-4] for name in os.listdir(path.join(self.mask_dir, video)) 28 | ], 29 | shortest_size=self.size, 30 | longest_size=self.longest_size) 31 | 32 | def __len__(self): 33 | return len(self.vid_list) 34 | 35 | 36 | class DAVISTestDataset: 37 | def __init__(self, data_root, imset='2017/val.txt', size=-1, longest_size=None, return_all_gt_masks=False): 38 | if size != 480: 39 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution') 40 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution') 41 | if not path.exists(self.image_dir): 42 | print(f'{self.image_dir} not found. Look at other options.') 43 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p') 44 | self.mask_dir = path.join(data_root, 'Annotations', '1080p') 45 | assert path.exists(self.image_dir), 'path not found' 46 | else: 47 | self.image_dir = path.join(data_root, 'JPEGImages', '480p') 48 | self.mask_dir = path.join(data_root, 'Annotations', '480p') 49 | self.size_dir = path.join(data_root, 'JPEGImages', '480p') 50 | self.size = size 51 | self.longest_size = longest_size 52 | self.return_all_gt_masks = return_all_gt_masks 53 | 54 | with open(path.join(data_root, 'ImageSets', imset)) as f: 55 | self.vid_list = sorted([line.strip() for line in f]) 56 | 57 | def get_datasets(self): 58 | for video in self.vid_list: 59 | yield VideoReader(video, 60 | path.join(self.image_dir, video), 61 | path.join(self.mask_dir, video), 62 | shortest_size=self.size, 63 | longest_size=self.longest_size, 64 | size_dir=path.join(self.size_dir, video), 65 | use_all_mask=self.return_all_gt_masks) 66 | 67 | def __len__(self): 68 | return len(self.vid_list) 69 | 70 | 71 | class YouTubeVOSTestDataset: 72 | def __init__(self, data_root, split, size=480, longest_size=None): 73 | self.image_dir = path.join(data_root, 'all_frames', split + '_all_frames', 'JPEGImages') 74 | self.mask_dir = path.join(data_root, split, 'Annotations') 75 | self.size = size 76 | self.longest_size = longest_size 77 | 78 | self.vid_list = sorted(os.listdir(self.image_dir)) 79 | self.req_frame_list = {} 80 | 81 | with open(path.join(data_root, split, 'meta.json')) as f: 82 | # read meta.json to know which frame is required for evaluation 83 | meta = json.load(f)['videos'] 84 | 85 | for vid in self.vid_list: 86 | req_frames = [] 87 | objects = meta[vid]['objects'] 88 | for value in objects.values(): 89 | req_frames.extend(value['frames']) 90 | 91 | req_frames = list(set(req_frames)) 92 | self.req_frame_list[vid] = req_frames 93 | 94 | def get_datasets(self): 95 | for video in self.vid_list: 96 | yield VideoReader(video, 97 | path.join(self.image_dir, video), 98 | path.join(self.mask_dir, video), 99 | shortest_size=self.size, 100 | longest_size=self.longest_size, 101 | to_save=self.req_frame_list[video], 102 | use_all_mask=True) 103 | 104 | def __len__(self): 105 | return len(self.vid_list) 106 | 107 | 108 | class MOSETestDataset: 109 | def __init__(self, data_root, split, shortest_size=-1, longest_size=None): 110 | if split == "val": 111 | split = "valid" 112 | 113 | self.shortest_size = shortest_size 114 | self.longest_size = longest_size 115 | 116 | self.image_dir = path.abspath(path.join(data_root, split, 'JPEGImages')) 117 | self.mask_dir = path.abspath(path.join(data_root, split, 'Annotations')) 118 | 119 | print(f'MOSE-{split}: {self.image_dir}') 120 | print(f'MOSE-{split}: {self.mask_dir}') 121 | assert path.exists(self.image_dir) 122 | assert path.exists(self.mask_dir) 123 | 124 | self.vid_list = sorted(os.listdir(self.image_dir)) 125 | print(f'MOSE-{split}: Found {len(self.vid_list)} videos in {self.image_dir}') 126 | 127 | def get_datasets(self): 128 | for video in self.vid_list: 129 | yield VideoReader( 130 | vid_name=video, 131 | image_dir=path.join(self.image_dir, video), 132 | mask_dir=path.join(self.mask_dir, video), 133 | shortest_size=self.shortest_size, 134 | longest_size=self.longest_size, 135 | use_all_mask=True, 136 | ) 137 | 138 | def __len__(self): 139 | return len(self.vid_list) 140 | 141 | 142 | class BDD100KTestDataset: 143 | def __init__(self, data_root, split, shortest_size=-1, longest_size=None): 144 | self.shortest_size = shortest_size 145 | self.longest_size = longest_size 146 | 147 | self.image_dir = path.abspath(path.join(data_root, split, 'JPEGImages')) 148 | self.mask_dir = path.abspath(path.join(data_root, split, 'Annotations')) 149 | 150 | print(f'BDD100K-{split}: {self.image_dir}') 151 | print(f'BDD100K-{split}: {self.mask_dir}') 152 | assert path.exists(self.image_dir) 153 | assert path.exists(self.mask_dir) 154 | 155 | self.vid_list = sorted(os.listdir(self.image_dir)) 156 | print(f'BDD100K-{split}: Found {len(self.vid_list)} videos in {self.image_dir}') 157 | 158 | def get_datasets(self): 159 | for video in self.vid_list: 160 | yield VideoReader( 161 | vid_name=video, 162 | image_dir=path.join(self.image_dir, video), 163 | mask_dir=path.join(self.mask_dir, video), 164 | shortest_size=self.shortest_size, 165 | longest_size=self.longest_size, 166 | use_all_mask=True, 167 | # mask_mode='I;16', 168 | # mask_dtype=np.int32, 169 | ) 170 | 171 | def __len__(self): 172 | return len(self.vid_list) 173 | -------------------------------------------------------------------------------- /sam_pt/vos_eval/data/video_reader.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/video_reader.py 2 | 3 | import os 4 | from os import path 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | from segment_anything.utils.transforms import ResizeLongestSide 10 | from torch.utils.data.dataset import Dataset 11 | from torchvision import transforms 12 | from torchvision.transforms import InterpolationMode 13 | 14 | 15 | class VideoReader(Dataset): 16 | """ 17 | This class is used to read a video, one frame at a time 18 | """ 19 | 20 | def __init__(self, vid_name, image_dir, mask_dir, 21 | shortest_size=-1, longest_size=None, 22 | to_save=None, use_all_mask=False, size_dir=None, 23 | mask_mode='P', mask_dtype=np.uint8, 24 | ): 25 | """ 26 | image_dir - points to a directory of jpg images 27 | mask_dir - points to a directory of png masks 28 | size - resize min. side to size. Does nothing if <0. 29 | to_save - optionally contains a list of file names without extensions 30 | where the segmentation mask is required 31 | use_all_mask - when true, read all available mask in mask_dir. 32 | Default false. Set to true for YouTubeVOS validation. 33 | """ 34 | assert shortest_size == -1 or longest_size is None, 'One size constraint should be given, not both.' 35 | 36 | self.vid_name = vid_name 37 | self.image_dir = image_dir 38 | self.mask_dir = mask_dir 39 | self.to_save = to_save 40 | self.use_all_mask = use_all_mask 41 | if size_dir is None: 42 | self.size_dir = self.image_dir 43 | else: 44 | self.size_dir = size_dir 45 | 46 | self.mask_mode = mask_mode 47 | self.mask_dtype = mask_dtype 48 | 49 | self.frames = sorted(os.listdir(self.image_dir)) 50 | self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() 51 | self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) 52 | 53 | # TODO SegGPT specific 54 | if shortest_size == "seggpt": 55 | shortest_size = (448, 448) 56 | 57 | self.shortest_size = shortest_size 58 | self.longest_size = longest_size 59 | 60 | # TODO: Model specific transforms are hardcoded here 61 | if self.shortest_size == -1 and self.longest_size is None: 62 | self.resize_longest_side_transform = None 63 | self.im_transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | ]) 66 | elif self.shortest_size != -1: 67 | self.resize_longest_side_transform = None 68 | self.im_transform = transforms.Compose([ 69 | transforms.ToTensor(), 70 | transforms.Resize(self.shortest_size, interpolation=InterpolationMode.BILINEAR), 71 | ]) 72 | elif self.longest_size is not None: 73 | self.resize_longest_side_transform = ResizeLongestSide(self.longest_size) 74 | self.im_transform = transforms.Compose([ 75 | transforms.ToTensor(), 76 | ]) 77 | else: 78 | raise RuntimeError('Invalid size constraints.') 79 | 80 | def __getitem__(self, idx): 81 | frame = self.frames[idx] 82 | info = {} 83 | data = {} 84 | info['frame'] = frame 85 | info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) 86 | 87 | im_path = path.join(self.image_dir, frame) 88 | img = Image.open(im_path).convert('RGB') 89 | 90 | if self.image_dir == self.size_dir: 91 | shape = np.array(img).shape[:2] 92 | else: 93 | size_path = path.join(self.size_dir, frame) 94 | size_im = Image.open(size_path).convert('RGB') 95 | shape = np.array(size_im).shape[:2] 96 | 97 | gt_path = path.join(self.mask_dir, frame[:-4] + '.png') 98 | if self.resize_longest_side_transform is not None: 99 | img = np.array(img) 100 | img = self.resize_longest_side_transform.apply_image(img) 101 | 102 | img = self.im_transform(img) 103 | 104 | load_mask = self.use_all_mask or (gt_path == self.first_gt_path) 105 | if load_mask and path.exists(gt_path): 106 | mask = Image.open(gt_path).convert(self.mask_mode) 107 | mask = np.array(mask, dtype=self.mask_dtype) 108 | data['mask'] = mask 109 | 110 | info['shape'] = shape 111 | info['need_resize'] = self.shortest_size != 0 or self.longest_size is not None 112 | data['rgb'] = img 113 | data['info'] = info 114 | 115 | # TODO: SegGPT specific 116 | if self.shortest_size == (448, 448): 117 | info['shape'] = (448, 448) 118 | 119 | return data 120 | 121 | def resize_mask(self, mask): 122 | # mask transform is applied AFTER mapper, so we need to post-process it in eval.py 123 | old_h, old_w = mask.shape[-2:] 124 | if self.resize_longest_side_transform is None: 125 | min_hw = min(old_h, old_w) 126 | if self.shortest_size == (448, 448): 127 | # TODO SegGPT specific 128 | shape = (448, 448) 129 | else: 130 | shape = (int(old_h / min_hw * self.shortest_size), int(old_w / min_hw * self.shortest_size)) 131 | else: 132 | shape = ResizeLongestSide.get_preprocess_shape(old_h, old_w, self.longest_size) 133 | return F.interpolate(mask, shape, mode='nearest') 134 | 135 | def get_palette(self): 136 | return self.palette 137 | 138 | def __len__(self): 139 | return len(self.frames) 140 | -------------------------------------------------------------------------------- /sam_pt/vos_eval/davis2017eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is a modified version of the original DAVIS 2017 evaluation script from: 3 | https://github.com/davisvideochallenge/davis2017-evaluation/blob/ac7c43fca936f9722837b7fbd337d284ba37004b/evaluation_method.py 4 | 5 | Usage: 6 | ``` 7 | python -m sam_pt.vos_eval.davis2017eval \ 8 | --results_path /srv/beegfs02/scratch/visobt4s/data/3d_point_tracking/sampt_outputs/SegGPT--D17-val--in-sampt-env_D17_val_72_2023.11.09_15.52.53/eval_D17_val \ 9 | --davis_path data/DAVIS/2017/trainval \ 10 | --set val \ 11 | --task semi-supervised \ 12 | --year 2017 13 | ``` 14 | """ 15 | 16 | import argparse 17 | import os 18 | import sys 19 | from time import time 20 | from typing import Union 21 | 22 | import numpy as np 23 | import pandas as pd 24 | from davis2017.evaluation import DAVISEvaluation 25 | 26 | 27 | class Davis2017Evaluator: 28 | def __init__(self, results_path: str, davis_path: str, set: str = "val", task: str = "semi-unsupervised", 29 | year: str = '2017', sequences: Union[str, list] = "all", ): 30 | """ 31 | :param results_path: Path to the folder containing the sequences folders. 32 | :param davis_path: Path to the DAVIS folder containing the `JPEGImages`, `Annotations`, `ImageSets`, 33 | `Annotations_unsupervised` folders. 34 | :param set: Subset to evaluate the results. 35 | :param task: Task to evaluate the results. 36 | :param year: DAVIS dataset year. 37 | :param sequences: List of sequences to evaluate. If "all", evaluate all sequences. 38 | """ 39 | assert set in ['val', 'test-dev', 'test-challenge'] 40 | assert task in ['semi-supervised', 'unsupervised'] 41 | 42 | self.davis_path = davis_path 43 | self.set = set 44 | self.task = task 45 | self.year = year 46 | self.sequences = sequences 47 | self.results_path = results_path 48 | 49 | def evaluate(self): 50 | time_start = time() 51 | csv_name_global = f'global_results-{self.set}.csv' 52 | csv_name_per_sequence = f'per-sequence_results-{self.set}.csv' 53 | 54 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results 55 | csv_name_global_path = os.path.join(self.results_path, csv_name_global) 56 | csv_name_per_sequence_path = os.path.join(self.results_path, csv_name_per_sequence) 57 | if os.path.exists(csv_name_global_path) and os.path.exists(csv_name_per_sequence_path): 58 | print('Using precomputed results...') 59 | table_g = pd.read_csv(csv_name_global_path) 60 | table_seq = pd.read_csv(csv_name_per_sequence_path) 61 | else: 62 | print(f'Evaluating sequences for the {self.task} task...') 63 | # Create dataset and evaluate 64 | dataset_eval = DAVISEvaluation(davis_root=self.davis_path, task=self.task, gt_set=self.set, year=self.year, 65 | sequences=self.sequences) 66 | metrics_res = dataset_eval.evaluate(self.results_path) 67 | J, F = metrics_res['J'], metrics_res['F'] 68 | 69 | # Generate dataframe for the general results 70 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 71 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 72 | g_res = np.array( 73 | [final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), 74 | np.mean(F["D"])]) 75 | g_res = np.reshape(g_res, [1, len(g_res)]) 76 | table_g = pd.DataFrame(data=g_res, columns=g_measures) 77 | with open(csv_name_global_path, 'w') as f: 78 | table_g.to_csv(f, index=False, float_format="%.3f") 79 | print(f'Global results saved in {csv_name_global_path}') 80 | 81 | # Generate a dataframe for the per sequence results 82 | seq_names = list(J['M_per_object'].keys()) 83 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 84 | J_per_object = [J['M_per_object'][x] for x in seq_names] 85 | F_per_object = [F['M_per_object'][x] for x in seq_names] 86 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures) 87 | with open(csv_name_per_sequence_path, 'w') as f: 88 | table_seq.to_csv(f, index=False, float_format="%.3f") 89 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}') 90 | 91 | # Print the results 92 | sys.stdout.write(f"--------------------------- Global results for {self.set} ---------------------------\n") 93 | print(table_g.to_string(index=False)) 94 | sys.stdout.write(f"\n---------- Per sequence results for {self.set} ----------\n") 95 | print(table_seq.to_string(index=False)) 96 | total_time = time() - time_start 97 | sys.stdout.write('\nTotal time:' + str(total_time)) 98 | 99 | return table_g, table_seq 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser(description='Evaluate a method on the DAVIS 2017 dataset') 104 | parser.add_argument('--results_path', type=str, required=True, 105 | help='Path to the folder containing the sequences folders.') 106 | parser.add_argument('--davis_path', type=str, required=True, 107 | help='Path to the DAVIS folder containing the `JPEGImages`, `Annotations`, `ImageSets`, ' 108 | '`Annotations_unsupervised` folders.') 109 | parser.add_argument('--set', type=str, default='val', choices=['val', 'test-dev', 'test-challenge'], 110 | help='Subset to evaluate the results.') 111 | parser.add_argument('--eval_only_on_the_sequences_present_in_the_results', action='store_true', 112 | help='If True, evaluate only on the sequences present in the results folder.') 113 | parser.add_argument('--task', type=str, default='semi-supervised', choices=['semi-supervised', 'unsupervised'], 114 | help='Task to evaluate the results.') 115 | parser.add_argument("--year", type=str, help="Davis dataset year (default: 2017)", default='2017', 116 | choices=['2016', '2017', '2019']) 117 | 118 | args = parser.parse_args() 119 | 120 | sequences = 'all' 121 | if args.eval_only_on_the_sequences_present_in_the_results: 122 | assert os.path.exists(args.results_path) 123 | sequences = sorted(os.listdir(args.results_path)) 124 | sequences = [s for s in sequences if s != "overlapping" and "." not in s] 125 | print(f"Evaluating only on the sequences present in the results folder: {sequences}") 126 | 127 | evaluator = Davis2017Evaluator(args.results_path, args.davis_path, args.set, args.task, args.year, sequences) 128 | evaluator.evaluate() 129 | -------------------------------------------------------------------------------- /sam_pt/vos_eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod, ABC 3 | 4 | from sam_pt.modeling.sam_pt import SamPt 5 | 6 | 7 | class VOSEvaluator(ABC): 8 | """ 9 | Abstract class for evaluating a model on the semi-supervised video object segmentation task. 10 | """ 11 | 12 | def __init__(self, cfg, model): 13 | self.cfg = cfg 14 | self.model = model 15 | 16 | @abstractmethod 17 | def evaluate_video(self, video): 18 | """ 19 | Evaluates model on a video and returns the predictions. 20 | 21 | Parameters 22 | ---------- 23 | video : dict 24 | Dictionary with video data. It includes the following keys: 25 | 'video_name': str - The name of the video. 26 | 'video_id': int - The ID of the video. 27 | 'image': List[torch.Tensor] - The frames of the video as uint8 tensors of shape (channels, height, width) 28 | 'info': List[dict] - Information for each frame, includes keys like 'frame', 'save', 'shape', 'need_resize'. 29 | 'target_hw': Tuple[int, int] - The target height and width for the predicted masks. 30 | 'query_masks': torch.Tensor - The query masks as binary float32 tensor of shape (num_masks, height, width). 31 | 'query_point_timestep': torch.Tensor - The query point timesteps as float32 tensor of shape (num_masks,). 32 | 33 | Returns 34 | ------- 35 | dict 36 | Dictionary with predictions. It includes the following keys: 37 | 'logits': List[torch.Tensor] - The logits as float32 tensors of shape (num_frames, height, width). 38 | 'trajectories': torch.Tensor - The trajectories as float32 tensor 39 | of shape (num_frames, n_masks, n_points_per_mask, 2). 40 | 'visibilities': torch.Tensor - The visibilities as float32 tensor 41 | of shape (num_frames, n_masks, n_points_per_mask). 42 | 'scores': List[float] - The scores as list of 'num_masks' floats. 43 | """ 44 | pass 45 | 46 | 47 | class SamPtEvaluator(VOSEvaluator): 48 | def evaluate_video(self, video): 49 | self.model: SamPt = self.model 50 | device = self.model.device 51 | for k, v in video.items(): 52 | if isinstance(v, torch.Tensor): 53 | video[k] = v.to(device) 54 | outputs = self.model(video) 55 | return { 56 | "logits": outputs["logits"], 57 | "trajectories": outputs["trajectories"], 58 | "visibilities": outputs["visibilities"], 59 | 'scores': outputs['scores'], 60 | } 61 | -------------------------------------------------------------------------------- /scripts/annotation_comparison_gif.py: -------------------------------------------------------------------------------- 1 | """ 2 | python scripts/annotation_comparison_gif.py 3 | """ 4 | 5 | import os 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | 8 | import imageio 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | 13 | def create_gif(results_dir, annotations_dir, images_dir, output_gif_path): 14 | # Get a sorted list of image files and annotation files 15 | result_files = sorted([f for f in os.listdir(results_dir) if f.endswith('.png')]) 16 | images_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg')]) 17 | annotation_files = sorted([f for f in os.listdir(annotations_dir) if f.endswith('.png')]) 18 | 19 | # Check if both folders have the same number of files 20 | assert len(result_files) == len(annotation_files) == len(images_files) 21 | 22 | # Create a list to store concatenated images 23 | concat_images = [] 24 | 25 | for res_file, img_file, ann_file in tqdm(list(zip(result_files, images_files, annotation_files))): 26 | # Open the images 27 | result = Image.open(os.path.join(results_dir, res_file)) 28 | image = Image.open(os.path.join(images_dir, img_file)) 29 | annotation = Image.open(os.path.join(annotations_dir, ann_file)) 30 | 31 | # Make sure the images can be concatenated 32 | assert image.size == annotation.size == result.size, "Image sizes do not match." 33 | 34 | # Concatenate the images vertically 35 | total_height = image.size[1] + annotation.size[1] + result.size[1] 36 | combined_image = Image.new('RGB', (image.size[0], total_height)) 37 | combined_image.paste(image, (0, 0)) 38 | combined_image.paste(annotation, (0, image.size[1])) 39 | combined_image.paste(result, (0, image.size[1] + annotation.size[1])) 40 | 41 | # Add to list of concatenated images 42 | concat_images.append(combined_image) 43 | 44 | # Save the frames as a GIF 45 | imageio.mimsave(output_gif_path, concat_images, duration=0.5, loop=0) 46 | 47 | print(f"GIF created at {output_gif_path}") 48 | 49 | 50 | def create_gif_per_video(video, results_path, annotations_path, images_path): 51 | print(f"Creating GIF for {video}") 52 | result_path = os.path.join(results_path, video) 53 | annotation_path = os.path.join(annotations_path, video) 54 | image_path = os.path.join(images_path, video) 55 | output_gif_path = os.path.join(results_path, video + ".gif") 56 | create_gif(result_path, annotation_path, image_path, output_gif_path) 57 | 58 | 59 | if __name__ == '__main__': 60 | # results_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/K9.000--debug--cotracker-0--1-1024/" 61 | # annotations_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/vos/val/Annotations/" 62 | # images_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/vos/val/JPEGImages/" 63 | 64 | results_path = "outputs/K9.000--debug--cotracker-0--1-1024/eval_BDD100K_val" 65 | annotations_path = "data/bdd100k/vos/val/Annotations" 66 | images_path = "data/bdd100k/vos/val/JPEGImages" 67 | 68 | videos = [video for video in os.listdir(results_path) if not video.endswith(".gif") and not "." in video] 69 | 70 | with ThreadPoolExecutor() as executor: 71 | # Submit all tasks to the executor 72 | future_to_video = { 73 | executor.submit(create_gif_per_video, video, results_path, annotations_path, images_path): video for video 74 | in videos} 75 | 76 | # Process the futures as they complete 77 | for future in tqdm(as_completed(future_to_video), total=len(videos), desc="Processing videos", unit="video"): 78 | video = future_to_video[future] 79 | try: 80 | future.result() 81 | except Exception as exc: 82 | print(f'{video} generated an exception: {exc}') 83 | 84 | print("All GIFs have been created.") 85 | -------------------------------------------------------------------------------- /scripts/bdd100k_from_instance_seg_to_vos_annotations.py: -------------------------------------------------------------------------------- 1 | """ 2 | To create the VOS annotations from the instance segmentation annotations, run: 3 | ```bash 4 | # Prepare directories 5 | mkdir -p data/bdd100k/vos/val/{Annotations,JPEGImages} 6 | 7 | # Copy JPEGImages 8 | cp -r data/bdd100k/images/seg_track_20/val/* data/bdd100k/vos/val/JPEGImages/ 9 | 10 | # Create the Annotations 11 | python -m scripts.bdd100k_from_instance_seg_to_vos_annotations 12 | 13 | # Link the chunks 14 | # e.g., data/bdd100k/vos/val/JPEGImages/b1c66a42-6f7d68ca-chunk2 -> b1c66a42-6f7d68ca/ 15 | find data/bdd100k/vos/val/Annotations -type d -name "*-chunk*" | sed 's/Annotations/JPEGImages/' | while read -r src; do 16 | tgt=$(basename "$src" | sed 's/-chunk.*//') 17 | rm $src 18 | ln -s "$tgt" "$src" 19 | done 20 | ``` 21 | """ 22 | import json 23 | import os 24 | 25 | import math 26 | import numpy as np 27 | import pandas as pd 28 | from PIL import Image 29 | from tqdm import tqdm 30 | from tqdm.contrib.concurrent import process_map 31 | 32 | np.random.seed(72) 33 | palette = (np.multiply(np.random.rand(768), 255).astype(np.uint8).tolist()) 34 | palette[:3] = [0, 0, 0] 35 | 36 | 37 | def remap_ids(ids): 38 | # Find the unique IDs and their new remapped positions 39 | unique_ids, inverse_indices = np.unique(ids, return_inverse=True) 40 | 41 | # Reshape the inverse_indices to the shape of the original IDs array 42 | remapped_ids = inverse_indices.reshape(ids.shape) 43 | 44 | return remapped_ids 45 | 46 | 47 | def process_video(video_name, objects_per_chunk=100): 48 | print(f"Processing video {video_name}") 49 | frames = sorted(os.listdir(os.path.join(videos_path, video_name))) 50 | bitmasks = [] 51 | for frame_name in frames: 52 | frame_path = os.path.join(videos_path, video_name, frame_name) 53 | bitmask = np.array(Image.open(frame_path)) 54 | bitmasks.append(bitmask) 55 | bitmasks = np.stack(bitmasks) 56 | annotation_ids = (bitmasks[:, :, :, 2].astype(np.uint32) << 8) + bitmasks[:, :, :, 3] 57 | unique_ids = np.unique(annotation_ids).size 58 | print(f"Video {video_name} is loaded, it has {unique_ids} unique objects") 59 | 60 | annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 1 == 0) # Remove ignored instances 61 | annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 2 == 0) # Remove crowd instances 62 | # annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 4 == 0) # Remove occluded instances 63 | # annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 8 == 0) # Remove truncated instances 64 | unique_ids_old = unique_ids 65 | unique_ids = np.unique(annotation_ids).size 66 | print(f"Video {video_name} is filtered by ignored and crowd instances, " 67 | f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now") 68 | 69 | # # Randomly select max_objects objects 70 | # if unique_ids > max_objects: 71 | # np.random.seed(72) 72 | # selected_ids = np.random.choice(np.sort(np.unique(annotation_ids))[1:], max_objects, replace=False) 73 | # annotation_ids = np.where(np.isin(annotation_ids, selected_ids), annotation_ids, 0) 74 | # unique_ids_old = unique_ids 75 | # unique_ids = np.unique(annotation_ids).size 76 | # print(f"Video {video_name} is filtered by max_objects, " 77 | # f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now") 78 | 79 | # # Select the first max_objects objects 80 | # if unique_ids > max_objects: 81 | # selected_ids = np.sort(np.unique(annotation_ids))[1:max_objects + 1] 82 | # annotation_ids = np.where(np.isin(annotation_ids, selected_ids), annotation_ids, 0) 83 | # unique_ids_old = unique_ids 84 | # unique_ids = np.unique(annotation_ids).size 85 | # print(f"Video {video_name} is filtered by max_objects, " 86 | # f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now") 87 | 88 | # Split the objects into chunks of objects_per_chunk objects 89 | annotation_ids_unique = np.unique(annotation_ids)[1:] 90 | for chunk_id in range(math.ceil(annotation_ids_unique.size / objects_per_chunk)): 91 | chunk_name = f"{video_name}-chunk{chunk_id + 1}" if chunk_id > 0 else video_name 92 | chunk = annotation_ids_unique[chunk_id * objects_per_chunk:(chunk_id + 1) * objects_per_chunk] 93 | print(f"Processing {chunk_name}, it has {chunk.size} objects: {chunk}") 94 | 95 | # Select the objects in the chunk 96 | annotation_ids_chunk = np.where(np.isin(annotation_ids, chunk), annotation_ids, 0) 97 | unique_ids = np.unique(annotation_ids_chunk).size 98 | 99 | # Remap annotation IDs to be continuous 100 | remapped_annotation_ids = remap_ids(annotation_ids_chunk) 101 | assert np.unique(remapped_annotation_ids).size == unique_ids 102 | assert np.unique(remapped_annotation_ids).size == remapped_annotation_ids.max() + 1 103 | print(f"Video {video_name} is remapped") 104 | 105 | output_dir = os.path.join(output_path, chunk_name) 106 | os.makedirs(output_dir, exist_ok=True) 107 | assert unique_ids <= 255, "The number of unique objects should be less than 255 to use uint8" 108 | for frame_id, frame_name in enumerate(frames): 109 | x = Image.fromarray(remapped_annotation_ids[frame_id].astype(np.uint8), mode="P") 110 | x.putpalette(palette) 111 | x.save(os.path.join(output_dir, frame_name)) 112 | print(f"Video {video_name} is saved") 113 | 114 | 115 | def sanity_check(output_path, rles_path): 116 | for i, video_json_name in enumerate(tqdm(sorted([vp for vp in os.listdir(rles_path) if vp.endswith("json")]))): 117 | video_name = video_json_name.replace(".json", "") 118 | # if i < 15: 119 | # print(f"Skipping video {video_name}") 120 | # continue 121 | with open(os.path.join(rles_path, video_json_name), "r") as fp: 122 | video = json.load(fp) 123 | df = pd.DataFrame([ 124 | (label["category"], label["id"]) 125 | for frame in video["frames"] 126 | for label in frame["labels"] 127 | ], columns=["cat", "id"]) 128 | assert df[~df.duplicated()].groupby("id").count().max().item() == 1 129 | 130 | annotation_ids = [ 131 | np.array(Image.open(os.path.join(output_path, video_name, frame_name))) 132 | for frame_name in sorted(os.listdir(os.path.join(output_path, video_name))) 133 | ] 134 | annotation_ids = np.stack(annotation_ids) 135 | assert np.unique(annotation_ids).size == annotation_ids.max() + 1 136 | if np.unique(annotation_ids).size != df.id.unique().size + 1: 137 | print(f"Video {video_name} has {np.unique(annotation_ids).size} unique objects, " 138 | f"but RLE has {df.id.unique().size + 1} unique objects") 139 | # breakpoint() 140 | else: 141 | assert np.unique(annotation_ids).size == df.id.unique().size + 1 142 | 143 | print(f"Unique objects for video {i:02d}: {df.id.unique().size}") 144 | 145 | 146 | if __name__ == '__main__': 147 | videos_path = "data/bdd100k/labels/seg_track_20/bitmasks/val" 148 | output_path = "data/bdd100k/vos/val/Annotations" 149 | video_names = sorted([name for name in os.listdir(videos_path) if os.path.isdir(os.path.join(videos_path, name))]) 150 | 151 | # Create the VOS annotations 152 | process_map(process_video, video_names, chunksize=1) 153 | print("Done creating VOS annotations") 154 | 155 | # # Sanity check that the number of objects in the VOS annotations is the same as in the RLEs 156 | # print("Sanity check that the number of objects in the VOS annotations is the same as in the RLEs") 157 | # rles_path = "data/bdd100k/labels/seg_track_20/rles/val" 158 | # sanity_check(output_path, rles_path) 159 | -------------------------------------------------------------------------------- /scripts/clean_tapnet_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script cleans-up the original TapNet checkpoint by removing objects 3 | that require `import tapnet` to work. The cleaned checkpoint saves only 4 | the weights and removes the optimizer state. The cleaned checkpoint can 5 | be used within SAM-PT. 6 | 7 | Note that we provide a link to the cleaned checkpoint in the 8 | documentation and that you might not need to run this script yourself. 9 | 10 | Usage: 11 | 1. Clone the [TapNet repository](https://github.com/deepmind/tapnet) and 12 | checkout the commit `ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1`. 13 | 2. Setup the TapNet environment. 14 | 3. Run this script one level above the TapNet repository (i.e., not 15 | within the TapNet repository, but within its parent directory). For 16 | that, navigate to the parent directory of TapNet repository (`cd ..`) 17 | and set the PYTHONPATH environment variable 18 | (```export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH```). 19 | 20 | Run the script from the command line with the following arguments: 21 | - --input: The path to the original TapNet checkpoint file. 22 | - --output: The path where the cleaned checkpoint file will be saved. 23 | 24 | For example: 25 | ```bash 26 | python script_name.py \ 27 | --input "./models/tapnet_ckpts/open_source_ckpt/checkpoint.npy" \ 28 | --output "./models/tapnet_ckpts/open_source_ckpt/checkpoint_wo_optstate.npy" 29 | ``` 30 | """ 31 | 32 | import argparse 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | 37 | def clean_checkpoint(input_path, output_path): 38 | # Load the original checkpoint file. 39 | checkpoint = np.load(input_path, allow_pickle=True).item() 40 | 41 | print(checkpoint.keys()) 42 | # dict_keys(['params', 'state', 'opt_state', 'global_step']) 43 | 44 | # Create a new dictionary without the 'opt_state' and 'global_step'. 45 | checkpoint_wo_optstate = { 46 | "params": checkpoint["params"], 47 | "state": checkpoint["state"], 48 | } 49 | 50 | # Save the cleaned checkpoint file. 51 | with tf.io.gfile.GFile(output_path, 'wb') as fp: 52 | np.save(fp, checkpoint_wo_optstate) 53 | 54 | 55 | def parse_arguments(): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--input", help="The path to the original TapNet checkpoint file.") 58 | parser.add_argument("--output", help="The path where the cleaned checkpoint file will be saved.") 59 | return parser.parse_args() 60 | 61 | 62 | if __name__ == "__main__": 63 | args = parse_arguments() 64 | clean_checkpoint(args.input, args.output) 65 | -------------------------------------------------------------------------------- /scripts/davis_mask_to_contour.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to convert DAVIS mask annotation images to contour images. 3 | Used to prepare figures for the SAM-PT paper. 4 | Note that paths are hardcoded in the script. 5 | 6 | Usage: `python -m scripts.davis_mask_to_contour` 7 | """ 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | 14 | def davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path, contour_radius=5): 15 | # Open image and convert it to numpy array 16 | print(f"Input image path: {input_image_path}") 17 | image = cv2.imread(input_image_path) 18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 19 | assert image.dtype == np.uint8 20 | assert image.min() >= 0 and image.max() <= 255 21 | plt.imshow(image) 22 | plt.show() 23 | 24 | # The number of masks is the number of unique colors in the image 25 | n_masks = len(np.unique(image.reshape(-1, image.shape[2]), axis=0)) - 1 26 | print(f"Number of masks: {n_masks}") 27 | 28 | # Take each mask separately and create a binary mask, remember the color of each mask 29 | masks = [] 30 | colors = np.unique(image.reshape(-1, image.shape[2]), axis=0) 31 | assert (colors[0] == [0, 0, 0]).all() 32 | colors = colors[1:] 33 | for mask_idx in range(n_masks): 34 | mask = (image == colors[mask_idx][None, None, :]).all(-1) 35 | masks.append(mask) 36 | 37 | # Create a contour mask for each mask 38 | contour_masks = [] 39 | for mask_idx in range(n_masks): 40 | m_8int = masks[mask_idx].astype(np.uint8) 41 | dist_transform_fore = cv2.distanceTransform(m_8int, cv2.DIST_L2, 3) 42 | contour_mask = (dist_transform_fore <= contour_radius) & (dist_transform_fore > 0) 43 | contour_mask = contour_mask.astype(np.uint8) 44 | contour_masks.append(contour_mask) 45 | plt.imshow(contour_mask) 46 | plt.show() 47 | 48 | # Add contour mask to the image 49 | output_image = np.zeros_like(image) 50 | for mask_idx in range(n_masks): 51 | output_image = np.where(contour_masks[mask_idx][:, :, None] == 1, colors[mask_idx][None, None, :], output_image) 52 | 53 | # Plot the image 54 | plt.imshow(output_image) 55 | plt.show() 56 | 57 | # Save the image 58 | print(f"Output image path: {output_image_path}") 59 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 60 | cv2.imwrite(output_image_path, output_image) 61 | 62 | # Save also RGBA image 63 | output_image = cv2.imread(output_image_path) 64 | output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) 65 | r, g, b = cv2.split(output_image) 66 | a = 255 - (output_image == np.array([0, 0, 0])[None, None, :]).all(-1).astype(np.uint8) * 255 67 | output_image = cv2.merge([r, g, b, a], 4) 68 | print(f"RGBA image path: {output_image_path}.rgba.png") 69 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2BGRA) 70 | cv2.imwrite(output_image_path + ".rgba.png", output_image) 71 | print("Done.") 72 | 73 | 74 | if __name__ == '__main__': 75 | for i in [1, 7, 16, 23, 32]: 76 | input_image_path = f"../../04-logs/system-figure/gt--mask-only--frame-{i}--cropped.png" 77 | output_image_path = f"../../04-logs/system-figure/gt--mask-only--contour--frame-{i}--cropped.png" 78 | davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path) 79 | for i in [1, 7, 16, 23, 32]: 80 | input_image_path = f"../../04-logs/system-figure/gt--mask-only--frame-{i}--cropped.png" 81 | output_image_path = f"../../04-logs/system-figure/gt--mask-only--contour--thin--frame-{i}--cropped.png" 82 | davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path, contour_radius=2) 83 | -------------------------------------------------------------------------------- /scripts/uvo_video2frames.py: -------------------------------------------------------------------------------- 1 | """ 2 | A utility script to split UVO videos into frames. 3 | 4 | The script takes two command-line arguments: 5 | 1. --video_dir: The directory containing the videos you wish to split into frames. 6 | 2. --frames_dir: The directory where the frames will be saved. 7 | 8 | Each video in the input directory will be split into frames, and these frames will be stored in a subdirectory of --frames_dir named after the video. 9 | 10 | Usage: 11 | 12 | ```bash 13 | python ../scripts/uvo_video2frames.py --video_dir UVOv1.0/uvo_videos_dense --frames_dir UVOv1.0/uvo_videos_dense_frames 14 | python ../scripts/uvo_video2frames.py --video_dir UVOv1.0/uvo_videos_sparse --frames_dir UVOv1.0/uvo_videos_sparse_frames 15 | ``` 16 | """ 17 | import argparse 18 | import cv2 19 | import os 20 | import pathlib 21 | from tqdm import tqdm 22 | 23 | 24 | def split_single_video(video_path, frames_dir=""): 25 | cap = cv2.VideoCapture(video_path) 26 | cnt = 0 27 | while cap.isOpened(): 28 | ret, frame = cap.read() 29 | if ret: 30 | success, buffer = cv2.imencode(".png", frame) 31 | if success: 32 | with open(f"{frames_dir}{cnt}.png", "wb") as f: 33 | f.write(buffer.tobytes()) 34 | f.flush() 35 | cnt += 1 36 | else: 37 | break 38 | return cnt 39 | 40 | 41 | def get_parser(): 42 | arg_parser = argparse.ArgumentParser() 43 | arg_parser.add_argument("--video_dir", type=str, default="NonPublic/uvo_videos_dense/") 44 | arg_parser.add_argument("--frames_dir", type=str, default="NonPublic/uvo_videos_dense_frames/") 45 | return arg_parser 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = get_parser() 50 | args = parser.parse_args() 51 | video_paths = os.listdir(args.video_dir) 52 | print(f"Splitting videos in {args.video_dir} to frames in {args.frames_dir}...") 53 | print(f"Total number of videos: {len(video_paths)}") 54 | for video_path in tqdm(video_paths): 55 | print(f"Splitting {video_path}...") 56 | v_frame_dir = pathlib.Path(os.path.join(args.frames_dir, video_path[:-4])) 57 | if not v_frame_dir.is_dir(): 58 | v_frame_dir.mkdir(parents=True, exist_ok=False) 59 | n_frames = split_single_video(os.path.join(args.video_dir, video_path), frames_dir=v_frame_dir) 60 | print(f"Total number of frames extracted from {video_path}: {n_frames}") 61 | print(f"Done.") 62 | -------------------------------------------------------------------------------- /scripts/visualize_point_sampling_methods.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to visualize the different point sampling methods for the SAM-PT paper. 3 | 4 | Usage: `python -m scripts.visualize_point_sampling_methods` 5 | """ 6 | import argparse 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from functools import partial 12 | 13 | from sam_pt.utils.query_points import extract_corner_points 14 | from sam_pt.utils.query_points import extract_kmedoid_points 15 | from sam_pt.utils.query_points import extract_mixed_points 16 | from sam_pt.utils.query_points import extract_random_mask_points 17 | from sam_pt.utils.util import seed_all 18 | 19 | 20 | def mixed_point_id_to_marker_and_rescale(n_points, point_id): 21 | n_kmedoid = n_points // 4 22 | n_shi_tomasi = n_points // 3 23 | if point_id < n_kmedoid: 24 | return "o", 1 25 | elif point_id < n_kmedoid + n_shi_tomasi: 26 | return "*", 3 27 | else: 28 | return "v", 1.2 29 | 30 | 31 | def visualize_point_sampling_methods( 32 | rgb_image_path, 33 | annotation_image_path, 34 | output_image_path, 35 | point_sampling_method_name="kmedoids", 36 | n_points=8, 37 | seed=72, 38 | ): 39 | # Open image and convert it to numpy array 40 | image = cv2.imread(rgb_image_path) 41 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 42 | assert image.dtype == np.uint8 43 | assert image.min() >= 0 and image.max() <= 255 44 | plt.imshow(image) 45 | plt.show() 46 | 47 | annotation_image = cv2.imread(annotation_image_path) 48 | annotation_image = cv2.cvtColor(annotation_image, cv2.COLOR_BGR2RGB) 49 | assert annotation_image.dtype == np.uint8 50 | assert annotation_image.min() >= 0 and annotation_image.max() <= 255 51 | plt.imshow(annotation_image) 52 | plt.show() 53 | 54 | # The number of masks is the number of unique colors in the image 55 | n_masks = len(np.unique(annotation_image.reshape(-1, annotation_image.shape[2]), axis=0)) - 1 56 | print(f"Number of masks: {n_masks}") 57 | 58 | # Prepare the point sampling methods 59 | point_sampling_methods = { 60 | "kmedoids": { 61 | "function": extract_kmedoid_points, 62 | "marker": ["o" for _ in range(n_points)], 63 | "rescale": [1 for _ in range(n_points)], 64 | }, 65 | "shi-tomasi": { 66 | "function": partial(extract_corner_points, image=torch.from_numpy(image).permute(2, 0, 1)), 67 | "marker": ["*" for _ in range(n_points)], 68 | "rescale": [3 for _ in range(n_points)], 69 | }, 70 | "random": { 71 | "function": extract_random_mask_points, 72 | "marker": ["v" for _ in range(n_points)], 73 | "rescale": [1.2 for _ in range(n_points)] 74 | }, 75 | "mixed": { 76 | "function": lambda mask, n_points_to_select: extract_mixed_points( 77 | query_masks=mask[None, ...], 78 | query_points_timestep=torch.zeros(n_masks), 79 | images=torch.from_numpy(image).permute(2, 0, 1)[None, ...], 80 | n_points=n_points_to_select, 81 | )[0], 82 | "marker": [mixed_point_id_to_marker_and_rescale(n_points, point_id)[0] for point_id in range(n_points)], 83 | "rescale": [mixed_point_id_to_marker_and_rescale(n_points, point_id)[1] for point_id in range(n_points)] 84 | }, 85 | } 86 | 87 | # Take each mask separately and create a binary mask, remember the color of each mask 88 | masks = [] 89 | colors = np.unique(annotation_image.reshape(-1, annotation_image.shape[2]), axis=0) 90 | assert (colors[0] == [0, 0, 0]).all() 91 | colors = colors[1:] 92 | for mask_idx in range(n_masks): 93 | mask = (annotation_image == colors[mask_idx][None, None, :]).all(-1) 94 | masks.append(mask) 95 | 96 | # Sample points from each mask 97 | mask_points = [] 98 | for mask_idx in range(n_masks): 99 | seed_all(seed + 3) 100 | mask = torch.from_numpy(masks[mask_idx]).bool() 101 | points = point_sampling_methods[point_sampling_method_name]["function"](mask=mask, n_points_to_select=n_points) 102 | mask_points.append(points) 103 | 104 | # Create a contour mask for each mask 105 | contour_radius = 3 106 | contour_masks = [] 107 | for mask_idx in range(n_masks): 108 | m_8int = masks[mask_idx].astype(np.uint8) 109 | dist_transform_fore = cv2.distanceTransform(m_8int, cv2.DIST_L2, 3) 110 | contour_mask = (dist_transform_fore <= contour_radius) & (dist_transform_fore > 0) 111 | contour_mask = contour_mask.astype(np.uint8) 112 | contour_masks.append(contour_mask) 113 | 114 | # Add contour and sampled points to the image 115 | output_image = np.zeros_like(annotation_image) 116 | for mask_idx in range(n_masks): 117 | output_image = np.where(contour_masks[mask_idx][:, :, None] == 1, colors[mask_idx][None, None, :], output_image) 118 | h, w, dpi = output_image.shape[0], output_image.shape[1], 100 119 | plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi) 120 | plt.imshow(output_image) 121 | for mask_idx in range(n_masks): 122 | for point_idx in range(n_points): 123 | plt.scatter( 124 | x=mask_points[mask_idx][point_idx, 0], 125 | y=mask_points[mask_idx][point_idx, 1], 126 | s=90 * point_sampling_methods[point_sampling_method_name]["rescale"][point_idx], 127 | c=(colors[mask_idx][None, :] * 1.8 / 255).clip(min=0, max=1), 128 | linewidths=0, 129 | marker=point_sampling_methods[point_sampling_method_name]["marker"][point_idx] 130 | ) 131 | plt.axis("off") 132 | plt.tight_layout(pad=0) 133 | print(f"Output image path: {output_image_path}") 134 | plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) 135 | plt.show() 136 | 137 | # Save also RGBA image 138 | output_image = cv2.imread(output_image_path) 139 | output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) 140 | r, g, b = cv2.split(output_image) 141 | a = 255 - (output_image == np.array([0, 0, 0])[None, None, :]).all(-1).astype(np.uint8) * 255 142 | output_image = cv2.merge([r, g, b, a], 4) 143 | print(f"RGBA image path: {output_image_path}.rgba.png") 144 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2BGRA) 145 | cv2.imwrite(output_image_path + ".rgba.png", output_image) 146 | print("Done.") 147 | 148 | 149 | def main(args): 150 | n_points = args.n_points 151 | for psm in args.point_sampling_methods: 152 | output_image_path = f"{args.output_path_prefix}--point-sampling-method-{psm}.png" 153 | visualize_point_sampling_methods( 154 | rgb_image_path=args.rgb_path, 155 | annotation_image_path=args.annotation_path, 156 | output_image_path=output_image_path, 157 | point_sampling_method_name=psm, 158 | n_points=n_points, 159 | seed=args.seed, 160 | ) 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--n_points', type=int, default=8) 166 | parser.add_argument('--rgb_path', type=str, 167 | default="../../04-logs/system-figure/horse-input--frame-16--cropped.png") 168 | parser.add_argument('--annotation_path', type=str, 169 | default="../../04-logs/system-figure/gt--mask-only--frame-16--cropped.png") 170 | parser.add_argument('--output_path_prefix', type=str, 171 | default="../../04-logs/system-figure/gt--mask-only--frame-16--cropped") 172 | parser.add_argument('--point_sampling_methods', type=str, nargs='+', 173 | default=["kmedoids", "shi-tomasi", "random", "mixed"]) 174 | parser.add_argument('--seed', type=int, default=72) 175 | args = parser.parse_args() 176 | main(args) 177 | --------------------------------------------------------------------------------