├── .gitignore
├── LICENSE
├── README.md
├── assets
├── avatar.gif
├── bees.gif
├── figure-1.png
├── horsejump-high.gif
├── interactive-camel.gif
├── interactive-drift-straight.gif
├── interactive-loading.gif
└── street.gif
├── configs
├── demo.yaml
├── logging
│ ├── base.yaml
│ ├── vis_eval.yaml
│ ├── vos_eval.yaml
│ └── wandb
│ │ └── base.yaml
├── model
│ ├── point_tracker
│ │ ├── cotracker.yaml
│ │ ├── pips.yaml
│ │ ├── pips_plus_plus.yaml
│ │ ├── raft.yaml
│ │ ├── superglue.yaml
│ │ ├── tapir.yaml
│ │ └── tapnet.yaml
│ ├── sam
│ │ ├── image_encoder
│ │ │ ├── vit_base.yaml
│ │ │ ├── vit_huge.yaml
│ │ │ └── vit_large.yaml
│ │ ├── mask_decoder
│ │ │ └── sam.yaml
│ │ ├── prompt_encoder
│ │ │ └── sam.yaml
│ │ ├── sam_mobile_vit_tiny.yaml
│ │ ├── sam_vit_base.yaml
│ │ ├── sam_vit_huge.yaml
│ │ ├── sam_vit_large.yaml
│ │ ├── samhq_light_vit_tiny.yaml
│ │ └── samhq_vit_huge.yaml
│ └── sam_pt.yaml
├── vis_eval_root.yaml
├── vis_eval_sam_pt.yaml
└── vos_eval_root.yaml
├── data
└── demo_data
│ ├── README.md
│ ├── bees.mp4
│ ├── query_points__bees.txt
│ ├── query_points__street.txt
│ └── street.mp4
├── demo
├── __init__.py
└── demo.py
├── docs
├── 01-getting-started.md
├── 02-prepare-datasets.md
├── 03-prepare-checkpoints.md
└── 04-running-experiments.md
├── requirements-jax.txt
├── requirements.txt
├── sam_pt
├── __init__.py
├── modeling
│ ├── __init__.py
│ ├── sam.py
│ ├── sam_pt.py
│ ├── sam_pt_interactive.py
│ └── vis_to_vos_adapter.py
├── point_tracker
│ ├── __init__.py
│ ├── cotracker
│ │ ├── __init__.py
│ │ └── tracker.py
│ ├── pips
│ │ ├── __init__.py
│ │ ├── pips.py
│ │ └── tracker.py
│ ├── pips_plus_plus
│ │ ├── __init__.py
│ │ ├── pips_plus_plus.py
│ │ └── tracker.py
│ ├── raft
│ │ ├── __init__.py
│ │ ├── raft_core
│ │ │ ├── __init__.py
│ │ │ ├── corr.py
│ │ │ ├── extractor.py
│ │ │ ├── raft.py
│ │ │ ├── update.py
│ │ │ └── util.py
│ │ ├── raftnet.py
│ │ └── tracker.py
│ ├── superglue
│ │ ├── __init__.py
│ │ ├── match_pairs.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── matching.py
│ │ │ ├── superglue.py
│ │ │ ├── superpoint.py
│ │ │ └── utils.py
│ │ └── tracker.py
│ ├── tapir
│ │ ├── __init__.py
│ │ ├── configs
│ │ │ └── tapir_config.py
│ │ ├── demo.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── resnet.py
│ │ ├── tapir_model.py
│ │ ├── tracker.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── model_utils.py
│ │ │ └── transforms.py
│ ├── tapnet
│ │ ├── __init__.py
│ │ ├── configs
│ │ │ └── tapnet_config.py
│ │ ├── demo.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── tsm_resnet.py
│ │ │ └── tsm_utils.py
│ │ ├── tapnet_model.py
│ │ ├── tracker.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ └── transforms.py
│ ├── tracker.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── basic.py
│ │ ├── improc.py
│ │ ├── misc.py
│ │ ├── samp.py
│ │ ├── saverloader.py
│ │ └── test.py
├── utils
│ ├── __init__.py
│ ├── query_points.py
│ └── util.py
├── vis_eval
│ ├── __init__.py
│ ├── eval.py
│ ├── mask2former
│ │ ├── __init__.py
│ │ └── config.py
│ ├── mask2former_video
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── data_video
│ │ │ ├── __init__.py
│ │ │ ├── augmentation.py
│ │ │ ├── build.py
│ │ │ ├── dataset_mapper.py
│ │ │ ├── datasets
│ │ │ ├── __init__.py
│ │ │ ├── builtin.py
│ │ │ ├── uvo.py
│ │ │ ├── ytvis.py
│ │ │ └── ytvis_api
│ │ │ │ ├── __init__.py
│ │ │ │ ├── ytvos.py
│ │ │ │ └── ytvoseval.py
│ │ │ └── ytvis_eval.py
│ └── train_net_video.py
└── vos_eval
│ ├── __init__.py
│ ├── bdd100keval.py
│ ├── data
│ ├── __init__.py
│ ├── mask_mapper.py
│ ├── test_datasets.py
│ └── video_reader.py
│ ├── davis2017eval.py
│ ├── eval.py
│ └── evaluator.py
└── scripts
├── annotation_comparison_gif.py
├── bdd100k_from_instance_seg_to_vos_annotations.py
├── clean_tapnet_checkpoint.py
├── davis_mask_to_contour.py
├── uvo_video2frames.py
└── visualize_point_sampling_methods.py
/.gitignore:
--------------------------------------------------------------------------------
1 | artifacts
2 | /logs
3 | /wandb
4 | /data/**
5 | !/data/demo_data
6 | !/data/demo_data/README.md
7 | !/data/demo_data/bees.mp4
8 | !/data/demo_data/street.mp4
9 | !/data/demo_data/query_points__bees.txt
10 | !/data/demo_data/query_points__street.txt
11 | /outputs
12 | /output
13 | instant_test_output
14 | inference_test_output
15 | experiments
16 |
17 | *.png
18 | *.json
19 | *.diff
20 | *.jpg
21 | !/projects/DensePose/doc/images/*.jpg
22 |
23 | # compilation and distribution
24 | __pycache__
25 | _ext
26 | *.pyc
27 | *.pyd
28 | *.so
29 | *.dll
30 | *.egg-info/
31 | build/
32 | dist/
33 | wheels/
34 |
35 | # pytorch/python/numpy formats
36 | *.pth
37 | *.pkl
38 | *.npy
39 | *.ts
40 | model_ts*.txt
41 |
42 | # ipython/jupyter notebooks
43 | *.ipynb
44 | **/.ipynb_checkpoints/
45 |
46 | # Editor temporaries
47 | *.swn
48 | *.swo
49 | *.swp
50 | *~
51 |
52 | # editor settings
53 | .idea
54 | .vscode
55 | _darcs
56 |
57 | # project dirs
58 | /detectron2/model_zoo/configs
59 | /datasets/*
60 | !/datasets/*.*
61 | /projects/*/datasets
62 | /models
63 | /snippet
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Segment Anything Meets Point Tracking
2 |
3 | > [**Segment Anything Meets Point Tracking**](https://arxiv.org/abs/2307.01197) \
4 | > [Frano Rajič](https://m43.github.io/), [Lei Ke](http://www.kelei.site/), [Yu-Wing Tai](https://yuwingtai.github.io/), [Chi-Keung Tang](http://home.cse.ust.hk/~cktang/bio.html), [Martin Danelljan](https://martin-danelljan.github.io/), [Fisher Yu](https://www.yf.io/) \
5 | > ETH Zürich, HKUST, EPFL
6 |
7 |
8 | 
9 |
10 | We propose SAM-PT, an extension of the [Segment Anything Model](https://github.com/facebookresearch/segment-anything) (SAM) for zero-shot video segmentation. Our work offers a simple yet effective point-based perspective in video object segmentation research. For more details, refer to our paper.
11 |
12 | ## Video Object Segmentation Demo
13 |
14 | Annotators only provide a few points to denote the target object at the first video frame to get video segmentation results. Please visit our [project page](https://www.vis.xyz/pub/sam-pt/) for more visualizations, including qualitative results on DAVIS 2017 videos and more Avatar clips.
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## Interactive Point-Based Video Segmentation
23 |
24 | Annotators can interactively add or remove points to refine the segmentation results.
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Documentation
32 |
33 | Explore our step-by-step guides to get up and running:
34 |
35 | 1. [Getting Started](./docs/01-getting-started.md): Learn how to set up your environment and run the demo.
36 | 2. [Prepare Datasets](./docs/02-prepare-datasets.md): Instructions on acquiring and prepping necessary datasets.
37 | 3. [Prepare Checkpoints](./docs/03-prepare-checkpoints.md): Steps to fetch model checkpoints.
38 | 4. [Running Experiments](./docs/04-running-experiments.md): Details on how to execute experiments.
39 |
40 | ## Acknowledgments
41 |
42 | We want to thank [SAM](https://github.com/facebookresearch/segment-anything), [PIPS](https://github.com/aharley/pips), [CoTracker](https://github.com/facebookresearch/co-tracker), [HQ-SAM](https://github.com/SysCV/sam-hq), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [XMem](https://github.com/hkchengrex/XMem), and [Mask2Former](https://github.com/facebookresearch/Mask2Former) for publicly releasing their code and pretrained models.
43 |
44 | ## Citation
45 |
46 | If you find SAM-PT useful in your research or if you refer to the results mentioned in our work, please star :star: this repository and consider citing :pencil::
47 | ```bibtex
48 | @article{sam-pt,
49 | title = {Segment Anything Meets Point Tracking},
50 | author = {Rajič, Frano and Ke, Lei and Tai, Yu-Wing and Tang, Chi-Keung and Danelljan, Martin and Yu, Fisher},
51 | journal = {arXiv:2307.01197},
52 | year = {2023}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/assets/avatar.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/avatar.gif
--------------------------------------------------------------------------------
/assets/bees.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/bees.gif
--------------------------------------------------------------------------------
/assets/figure-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/figure-1.png
--------------------------------------------------------------------------------
/assets/horsejump-high.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/horsejump-high.gif
--------------------------------------------------------------------------------
/assets/interactive-camel.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-camel.gif
--------------------------------------------------------------------------------
/assets/interactive-drift-straight.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-drift-straight.gif
--------------------------------------------------------------------------------
/assets/interactive-loading.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/interactive-loading.gif
--------------------------------------------------------------------------------
/assets/street.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/assets/street.gif
--------------------------------------------------------------------------------
/configs/demo.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: sam_pt
3 | - logging: base
4 | - _self_
5 |
6 | logging:
7 | wandb:
8 | project: demo
9 |
10 | model:
11 | iterative_refinement_iterations: 12
12 | add_other_objects_positive_points_as_negative_points: true
13 | use_point_reinit: false
14 | positive_points_per_mask: -1
15 | negative_points_per_mask: -1
16 |
17 | frames_path: ${hydra:runtime.cwd}/data/demo_data/bees # Path to the folder with frames of the video
18 | query_points_path: ${hydra:runtime.cwd}/data/demo_data/query_points__bees.txt # Path or null
19 |
20 | longest_side_length: 1024 # Resize the image so that the longest side is of this length
21 | frame_stride: 1 # Evaluate on every n frames
22 | max_frames: null # Maximum number of video frames to evaluate for
23 |
24 | seed: 72
25 |
26 | annot_size: 16 # Size of the point annotations in visualisations
27 | annot_line_width: 6 # Line width of the point annotations in visualisations
--------------------------------------------------------------------------------
/configs/logging/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - wandb: base
3 |
4 | debug: false
5 | exp_id: debug
6 |
--------------------------------------------------------------------------------
/configs/logging/vis_eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - wandb: base
3 |
4 | exp_id: eval
5 | wandb:
6 | project: point-tracking-for-vis
7 |
--------------------------------------------------------------------------------
/configs/logging/vos_eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - wandb: base
3 |
4 | exp_id: eval
5 | wandb:
6 | project: point-tracking-for-vos
7 |
--------------------------------------------------------------------------------
/configs/logging/wandb/base.yaml:
--------------------------------------------------------------------------------
1 | entity: null
2 | project: ???
3 | tensorboard: true
4 | log_code_path: ${hydra:runtime.cwd}/sam_pt
5 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/cotracker.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.cotracker.CoTrackerPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_4_wind_8.pth"
3 | #checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_4_wind_12.pth"
4 | #checkpoint_path: "${hydra:runtime.cwd}/models/cotracker_ckpts/cotracker_stride_8_wind_16.pth"
5 |
6 | interp_shape: [384, 512]
7 | visibility_threshold: 0.7
8 | support_grid_size: 2
9 | support_grid_every_n_frames: 12
10 |
11 | add_debug_visualisations: false
12 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/pips.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.pips.PipsPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/pips_ckpts/reference_model"
3 | stride: 4
4 | s: 8
5 | initial_next_frame_visibility_threshold: 0.9
6 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/pips_plus_plus.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.pips_plus_plus.PipsPlusPlusPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/pips_plus_plus_ckpts/reference_model"
3 | stride: 8
4 | max_sequence_length: 128
5 | iters: 16
6 | image_size: null # [ 512, 896 ]
7 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/raft.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.raft.RaftPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/raft_ckpts/raft-things.pth"
3 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/superglue.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.superglue.SuperGluePointTracker
2 |
3 | positive_points_per_mask: ${..positive_points_per_mask}
4 | negative_points_per_mask: ${..negative_points_per_mask}
5 |
6 | #resize: [ 640, 480 ]
7 | resize: [ -1, -1 ]
8 |
9 | matching_config:
10 | superpoint:
11 | checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superpoint_v1.pth
12 | nms_radius: 3
13 | keypoint_threshold: 0.005
14 | max_keypoints: -1
15 | descriptor_dim: 256
16 | remove_borders: 4
17 | superglue:
18 | #checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superglue_indoor.pth
19 | checkpoint: ${hydra:runtime.cwd}/models/superglue_ckpts/superglue_outdoor.pth
20 | sinkhorn_iterations: 20
21 | match_threshold: 0.2
22 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/tapir.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.tapir.TapirPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/tapir_ckpts/open_source_ckpt/tapir_checkpoint_panning.npy"
3 | visibility_threshold: 0.1
4 |
--------------------------------------------------------------------------------
/configs/model/point_tracker/tapnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.point_tracker.tapnet.TapnetPointTracker
2 | checkpoint_path: "${hydra:runtime.cwd}/models/tapnet_ckpts/open_source_ckpt/checkpoint_wo_optstate.npy"
3 | visibility_threshold: 0.5
4 |
--------------------------------------------------------------------------------
/configs/model/sam/image_encoder/vit_base.yaml:
--------------------------------------------------------------------------------
1 | _target_: segment_anything.modeling.image_encoder.ImageEncoderViT
2 | depth: 12
3 | embed_dim: 768
4 | img_size: ${ ..image_size }
5 | mlp_ratio: 4
6 | norm_layer:
7 | _partial_: true
8 | _target_: torch.nn.LayerNorm
9 | eps: 1e-6
10 | num_heads: 12
11 | patch_size: ${ ..vit_patch_size }
12 | qkv_bias: True
13 | use_rel_pos: True
14 | global_attn_indexes: [ 2, 5, 8, 11 ]
15 | window_size: 14
16 | out_chans: ${ ..prompt_embed_dim }
17 |
--------------------------------------------------------------------------------
/configs/model/sam/image_encoder/vit_huge.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - vit_base
3 | depth: 32
4 | embed_dim: 1280
5 | num_heads: 16
6 | global_attn_indexes: [ 7, 15, 23, 31 ]
7 |
--------------------------------------------------------------------------------
/configs/model/sam/image_encoder/vit_large.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - vit_base
3 | depth: 24
4 | embed_dim: 1024
5 | num_heads: 16
6 | global_attn_indexes: [ 5, 11, 17, 23 ]
7 |
--------------------------------------------------------------------------------
/configs/model/sam/mask_decoder/sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: segment_anything.modeling.mask_decoder.MaskDecoder
2 | num_multimask_outputs: 3
3 | transformer:
4 | _target_: segment_anything.modeling.transformer.TwoWayTransformer
5 | depth: 2
6 | embedding_dim: ${ ...prompt_embed_dim }
7 | mlp_dim: 2048
8 | num_heads: 8
9 | transformer_dim: ${ ..prompt_embed_dim }
10 | iou_head_depth: 3
11 | iou_head_hidden_dim: 256
12 |
--------------------------------------------------------------------------------
/configs/model/sam/prompt_encoder/sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: segment_anything.modeling.prompt_encoder.PromptEncoder
2 | embed_dim: ${ ..prompt_embed_dim }
3 | image_embedding_size:
4 | - ${ ...image_embedding_size }
5 | - ${ ...image_embedding_size }
6 | input_image_size:
7 | - ${ ...image_size }
8 | - ${ ...image_size }
9 | mask_in_chans: 16
10 |
--------------------------------------------------------------------------------
/configs/model/sam/sam_mobile_vit_tiny.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - prompt_encoder: sam
3 | - mask_decoder: sam
4 | - _self_
5 |
6 | _target_: sam_pt.modeling.sam.MobileSamHydra
7 |
8 | checkpoint: ${hydra:runtime.cwd}/models/sam_mobile_ckpts/sam_mobile_vit_t.pth
9 |
10 | prompt_embed_dim: 256
11 | image_size: 1024
12 | vit_patch_size: 16
13 | image_embedding_size: 64
14 |
15 | pixel_mean: [ 123.675, 116.28, 103.53 ]
16 | pixel_std: [ 58.395, 57.12, 57.375 ]
17 |
18 | image_encoder:
19 | _target_: mobile_sam.modeling.TinyViT
20 | img_size: ${..image_size}
21 | in_chans: 3
22 | num_classes: 1000
23 | embed_dims: [ 64, 128, 160, 320 ]
24 | depths: [ 2, 2, 6, 2 ]
25 | num_heads: [ 2, 4, 5, 10 ]
26 | window_sizes: [ 7, 7, 14, 7 ]
27 | mlp_ratio: 4.
28 | drop_rate: 0.
29 | drop_path_rate: 0.0
30 | use_checkpoint: False
31 | mbconv_expand_ratio: 4.0
32 | local_conv_size: 3
33 | layer_lr_decay: 0.8
34 |
35 | prompt_encoder:
36 | _target_: mobile_sam.modeling.prompt_encoder.PromptEncoder
37 |
38 | mask_decoder:
39 | _target_: mobile_sam.modeling.mask_decoder.MaskDecoder
40 | transformer:
41 | _target_: mobile_sam.modeling.transformer.TwoWayTransformer
42 |
--------------------------------------------------------------------------------
/configs/model/sam/sam_vit_base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - image_encoder: vit_base
3 | - prompt_encoder: sam
4 | - mask_decoder: sam
5 |
6 | _target_: sam_pt.modeling.sam.SamHydra
7 |
8 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_b_01ec64.pth
9 |
10 | prompt_embed_dim: 256
11 | image_size: 1024
12 | vit_patch_size: 16
13 | image_embedding_size: 64
14 |
15 | pixel_mean: [ 123.675, 116.28, 103.53 ]
16 | pixel_std: [ 58.395, 57.12, 57.375 ]
17 |
--------------------------------------------------------------------------------
/configs/model/sam/sam_vit_huge.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - sam_vit_base
3 | - override image_encoder: vit_huge
4 |
5 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_h_4b8939.pth
6 |
--------------------------------------------------------------------------------
/configs/model/sam/sam_vit_large.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - sam_vit_base
3 | - override image_encoder: vit_large
4 |
5 | checkpoint: ${hydra:runtime.cwd}/models/sam_ckpts/sam_vit_l_0b3195.pth
6 |
--------------------------------------------------------------------------------
/configs/model/sam/samhq_light_vit_tiny.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - prompt_encoder: sam
3 | - mask_decoder: sam
4 | - _self_
5 |
6 | _target_: sam_pt.modeling.sam.SamHQHydra
7 |
8 | checkpoint: ${hydra:runtime.cwd}/models/samhq_ckpts/sam_hq_vit_t.pth
9 |
10 | prompt_embed_dim: 256
11 | image_size: 1024
12 | vit_patch_size: 16
13 | image_embedding_size: 64
14 |
15 | pixel_mean: [ 123.675, 116.28, 103.53 ]
16 | pixel_std: [ 58.395, 57.12, 57.375 ]
17 |
18 | image_encoder:
19 | _target_: segment_anything_hq.modeling.TinyViT
20 | img_size: ${..image_size}
21 | in_chans: 3
22 | num_classes: 1000
23 | embed_dims: [ 64, 128, 160, 320 ]
24 | depths: [ 2, 2, 6, 2 ]
25 | num_heads: [ 2, 4, 5, 10 ]
26 | window_sizes: [ 7, 7, 14, 7 ]
27 | mlp_ratio: 4.
28 | drop_rate: 0.
29 | drop_path_rate: 0.0
30 | use_checkpoint: False
31 | mbconv_expand_ratio: 4.0
32 | local_conv_size: 3
33 | layer_lr_decay: 0.8
34 |
35 | prompt_encoder:
36 | _target_: segment_anything_hq.modeling.prompt_encoder.PromptEncoder
37 |
38 | mask_decoder:
39 | _target_: segment_anything_hq.modeling.mask_decoder_hq.MaskDecoderHQ
40 | vit_dim: 160
41 |
--------------------------------------------------------------------------------
/configs/model/sam/samhq_vit_huge.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - image_encoder: vit_huge
3 | - prompt_encoder: sam
4 | - mask_decoder: sam
5 | - _self_
6 |
7 | _target_: sam_pt.modeling.sam.SamHQHydra
8 |
9 | checkpoint: ${hydra:runtime.cwd}/models/samhq_ckpts/sam_hq_vit_h.pth
10 |
11 | prompt_embed_dim: 256
12 | image_size: 1024
13 | vit_patch_size: 16
14 | image_embedding_size: 64
15 |
16 | pixel_mean: [ 123.675, 116.28, 103.53 ]
17 | pixel_std: [ 58.395, 57.12, 57.375 ]
18 |
19 | image_encoder:
20 | _target_: segment_anything_hq.modeling.image_encoder.ImageEncoderViT
21 |
22 | prompt_encoder:
23 | _target_: segment_anything_hq.modeling.prompt_encoder.PromptEncoder
24 |
25 | mask_decoder:
26 | _target_: segment_anything_hq.modeling.mask_decoder_hq.MaskDecoderHQ
27 | vit_dim: ${..image_encoder.embed_dim}
28 |
--------------------------------------------------------------------------------
/configs/model/sam_pt.yaml:
--------------------------------------------------------------------------------
1 | _target_: sam_pt.modeling.sam_pt.SamPt
2 |
3 | defaults:
4 | - point_tracker: cotracker
5 | - sam@sam_predictor.sam_model: samhq_vit_huge
6 |
7 | sam_predictor:
8 | _target_: segment_anything_hq.predictor.SamPredictor
9 |
10 | sam_iou_threshold: 0.7
11 |
12 | iterative_refinement_iterations: 12
13 |
14 | positive_point_selection_method: "kmedoids" # kmedoids, shi-tomasi, random, mixed
15 | negative_point_selection_method: "mixed" # kmedoids, shi-tomasi, random, mixed
16 | positive_points_per_mask: 16
17 | negative_points_per_mask: 1
18 | add_other_objects_positive_points_as_negative_points: true
19 | max_other_objects_positive_points: null
20 |
21 | point_tracker_mask_batch_size: 5
22 |
23 | use_patch_matching_filtering: false
24 | patch_size: 3
25 | patch_similarity_threshold: 0.01
26 |
27 | use_point_reinit: false
28 | reinit_point_tracker_horizon: 24
29 | reinit_horizon: 24
30 | reinit_variant: "reinit-at-median-of-area-diff"
31 | # Reinitialization variants:
32 | # A) reinit-on-horizon-and-sync-masks:
33 | # - simplest variant: reinitialize the points after a fixed number of
34 | # frames (e.g., every 8 frames) can fail if the mask happens to be
35 | # empty at the reinitialization timestep
36 | # - as fast as not using reinit
37 | # B) reinit-at-median-of-area-diff:
38 | # - reinitialize points for the non-empty mask with the mean mask area
39 | # - multiple times slower than no reinit, as many sam masks will be
40 | # rejected (e.g., 8 masks were computed, but we might reinit on the
41 | # second one, recomputing the rejected masks in the next step again)
42 | # C) reinit-on-similar-mask-area:
43 | # - reinit when the mask area is similar to the initial mask area
44 | # - multiple times slower than no reinit
45 | # D) reinit-on-similar-mask-area-and-sync-masks:
46 | # - reinit when the mask area is similar to the initial mask area for
47 | # all masks in the batch and synchronize the masks to be tracked from
48 | # the same timestep, as to be able to use negative points from other
49 | # masks when querying sam
50 | # - multiple times slower than no reinit
51 |
--------------------------------------------------------------------------------
/configs/vis_eval_root.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - logging: vis_eval
3 | - _self_
4 | - model/sam@model.sam_generator.model: sam_vit_huge
5 | - model@model.model: ???
6 |
7 | model:
8 | _target_: sam_pt.modeling.vis_to_vos_adapter.SamBasedVisToVosAdapter
9 | max_num_masks: 100
10 | masks_batch_size: 100
11 | visualize_results: true
12 | max_videos_to_visualize: 30
13 | sam_generator:
14 | _target_: segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator
15 | model: ???
16 | points_per_side: 32
17 | points_per_batch: 64
18 | pred_iou_thresh: 0.88
19 | stability_score_thresh: 0.95
20 | stability_score_offset: 1.0
21 | box_nms_thresh: 0.7
22 | crop_n_layers: 0
23 | crop_nms_thresh: 0.7
24 | crop_overlap_ratio: 512 / 1500
25 | crop_n_points_downscale_factor: 1
26 | point_grids: null
27 | min_mask_region_area: 0
28 | output_mode: "binary_mask"
29 |
30 | output: results
31 |
32 | device: cuda
33 | num_gpus_per_machine: 1
34 | num_machines: 1
35 | machine_rank: 0
36 | dist_url: tcp://127.0.0.1:27036
37 |
38 | DETECTRON2_CONFIG:
39 | CUDNN_BENCHMARK: false
40 | DATALOADER:
41 | ASPECT_RATIO_GROUPING: true
42 | FILTER_EMPTY_ANNOTATIONS: false
43 | NUM_WORKERS: 0
44 | REPEAT_THRESHOLD: 0.0
45 | SAMPLER_TRAIN: TrainingSampler
46 | DATASETS:
47 | PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000
48 | PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000
49 | PROPOSAL_FILES_TEST: [ ]
50 | PROPOSAL_FILES_TRAIN: [ ]
51 | TEST:
52 | - uvo_v1_val
53 | TRAIN:
54 | - null
55 | GLOBAL:
56 | HACK: 1.0
57 | INPUT:
58 | AUGMENTATIONS: [ ]
59 | COLOR_AUG_SSD: false
60 | CROP:
61 | ENABLED: false
62 | SINGLE_CATEGORY_MAX_AREA: 1.0
63 | SIZE:
64 | - 600
65 | - 720
66 | TYPE: absolute_range
67 | DATASET_MAPPER_NAME: mask_former_semantic
68 | FORMAT: RGB
69 | IMAGE_SIZE: 1024
70 | MASK_FORMAT: polygon
71 | MAX_SCALE: 2.0
72 | MAX_SIZE_TEST: 1333
73 | MAX_SIZE_TRAIN: 1333
74 | MIN_SCALE: 0.1
75 | MIN_SIZE_TEST: 360
76 | MIN_SIZE_TRAIN:
77 | - 360
78 | - 480
79 | MIN_SIZE_TRAIN_SAMPLING: choice_by_clip
80 | RANDOM_FLIP: flip_by_clip
81 | SAMPLING_FRAME_NUM: 2
82 | SAMPLING_FRAME_RANGE: 20
83 | SAMPLING_FRAME_SHUFFLE: false
84 | SIZE_DIVISIBILITY: -1
85 | MODEL:
86 | MASK_ON: false
87 | SEM_SEG_HEAD:
88 | NUM_CLASSES: 54
89 | LOAD_PROPOSALS: false
90 | OUTPUT_DIR: ${ output }
91 | SEED: -1
92 | TEST:
93 | AUG:
94 | ENABLED: false
95 | FLIP: true
96 | MAX_SIZE: 4000
97 | MIN_SIZES:
98 | - 400
99 | - 500
100 | - 600
101 | - 700
102 | - 800
103 | - 900
104 | - 1000
105 | - 1100
106 | - 1200
107 | DETECTIONS_PER_IMAGE: 10
108 | EVAL_PERIOD: 0
109 | EXPECTED_RESULTS: [ ]
110 | KEYPOINT_OKS_SIGMAS: [ ]
111 | PRECISE_BN:
112 | ENABLED: false
113 | NUM_ITER: 200
114 | VERSION: 2
115 | VIS_PERIOD: 0
116 |
--------------------------------------------------------------------------------
/configs/vis_eval_sam_pt.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - vis_eval_root
3 | - override model@model.model: sam_pt
4 | - _self_
5 |
6 | model:
7 | model:
8 | point_tracker_mask_batch_size: 100
9 | sam_predictor:
10 | sam_model: ${ ...sam_generator.model }
11 |
--------------------------------------------------------------------------------
/configs/vos_eval_root.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: sam_pt
3 | - logging: vos_eval
4 | - _self_
5 |
6 | evaluator:
7 | _target_: sam_pt.vos_eval.evaluator.SamPtEvaluator
8 | _recursive_: false
9 |
10 | dataset: D17 # D16/D17/Y18/Y19/LV1/LV3/MOSE/BDD100K/G
11 | split: val # val/test
12 | simulate_interactive_point_correction: false
13 | masks_batch_size: 100
14 | seed: 72
15 |
16 | d16_path: ${hydra:runtime.cwd}/data/DAVIS/2016
17 | d17_path: ${hydra:runtime.cwd}/data/DAVIS/2017
18 | y18_path: ${hydra:runtime.cwd}/data/YouTube2018
19 | y19_path: ${hydra:runtime.cwd}/data/YouTube
20 | lv_path: ${hydra:runtime.cwd}/data/long_video_set
21 | mose_path: ${hydra:runtime.cwd}/data/mose
22 | bdd100k_path: ${hydra:runtime.cwd}/data/bdd100k/vos
23 | generic_path: null # For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
24 |
25 | input_only_one_gt_mask_point: false # If true, only one gt mask point will be used for evaluation
26 |
27 | size: -1 # Resize the shorter side to this size. -1 to use original resolution
28 | longest_size: 1024 # Resize the longest side to this size. null to use original resolution. Must be used with size=-1
29 |
30 | flip: false
31 |
32 | output: eval_${dataset}_${split} # Path to save the results. If None, will save to the default path
33 | save_all: false # Save all frames. Useful only in YouTubeVOS/long-time video
34 | save_scores: false
35 | save_overlapping_masks: false # Save overlapping masks along with non-overlapping multi-object masks
36 |
37 | visualize_results: true # Whether to visualize the results using wandb
38 | verbose_visualisations: false # Whether to visualize the results in a verbose way (e.g. with input GIFs), slower
39 | vid_ids: null # Evaluate only on the videos specified in the list, e.g. [0,1,2] (or vid_ids=\[0,1,2\] in command line)
40 | max_videos_to_visualize: 30 # Max number of videos to visualize, used when visualize_results flag is set, videos with id >= max_videos_to_visualize will not be visualized
41 | vid_ids_to_visualize: [ 0, 1, 2, 15 ] # Videos to visualize, used when visualize_results flag is set, null for all videos
42 | log_fmt: gif # gif/mp4
43 |
44 | max_videos: null # Max number of videos to process, useful with the visualize_results flag and for debugging
45 | max_frames: null # Max number of frames to process per video. Useful for debugging
46 |
47 | logging:
48 | exp_id_verbose: ${logging.exp_id}_${dataset}_${split}_${seed}_${now:%Y.%m.%d_%H.%M.%S}
49 |
50 |
51 | hydra:
52 | job:
53 | chdir: True
54 | run:
55 | dir: outputs/${logging.exp_id_verbose}
56 |
--------------------------------------------------------------------------------
/data/demo_data/README.md:
--------------------------------------------------------------------------------
1 | # Demo Data
2 |
3 | This directory contains demo data that users can use to understand the structure and format of input data. Below, we've detailed the sources of our demo data and provided an in-depth explanation of the query points format.
4 |
5 | ## Data Sources
6 |
7 | The provided clips in this directory serve as sample data for the demo and were obtained from Pixabay:
8 |
9 | 1. [`street.mp4`](.street.mp4) - [Video source](https://pixabay.com/videos/street-bus-village-bus-stop-city-38590/).
10 | 2. [`bees.mp4`](bees.mp4) - [Video source](https://pixabay.com/videos/bees-honey-bees-insect-pollen-35093/).
11 |
12 | ## Query Points Format
13 |
14 | Query points are crucial for our application as they define the target object (positive points) and the background/non-target objects (negative points).
15 |
16 | They can be provided interactively by the user or derived from a ground truth mask. The following section explains how they're structured when saved to a text file:
17 |
18 | ```bash
19 | number_of_positive_points
20 | mask_1_timestep ; pos_x_1,pos_y_1 ... pos_x_n,pos_y_n neg_x_1,neg_y_1 ... neg_x_m,neg_y_m
21 | mask_2_timestep ; pos_x_1,pos_y_1 ... pos_x_n,pos_y_n neg_x_1,neg_y_1 ... neg_x_m,neg_y_m
22 | ...
23 | ```
24 |
25 | - `number_of_positive_points` - Specifies the number of positive points
26 | - `mask_x_timestep` - The timestamp for each mask
27 | - `pos_x_i,pos_y_i` - x, y coordinates of the positive points
28 | - `neg_x_i,neg_y_i` - x, y coordinates of the negative points
29 |
30 | Note: The number of negative points is inferred from the total number of points minus the number of positive points.
31 |
32 | Here is a simple example of a query point file with two masks:
33 |
34 | ```sh
35 | 1
36 | 0 ; 10,20 30,30 40,40
37 | 4 ; 123.123,456.456 72,72 5,6
38 | ```
39 |
40 | In this example, each mask has one positive point and two negative points. The positive query point for the first mask, for instance, has (x,y) coordinates of (10,20). Here, the value '10' denotes a distance of 10 pixels from the left image border, and '20' indicates a distance of 20 pixels from the top image border (as the coordinate system begins at the top left corner of the image).
41 |
--------------------------------------------------------------------------------
/data/demo_data/bees.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/data/demo_data/bees.mp4
--------------------------------------------------------------------------------
/data/demo_data/query_points__bees.txt:
--------------------------------------------------------------------------------
1 | 3
2 | 0;282.5,241.25 336.25,240.0 357.5,276.25 531.25,242.5
3 | 0;1116.25,428.75 1137.5,411.25 1165.0,412.5 1130.0,272.5
4 |
--------------------------------------------------------------------------------
/data/demo_data/query_points__street.txt:
--------------------------------------------------------------------------------
1 | 4
2 | 0;403.75,426.25 382.5,481.25 423.75,507.5 403.75,561.25 425.0,603.75 350.0,360.0
3 | 0;668.75,520.0 631.25,493.75 648.75,443.75 681.25,427.5 870.0,513.75 691.25,316.25
4 | 0;307.5,393.75 337.5,540.0 335.0,296.25 497.5,231.25 265.0,286.25 330.0,588.75
5 |
--------------------------------------------------------------------------------
/data/demo_data/street.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/data/demo_data/street.mp4
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/demo/__init__.py
--------------------------------------------------------------------------------
/docs/01-getting-started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ## Setting Up the Environment
4 |
5 | This codebase has been tested and confirmed to be compatible with the package versions listed in [`requirements.txt`](../requirements.txt), along with PyTorch 1.12.0, and Python 3.8.16. These versions were tested on Manjaro Linux and Debian GNU/Linux 10 (buster) systems.
6 |
7 | Start by cloning the repository:
8 |
9 | ```bash
10 | git clone https://github.com/SysCV/sam-pt.git
11 | cd sam-pt
12 | ```
13 |
14 | With the repository now cloned, we recommend creating a new [conda](https://docs.conda.io/en/latest/) virtual environment:
15 |
16 | ```bash
17 | conda create --name sam-pt python=3.8.16 -y
18 | conda activate sam-pt
19 | ```
20 |
21 | Next, install [PyTorch](https://pytorch.org/) 1.12.0 and [torchvision](https://pytorch.org/vision/stable/index.html) 0.13.0, for example with CUDA 11 support:
22 |
23 | ```bash
24 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
25 | ```
26 |
27 | Finally, install the required packages:
28 |
29 | ```bash
30 | pip install -r requirements.txt
31 | ```
32 |
33 | If you wish to use TapNet (or TAPIR) as a point tracker, it's necessary to configure JAX on your system. The required packages, including JAX library version [0.4.11](https://github.com/google/jax/tree/jax-v0.4.11) and others needed by TapNet, can be found in the [`requirements-jax.txt`](../requirements-jax.txt) file. To install JAX, we recommend following the [official installation instructions](https://github.com/google/jax#installation). In some environments, like ours, it may be necessary to build PyTorch and JAX from source.
34 |
35 | ## Running the Demo
36 |
37 | To run the demo, start by preparing your demo data. This can either be one of the clips provided in `data/demo_data`, or a clip of your own. You can also use the horse jumping video `data/DAVIS/2017/trainval/JPEGImages/Full-Resolution/horsejump-high` from [DAVIS 2017](02-prepare-datasets.md#davis-2017).
38 |
39 | The demo expects a sequence of images as input. If your data is a video clip, convert it to images ensuring their filenames are lexicographically ordered (e.g., `frame-000.png`, `frame-001.png`, etc.). For example, the `ffmpeg` command can be used to convert the provided demo clips as follows:
40 |
41 | ```bash
42 | # List the content of the demo_data directory
43 | ls data/demo_data
44 | # bees.mp4 street.mp4 ...
45 |
46 | # Convert bees.mp4 to png frames
47 | mkdir data/demo_data/bees
48 | ffmpeg -i data/demo_data/bees.mp4 -vf fps=5 data/demo_data/bees/frame-%05d.png
49 |
50 | # Convert street.mp4 to png frames
51 | mkdir data/demo_data/street
52 | ffmpeg -i data/demo_data/street.mp4 -vf fps=10 data/demo_data/street/frame-%05d.png
53 | ```
54 |
55 | Before running the demo, you additionally have to make sure to have the SAM and PIPS checkpoints downloaded, as described under [minimal checkpoints](03-prepare-checkpoints.md#minimal-checkpoints).
56 |
57 | ### Running the Interactive Demo
58 |
59 | The interactive demo allows you to specify query points using mouse clicks on a pop-up window. This requires a GUI environment, which is typically available on personal computers. If you're using remote GPUs, you may need to set up X forwarding.
60 |
61 |
62 | Note that the [`${hydra:runtime.cwd}`](https://hydra.cc/docs/1.3/configure_hydra/intro/#hydraruntime) prefix in the commands below needs to be used to prefix relative paths. This is because we launch demos within a [working directory created by Hydra](https://hydra.cc/docs/1.3/tutorials/basic/running_your_app/working_directory/). Follow the instructions displayed in your terminal after launching the interactive demo.
63 |
64 |
65 | ```bash
66 | # Run demo on bees.mp4
67 | export HYDRA_FULL_ERROR=1
68 | python -m demo.demo \
69 | frames_path='${hydra:runtime.cwd}/data/demo_data/bees/' \
70 | query_points_path=null \
71 | longest_side_length=1024 frame_stride=1 max_frames=-1
72 |
73 | # Run demo on street.mp4
74 | export HYDRA_FULL_ERROR=1
75 | python -m demo.demo \
76 | frames_path='${hydra:runtime.cwd}/data/demo_data/street/' \
77 | query_points_path=null \
78 | longest_side_length=1024 frame_stride=1 max_frames=-1
79 | ```
80 |
81 | ### Running the Non-interactive Demo
82 |
83 | You also have the option to run the demo in a non-interactive mode where query points are predefined in a file. You can create the content of a query points file using the interactive demo, which will print a string of the query points. This string can be saved and used for running the non-interactive demo. More details about the format of the query points file can be found in [`data/demo_data/README.md`](../data/demo_data/README.md). Examples of query point files for the [bees](../data/demo_data/query_points__bees.txt) and [street](../data/demo_data/query_points__street.txt) clips are also provided and can be used as in the following commands:
84 |
85 | ```bash
86 | # Run non-interactive demo on bees.mp4
87 | export HYDRA_FULL_ERROR=1
88 | python -m demo.demo \
89 | frames_path='${hydra:runtime.cwd}/data/demo_data/bees/' \
90 | query_points_path='${hydra:runtime.cwd}/data/demo_data/query_points__bees.txt' \
91 | longest_side_length=1024 frame_stride=1 max_frames=-1
92 |
93 | # Run non-interactive demo on street.mp4
94 | export HYDRA_FULL_ERROR=1
95 | python -m demo.demo \
96 | frames_path='${hydra:runtime.cwd}/data/demo_data/street/' \
97 | query_points_path='${hydra:runtime.cwd}/data/demo_data/query_points__street.txt' \
98 | longest_side_length=1024 frame_stride=1 max_frames=-1
99 | ```
100 |
101 | ## Codebase Overview
102 |
103 | Here's a quick overview of our project's codebase and its structure:
104 |
105 | - [`assets`](../assets): Assets related to the GitHub repository
106 | - [`configs`](../configs): YAML configuration files used with Hydra
107 | - [`data`](../data): Directory to store data
108 | - [`demo_data`](../data/demo_data): Demo data with README for data sources and query points file format
109 | - [`demo`](../demo): Code for running the demo
110 | - [`docs`](../docs): Documentation on how to use the codebase
111 | - [`sam_pt`](../sam_pt): Source for SAM-PT
112 | - [`modeling`](../sam_pt/modeling): Main code for SAM-PT
113 | - [`point_tracker`](../sam_pt/point_tracker): Code for different point trackers
114 | - [`utils`](../sam_pt/utils): Utilities used within the SAM-PT module
115 | - [`vis_eval`](../sam_pt/vis_eval): Code for evaluating on Video Instance Segmentation (VIS)
116 | - [`vos_eval`](../sam_pt/vos_eval): Code for evaluating on Video Object Segmentation (VOS)
117 | - [`scripts`](../scripts): Scripts used for small tasks
118 | - [`README.md`](../README.md): Main README file
119 | - [`requirements.txt`](../requirements.txt): General project requirements
120 | - [`requirements-jax.txt`](../requirements-jax.txt): Requirements for using the JAX-based TapNet and TAPIR point trackers
121 |
122 |
123 | ## What's Next?
124 |
125 | Once you are comfortable with running the demo, you might want to explore [how to prepare the data](02-prepare-datasets.md) and [how to prepare the checkpoints](03-prepare-checkpoints.md) that are necessary for [running our VOS and VIS experiments](04-running-experiments.md).
126 |
--------------------------------------------------------------------------------
/requirements-jax.txt:
--------------------------------------------------------------------------------
1 | # JAX version requirements, recommended to be installed following JAX's instructions here: https://github.com/google/jax#installation
2 | # jax==0.4.11
3 | # jaxlib==0.4.11
4 |
5 | jaxline==0.0.5
6 | chex==0.1.7
7 | dm-haiku==0.0.9
8 | optax==0.1.5
9 | einshape@git+https://github.com/deepmind/einshape
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Torch version requirements, recommended to be installed w/o pip:
2 | # torch==1.12.0
3 | # torchvision==0.13.0
4 |
5 | tensorflow==2.12.1
6 | einops==0.4.1
7 | opencv-python==4.7.0.72
8 | timm==0.9.2
9 | flow_vis==0.1
10 |
11 | numpy==1.24.3
12 | h5py==3.9.0
13 | Pillow==9.5.0
14 | pandas==1.5.3
15 | matplotlib==3.5.1
16 | seaborn==0.12.2
17 | scikit-learn==1.1.1
18 | scikit-learn-extra==0.3.0
19 |
20 | hydra-core==1.3.2
21 | wandb==0.15.3
22 | imageio==2.31.1
23 | moviepy==1.0.3
24 | mediapy==1.1.8
25 |
26 | git+https://github.com/facebookresearch/detectron2@v0.6
27 | git+https://github.com/m43/davis2016-davis2017-davis2019-evaluation.git@35401a5619757359673d9d1a7d9e02c177f06f7f
28 | git+https://github.com/facebookresearch/segment-anything.git@aac76a1fb03cf90dc7cb2ad481d511642e51aeba
29 | git+https://github.com/ChaoningZhang/MobileSAM.git@01ea8d0f5590082f0c1ceb0a3e2272593f20154b
30 | git+https://github.com/m43/sam-hq.git@75c73fa27b32435f33119d08a47788db4601e1da
31 | git+https://github.com/facebookresearch/co-tracker.git@4f297a92fe1a684b1b0980da138b706d62e45472
32 |
--------------------------------------------------------------------------------
/sam_pt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/__init__.py
--------------------------------------------------------------------------------
/sam_pt/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/modeling/__init__.py
--------------------------------------------------------------------------------
/sam_pt/modeling/sam.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains hydra wrapper classes for different types of Sam models. Each hydra wrapper provides functionality
3 | for loading checkpoints and storing additional parameters that we used for variable interpolation within Hydra.
4 | """
5 |
6 | import torch
7 | from mobile_sam.modeling import Sam as MobileSam
8 | from segment_anything.modeling import Sam
9 | from segment_anything_hq.modeling import Sam as SamHQ
10 |
11 |
12 | class BaseHydra:
13 | """
14 | Base class for hydra wrappers that loads the model checkpoint and stores additional parameters that we used for
15 | variable interpolation within Hydra.
16 | """
17 |
18 | def __init__(self, model, checkpoint, prompt_embed_dim, image_size, vit_patch_size, image_embedding_size, **kwargs):
19 | super().__init__(**kwargs)
20 |
21 | if checkpoint is not None:
22 | with open(checkpoint, "rb") as f:
23 | state_dict = torch.load(f)
24 | model.load_state_dict(self, state_dict, strict=False)
25 | print(f"Loaded checkpoint from {checkpoint}.")
26 |
27 | # Store additional parameters used for variable interpolation within Hydra
28 | self.prompt_embed_dim = prompt_embed_dim
29 | self.image_size = image_size
30 | self.vit_patch_size = vit_patch_size
31 | self.image_embedding_size = image_embedding_size
32 |
33 |
34 | class SamHydra(BaseHydra, Sam):
35 | """
36 | Wrapper for the Sam model that allows for loading a checkpoint
37 | and setting additional parameters used for variable interpolation.
38 | """
39 |
40 | def __init__(self, *args, **kwargs):
41 | super().__init__(Sam, *args, **kwargs)
42 |
43 |
44 | class SamHQHydra(BaseHydra, SamHQ):
45 | """
46 | Wrapper for the SamHQ model that allows for loading a checkpoint
47 | and setting additional parameters used for variable interpolation.
48 | """
49 |
50 | def __init__(self, *args, **kwargs):
51 | super().__init__(SamHQ, *args, **kwargs)
52 |
53 |
54 | class MobileSamHydra(BaseHydra, MobileSam):
55 | """
56 | Wrapper for the MobileSAM model that allows for loading a checkpoint
57 | and setting additional parameters used for variable interpolation.
58 | """
59 |
60 | def __init__(self, *args, **kwargs):
61 | super().__init__(MobileSam, *args, **kwargs)
62 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import PointTracker
2 | from .pips import PipsPointTracker
3 | from .raft import RaftPointTracker
4 | from .superglue import SuperGluePointTracker
5 | from .tapir import TapirPointTracker
6 | from .tapnet import TapnetPointTracker
7 | from .cotracker import CoTrackerPointTracker
--------------------------------------------------------------------------------
/sam_pt/point_tracker/cotracker/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import CoTrackerPointTracker
2 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/pips/__init__.py:
--------------------------------------------------------------------------------
1 | from .pips import Pips
2 | from .tracker import PipsPointTracker
3 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/pips_plus_plus/__init__.py:
--------------------------------------------------------------------------------
1 | from .pips_plus_plus import PipsPlusPlus
2 | from .tracker import PipsPlusPlusPointTracker
3 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/pips_plus_plus/tracker.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import torch
4 |
5 | from sam_pt.point_tracker import PointTracker
6 | from sam_pt.point_tracker.pips_plus_plus import PipsPlusPlus
7 | from sam_pt.point_tracker.utils import saverloader
8 |
9 |
10 | class PipsPlusPlusPointTracker(PointTracker):
11 |
12 | def __init__(self, checkpoint_path, stride=8, max_sequence_length=128, iters=16, image_size=(512, 896)):
13 | super().__init__()
14 | self.checkpoint_path = checkpoint_path
15 | self.stride = stride
16 | self.max_sequence_length = max_sequence_length
17 | self.iters = iters
18 | self.image_size = tuple(image_size) if image_size is not None else None
19 |
20 | print(f"Loading PIPS++ model from {self.checkpoint_path}")
21 | self.model = PipsPlusPlus(stride=self.stride)
22 | self._loaded_checkpoint_step = saverloader.load(self.checkpoint_path, self.model,
23 | device="cuda" if torch.cuda.is_available() else "cpu")
24 |
25 | def _forward(self, rgbs, query_points):
26 | """
27 | Single direction forward pass.
28 | """
29 | B, S, C, H, W = rgbs.shape
30 | assert query_points.ndim == 2
31 | assert query_points.shape[1] == 2
32 |
33 | # zero-vel init
34 | trajs_e = query_points[None, None, :, :].repeat(1, rgbs.shape[1], 1, 1)
35 |
36 | cur_frame = 0
37 | done = False
38 | feat_init = None
39 | while not done:
40 | end_frame = cur_frame + self.max_sequence_length
41 |
42 | if end_frame > S:
43 | diff = end_frame - S
44 | end_frame = end_frame - diff
45 | cur_frame = max(cur_frame - diff, 0)
46 |
47 | traj_seq = trajs_e[:, cur_frame:end_frame]
48 | rgb_seq = rgbs[:, cur_frame:end_frame]
49 | S_local = rgb_seq.shape[1]
50 |
51 | if feat_init is not None:
52 | feat_init = [fi[:, :S_local] for fi in feat_init]
53 |
54 | preds, preds_anim, feat_init, _ = self.model(traj_seq, rgb_seq, iters=self.iters, feat_init=feat_init)
55 |
56 | trajs_e[:, cur_frame:end_frame] = preds[-1][:, :S_local]
57 | trajs_e[:, end_frame:] = trajs_e[:, end_frame - 1:end_frame] # update the future with new zero-vel
58 |
59 | if end_frame >= S:
60 | done = True
61 | else:
62 | cur_frame = cur_frame + self.max_sequence_length - 1
63 |
64 | visibilities = torch.ones_like(trajs_e[:, :, :, 0])
65 | return trajs_e, visibilities
66 |
67 | def forward(self, rgbs, query_points):
68 | """
69 | Forward function for the tracker.
70 | """
71 | batch_size, num_frames, C, H, W = rgbs.shape
72 | if self.image_size is not None:
73 | rgbs = rgbs.reshape(batch_size * num_frames, C, H, W)
74 | rgbs = rgbs / 255.0
75 | rgbs = torch.nn.functional.interpolate(rgbs, size=tuple(self.image_size), mode="bilinear")
76 | rgbs = rgbs * 255.0
77 | rgbs = rgbs.reshape(batch_size, num_frames, C, *self.image_size)
78 | query_points[:, :, 1] *= self.image_size[0] / H
79 | query_points[:, :, 2] *= self.image_size[1] / W
80 |
81 | # Group query points by their time-step
82 | groups = defaultdict(list)
83 | assert query_points.shape[0] == batch_size == 1, "Only batch size 1 is supported."
84 | for idx, point in enumerate(query_points[0]):
85 | t = int(point[0].item())
86 | groups[t].append((idx, point[1:].tolist()))
87 |
88 | # Dictionary to store results
89 | trajectories_dict = {}
90 | visibilities_dict = {}
91 |
92 | for t, points_with_indices in groups.items():
93 | points = [x[1] for x in points_with_indices]
94 |
95 | # Left to right
96 | if t == num_frames - 1:
97 | left_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda()
98 | left_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda()
99 | else:
100 | left_rgbs = rgbs[:, t:]
101 | left_query = torch.tensor(points, dtype=torch.float32).cuda()
102 | left_trajectories, left_visibilities = self._forward(left_rgbs, left_query)
103 |
104 | # Right to left
105 | if t == 0:
106 | right_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda()
107 | right_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda()
108 | else:
109 | right_rgbs = rgbs[:, :t + 1].flip(1)
110 | right_query = torch.tensor(points, dtype=torch.float32).cuda()
111 | right_trajectories, right_visibilities = self._forward(right_rgbs, right_query)
112 | right_trajectories = right_trajectories.flip(1)
113 | right_visibilities = right_visibilities.flip(1)
114 |
115 | # Merge the results
116 | trajectories = torch.cat([right_trajectories[:, :-1], left_trajectories], dim=1)
117 | visibilities = torch.cat([right_visibilities[:, :-1], left_visibilities], dim=1)
118 |
119 | # Store in dictionary
120 | for idx, (idx, _) in enumerate(points_with_indices):
121 | trajectories_dict[idx] = trajectories[:, :, idx, :]
122 | visibilities_dict[idx] = visibilities[:, :, idx]
123 |
124 | # Assemble the results back in the order of the input query points
125 | n_points = query_points.shape[1]
126 | final_trajectories = torch.stack([trajectories_dict[i] for i in range(n_points)], dim=2)
127 | final_visibilities = torch.stack([visibilities_dict[i] for i in range(n_points)], dim=2)
128 |
129 | # Rescale trajectories back to the original size
130 | if self.image_size is not None:
131 | final_trajectories[:, :, :, 0] *= H / self.image_size[0]
132 | final_trajectories[:, :, :, 1] *= W / self.image_size[1]
133 |
134 | return final_trajectories, final_visibilities
135 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import RaftPointTracker
2 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raft_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/raft/raft_core/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raft_core/corr.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/corr.py
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .util import bilinear_sampler
7 |
8 | try:
9 | import alt_cuda_corr
10 | except:
11 | # alt_cuda_corr is not compiled
12 | pass
13 |
14 |
15 | class CorrBlock:
16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
17 | self.num_levels = num_levels
18 | self.radius = radius
19 | self.corr_pyramid = []
20 |
21 | # all pairs correlation
22 | corr = CorrBlock.corr(fmap1, fmap2)
23 |
24 | batch, h1, w1, dim, h2, w2 = corr.shape
25 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
26 |
27 | self.corr_pyramid.append(corr)
28 | for i in range(self.num_levels - 1):
29 | corr = F.avg_pool2d(corr, 2, stride=2)
30 | self.corr_pyramid.append(corr)
31 |
32 | def __call__(self, coords):
33 | r = self.radius
34 | coords = coords.permute(0, 2, 3, 1)
35 | batch, h1, w1, _ = coords.shape
36 |
37 | out_pyramid = []
38 | for i in range(self.num_levels):
39 | corr = self.corr_pyramid[i]
40 | dx = torch.linspace(-r, r, 2 * r + 1)
41 | dy = torch.linspace(-r, r, 2 * r + 1)
42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
43 |
44 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i
45 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
46 | coords_lvl = centroid_lvl + delta_lvl
47 |
48 | corr = bilinear_sampler(corr, coords_lvl)
49 | corr = corr.view(batch, h1, w1, -1)
50 | out_pyramid.append(corr)
51 |
52 | out = torch.cat(out_pyramid, dim=-1)
53 | return out.permute(0, 3, 1, 2).contiguous().float()
54 |
55 | @staticmethod
56 | def corr(fmap1, fmap2):
57 | batch, dim, ht, wd = fmap1.shape
58 | fmap1 = fmap1.view(batch, dim, ht * wd)
59 | fmap2 = fmap2.view(batch, dim, ht * wd)
60 |
61 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
62 | corr = corr.view(batch, ht, wd, 1, ht, wd)
63 | return corr / torch.sqrt(torch.tensor(dim).float())
64 |
65 |
66 | class AlternateCorrBlock:
67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
68 | self.num_levels = num_levels
69 | self.radius = radius
70 |
71 | self.pyramid = [(fmap1, fmap2)]
72 | for i in range(self.num_levels):
73 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
74 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
75 | self.pyramid.append((fmap1, fmap2))
76 |
77 | def __call__(self, coords):
78 | coords = coords.permute(0, 2, 3, 1)
79 | B, H, W, _ = coords.shape
80 | dim = self.pyramid[0][0].shape[1]
81 |
82 | corr_list = []
83 | for i in range(self.num_levels):
84 | r = self.radius
85 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
86 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
87 |
88 | coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous()
89 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
90 | corr_list.append(corr.squeeze(1))
91 |
92 | corr = torch.stack(corr_list, dim=1)
93 | corr = corr.reshape(B, -1, H, W)
94 | return corr / torch.sqrt(torch.tensor(dim).float())
95 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raft_core/raft.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/raft.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .corr import CorrBlock, AlternateCorrBlock
8 | from .extractor import BasicEncoder, SmallEncoder
9 | from .update import BasicUpdateBlock, SmallUpdateBlock
10 | from .util import coords_grid, upflow8
11 |
12 | try:
13 | autocast = torch.cuda.amp.autocast
14 | except:
15 | # dummy autocast for PyTorch < 1.6
16 | class autocast:
17 | def __init__(self, enabled):
18 | pass
19 |
20 | def __enter__(self):
21 | pass
22 |
23 | def __exit__(self, *args):
24 | pass
25 |
26 |
27 | class RAFT(nn.Module):
28 | def __init__(self, args):
29 | super(RAFT, self).__init__()
30 | self.args = args
31 |
32 | if args.small:
33 | self.hidden_dim = hdim = 96
34 | self.context_dim = cdim = 64
35 | args.corr_levels = 4
36 | args.corr_radius = 3
37 |
38 | else:
39 | self.hidden_dim = hdim = 128
40 | self.context_dim = cdim = 128
41 | args.corr_levels = 4
42 | args.corr_radius = 4
43 |
44 | if 'dropout' not in self.args:
45 | self.args.dropout = 0
46 |
47 | if 'alternate_corr' not in self.args:
48 | self.args.alternate_corr = False
49 |
50 | # feature network, context network, and update block
51 | if args.small:
52 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
53 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout)
54 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
55 |
56 | else:
57 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
58 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout)
59 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
60 |
61 | def freeze_bn(self):
62 | for m in self.modules():
63 | if isinstance(m, nn.BatchNorm2d):
64 | m.eval()
65 |
66 | def initialize_flow(self, img):
67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
68 | N, C, H, W = img.shape
69 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
70 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
71 |
72 | # optical flow computed as difference: flow = coords1 - coords0
73 | return coords0, coords1
74 |
75 | def upsample_flow(self, flow, mask):
76 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
77 | N, _, H, W = flow.shape
78 | mask = mask.view(N, 1, 9, 8, 8, H, W)
79 | mask = torch.softmax(mask, dim=2)
80 |
81 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
82 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
83 |
84 | up_flow = torch.sum(mask * up_flow, dim=2)
85 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
86 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
87 |
88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
89 | """ Estimate optical flow between pair of frames """
90 |
91 | image1 = 2 * (image1 / 255.0) - 1.0
92 | image2 = 2 * (image2 / 255.0) - 1.0
93 |
94 | image1 = image1.contiguous()
95 | image2 = image2.contiguous()
96 |
97 | hdim = self.hidden_dim
98 | cdim = self.context_dim
99 |
100 | # run the feature network
101 | with autocast(enabled=self.args.mixed_precision):
102 | fmap1, fmap2 = self.fnet([image1, image2])
103 |
104 | fmap1 = fmap1.float()
105 | fmap2 = fmap2.float()
106 | if self.args.alternate_corr:
107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 | else:
109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
110 |
111 | # run the context network
112 | with autocast(enabled=self.args.mixed_precision):
113 | cnet = self.cnet(image1)
114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
115 | net = torch.tanh(net)
116 | inp = torch.relu(inp)
117 |
118 | coords0, coords1 = self.initialize_flow(image1)
119 |
120 | if flow_init is not None:
121 | coords1 = coords1 + flow_init
122 |
123 | flow_predictions = []
124 | for itr in range(iters):
125 | coords1 = coords1.detach()
126 | corr = corr_fn(coords1) # index correlation volume
127 |
128 | flow = coords1 - coords0
129 | with autocast(enabled=self.args.mixed_precision):
130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
131 |
132 | # F(t+1) = F(t) + \Delta(t)
133 | coords1 = coords1 + delta_flow
134 |
135 | # upsample predictions
136 | if up_mask is None:
137 | flow_up = upflow8(coords1 - coords0)
138 | else:
139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
140 |
141 | flow_predictions.append(flow_up)
142 |
143 | if test_mode:
144 | corr = corr_fn(coords1) # index correlation volume
145 | # feat = torch.cat([inp, corr], dim=1)
146 | feat = inp
147 | return coords1 - coords0, flow_up, (feat, fmap1, fmap2)
148 |
149 | return flow_predictions
150 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raft_core/update.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/update.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FlowHead(nn.Module):
9 | def __init__(self, input_dim=128, hidden_dim=256):
10 | super(FlowHead, self).__init__()
11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | def forward(self, x):
16 | return self.conv2(self.relu(self.conv1(x)))
17 |
18 |
19 | class ConvGRU(nn.Module):
20 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
21 | super(ConvGRU, self).__init__()
22 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
23 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
24 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
25 |
26 | def forward(self, h, x):
27 | hx = torch.cat([h, x], dim=1)
28 |
29 | z = torch.sigmoid(self.convz(hx))
30 | r = torch.sigmoid(self.convr(hx))
31 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
32 |
33 | h = (1 - z) * h + z * q
34 | return h
35 |
36 |
37 | class SepConvGRU(nn.Module):
38 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
39 | super(SepConvGRU, self).__init__()
40 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
41 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
42 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
43 |
44 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
45 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
46 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
47 |
48 | def forward(self, h, x):
49 | # horizontal
50 | hx = torch.cat([h, x], dim=1)
51 | z = torch.sigmoid(self.convz1(hx))
52 | r = torch.sigmoid(self.convr1(hx))
53 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
54 | h = (1 - z) * h + z * q
55 |
56 | # vertical
57 | hx = torch.cat([h, x], dim=1)
58 | z = torch.sigmoid(self.convz2(hx))
59 | r = torch.sigmoid(self.convr2(hx))
60 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
61 | h = (1 - z) * h + z * q
62 |
63 | return h
64 |
65 |
66 | class SmallMotionEncoder(nn.Module):
67 | def __init__(self, args):
68 | super(SmallMotionEncoder, self).__init__()
69 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
70 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
71 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
72 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
73 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
74 |
75 | def forward(self, flow, corr):
76 | cor = F.relu(self.convc1(corr))
77 | flo = F.relu(self.convf1(flow))
78 | flo = F.relu(self.convf2(flo))
79 | cor_flo = torch.cat([cor, flo], dim=1)
80 | out = F.relu(self.conv(cor_flo))
81 | return torch.cat([out, flow], dim=1)
82 |
83 |
84 | class BasicMotionEncoder(nn.Module):
85 | def __init__(self, args):
86 | super(BasicMotionEncoder, self).__init__()
87 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
88 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
89 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
90 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
91 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
92 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
93 |
94 | def forward(self, flow, corr):
95 | cor = F.relu(self.convc1(corr))
96 | cor = F.relu(self.convc2(cor))
97 | flo = F.relu(self.convf1(flow))
98 | flo = F.relu(self.convf2(flo))
99 |
100 | cor_flo = torch.cat([cor, flo], dim=1)
101 | out = F.relu(self.conv(cor_flo))
102 | return torch.cat([out, flow], dim=1)
103 |
104 |
105 | class SmallUpdateBlock(nn.Module):
106 | def __init__(self, args, hidden_dim=96):
107 | super(SmallUpdateBlock, self).__init__()
108 | self.encoder = SmallMotionEncoder(args)
109 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
110 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
111 |
112 | def forward(self, net, inp, corr, flow):
113 | motion_features = self.encoder(flow, corr)
114 | inp = torch.cat([inp, motion_features], dim=1)
115 | net = self.gru(net, inp)
116 | delta_flow = self.flow_head(net)
117 |
118 | return net, None, delta_flow
119 |
120 |
121 | class BasicUpdateBlock(nn.Module):
122 | def __init__(self, args, hidden_dim=128, input_dim=128):
123 | super(BasicUpdateBlock, self).__init__()
124 | self.args = args
125 | self.encoder = BasicMotionEncoder(args)
126 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
127 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
128 |
129 | self.mask = nn.Sequential(
130 | nn.Conv2d(128, 256, 3, padding=1),
131 | nn.ReLU(inplace=True),
132 | nn.Conv2d(256, 64 * 9, 1, padding=0))
133 |
134 | def forward(self, net, inp, corr, flow, upsample=True):
135 | motion_features = self.encoder(flow, corr)
136 | inp = torch.cat([inp, motion_features], dim=1)
137 |
138 | net = self.gru(net, inp)
139 | delta_flow = self.flow_head(net)
140 |
141 | # scale mask to balence gradients
142 | mask = .25 * self.mask(net)
143 | return net, mask, delta_flow
144 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raft_core/util.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/utils/utils.py
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from scipy import interpolate
7 |
8 |
9 | class InputPadder:
10 | """ Pads images such that dimensions are divisible by 8 """
11 |
12 | def __init__(self, dims, mode='sintel'):
13 | self.ht, self.wd = dims[-2:]
14 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
15 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
16 | if mode == 'sintel':
17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
18 | else:
19 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
20 |
21 | def pad(self, *inputs):
22 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
23 |
24 | def unpad(self, x):
25 | ht, wd = x.shape[-2:]
26 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
27 | return x[..., c[0]:c[1], c[2]:c[3]]
28 |
29 |
30 | def forward_interpolate(flow):
31 | flow = flow.detach().cpu().numpy()
32 | dx, dy = flow[0], flow[1]
33 |
34 | ht, wd = dx.shape
35 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
36 |
37 | x1 = x0 + dx
38 | y1 = y0 + dy
39 |
40 | x1 = x1.reshape(-1)
41 | y1 = y1.reshape(-1)
42 | dx = dx.reshape(-1)
43 | dy = dy.reshape(-1)
44 |
45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
46 | x1 = x1[valid]
47 | y1 = y1[valid]
48 | dx = dx[valid]
49 | dy = dy[valid]
50 |
51 | flow_x = interpolate.griddata(
52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
53 |
54 | flow_y = interpolate.griddata(
55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
56 |
57 | flow = np.stack([flow_x, flow_y], axis=0)
58 | return torch.from_numpy(flow).float()
59 |
60 |
61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
62 | """ Wrapper for grid_sample, uses pixel coordinates """
63 | H, W = img.shape[-2:]
64 | xgrid, ygrid = coords.split([1, 1], dim=-1)
65 | xgrid = 2 * xgrid / (W - 1) - 1
66 | ygrid = 2 * ygrid / (H - 1) - 1
67 |
68 | grid = torch.cat([xgrid, ygrid], dim=-1)
69 | img = F.grid_sample(img, grid, align_corners=True)
70 |
71 | if mask:
72 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
73 | return img, mask.float()
74 |
75 | return img
76 |
77 |
78 | def coords_grid(batch, ht, wd):
79 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
80 | coords = torch.stack(coords[::-1], dim=0).float()
81 | return coords[None].repeat(batch, 1, 1, 1)
82 |
83 |
84 | def upflow8(flow, mode='bilinear'):
85 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
86 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
87 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/raftnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/nets/raftnet.py
2 |
3 | import argparse
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .raft_core.raft import RAFT
9 | from .raft_core.util import InputPadder
10 |
11 |
12 | class Raftnet(nn.Module):
13 | def __init__(self, ckpt_name=None, small=False, alternate_corr=False, mixed_precision=True):
14 | super(Raftnet, self).__init__()
15 | args = argparse.Namespace()
16 | args.small = small
17 | args.alternate_corr = alternate_corr
18 | args.mixed_precision = mixed_precision
19 | self.model = RAFT(args)
20 | if ckpt_name is not None:
21 | state_dict = torch.load(ckpt_name)
22 | state_dict = { # The checkpoint was saved as wrapped in nn.DataParallel, this removes the wrapper
23 | k.replace('module.', ''): v
24 | for k, v in state_dict.items()
25 | if k != 'module'
26 | }
27 | self.model.load_state_dict(state_dict)
28 |
29 | def forward(self, image1, image2, iters=20, test_mode=True):
30 | # input images are in [-0.5, 0.5]
31 | # raftnet wants the images to be in [0,255]
32 | image1 = (image1 + 0.5) * 255.0
33 | image2 = (image2 + 0.5) * 255.0
34 |
35 | padder = InputPadder(image1.shape)
36 | image1, image2 = padder.pad(image1, image2)
37 | if test_mode:
38 | flow_low, flow_up, feat = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode)
39 | flow_up = padder.unpad(flow_up)
40 | return flow_up, feat
41 | else:
42 | flow_predictions = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode)
43 | return flow_predictions
44 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/raft/tracker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import sam_pt.point_tracker.utils.improc
5 | import sam_pt.point_tracker.utils.samp
6 | from sam_pt.point_tracker import PointTracker
7 | from .raftnet import Raftnet
8 |
9 |
10 | class RaftPointTracker(PointTracker):
11 | """
12 | Implements a point tracker that uses the RAFT algorithm for optical flow estimation
13 | from https://arxiv.org/abs/2003.12039. The tracker computes forward and backward flows
14 | for each frame in a video sequence and uses these to estimate the trajectories of given points.
15 | """
16 |
17 | def __init__(self, checkpoint_path):
18 | """
19 | Args:
20 | checkpoint_path (str): The path to the trained RAFT model checkpoint.
21 | """
22 | super().__init__()
23 | self.checkpoint_path = checkpoint_path
24 | if self.checkpoint_path is not None and not os.path.exists(self.checkpoint_path):
25 | raise FileNotFoundError(f"Raft checkpoint not found at {self.checkpoint_path}")
26 | print(f"Loading Raft model from {self.checkpoint_path}")
27 | self.model = Raftnet(ckpt_name=self.checkpoint_path)
28 |
29 | def forward(self, rgbs, query_points, summary_writer=None):
30 | batch_size, n_frames, channels, height, width = rgbs.shape
31 | n_points = query_points.shape[1]
32 |
33 | prep_rgbs = sam_pt.point_tracker.utils.improc.preprocess_color(rgbs)
34 |
35 | flows_forward = []
36 | flows_backward = []
37 | for t in range(1, n_frames):
38 | rgb0 = prep_rgbs[:, t - 1]
39 | rgb1 = prep_rgbs[:, t]
40 | flows_forward.append(self.model.forward(rgb0, rgb1, iters=32)[0])
41 | flows_backward.append(self.model.forward(rgb1, rgb0, iters=32)[0])
42 | flows_forward = torch.stack(flows_forward, dim=1)
43 | flows_backward = torch.stack(flows_backward, dim=1)
44 | assert flows_forward.shape == flows_backward.shape == (batch_size, n_frames - 1, 2, height, width)
45 |
46 | coords = []
47 | for t in range(n_frames):
48 | if t == 0:
49 | coord = torch.zeros_like(query_points[:, :, 1:])
50 | else:
51 | prev_coord = coords[t - 1]
52 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d(
53 | im=flows_forward[:, t - 1],
54 | x=prev_coord[:, :, 0],
55 | y=prev_coord[:, :, 1],
56 | ).permute(0, 2, 1)
57 | assert delta.shape == (batch_size, n_points, 2), "Forward flow at the discrete points"
58 | coord = prev_coord + delta
59 |
60 | # Set the ground truth query point location if the timestep is correct
61 | query_point_mask = query_points[:, :, 0] == t
62 | coord = coord * ~query_point_mask.unsqueeze(-1) + query_points[:, :, 1:] * query_point_mask.unsqueeze(-1)
63 |
64 | coords.append(coord)
65 |
66 | for t in range(n_frames - 2, -1, -1):
67 | coord = coords[t]
68 | successor_coord = coords[t + 1]
69 |
70 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d(
71 | im=flows_backward[:, t],
72 | x=successor_coord[:, :, 0],
73 | y=successor_coord[:, :, 1],
74 | ).permute(0, 2, 1)
75 | assert delta.shape == (batch_size, n_points, 2), "Backward flow at the discrete points"
76 |
77 | # Update only the points that are located prior to the query point
78 | prior_to_query_point_mask = t < query_points[:, :, 0]
79 | coord = (coord * ~prior_to_query_point_mask.unsqueeze(-1) +
80 | (successor_coord + delta) * prior_to_query_point_mask.unsqueeze(-1))
81 | coords[t] = coord
82 |
83 | trajectories = torch.stack(coords, dim=1)
84 | visibilities = (trajectories[:, :, :, 0] >= 0) & \
85 | (trajectories[:, :, :, 1] >= 0) & \
86 | (trajectories[:, :, :, 0] < width) & \
87 | (trajectories[:, :, :, 1] < height)
88 | return trajectories, visibilities
89 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/superglue/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import SuperGluePointTracker
2 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/superglue/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/superglue/models/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/superglue/models/matching.py:
--------------------------------------------------------------------------------
1 | # %BANNER_BEGIN%
2 | # ---------------------------------------------------------------------
3 | # %COPYRIGHT_BEGIN%
4 | #
5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6 | #
7 | # Unpublished Copyright (c) 2020
8 | # Magic Leap, Inc., All Rights Reserved.
9 | #
10 | # NOTICE: All information contained herein is, and remains the property
11 | # of COMPANY. The intellectual and technical concepts contained herein
12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign
13 | # Patents, patents in process, and are protected by trade secret or
14 | # copyright law. Dissemination of this information or reproduction of
15 | # this material is strictly forbidden unless prior written permission is
16 | # obtained from COMPANY. Access to the source code contained herein is
17 | # hereby forbidden to anyone except current COMPANY employees, managers
18 | # or contractors who have executed Confidentiality and Non-disclosure
19 | # agreements explicitly covering such access.
20 | #
21 | # The copyright notice above does not evidence any actual or intended
22 | # publication or disclosure of this source code, which includes
23 | # information that is confidential and/or proprietary, and is a trade
24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32 | #
33 | # %COPYRIGHT_END%
34 | # ----------------------------------------------------------------------
35 | # %AUTHORS_BEGIN%
36 | #
37 | # Originating Authors: Paul-Edouard Sarlin
38 | #
39 | # %AUTHORS_END%
40 | # --------------------------------------------------------------------*/
41 | # %BANNER_END%
42 |
43 | # Taken from: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/matching.py
44 |
45 | import torch
46 |
47 | from .superglue import SuperGlue
48 | from .superpoint import SuperPoint
49 |
50 |
51 | class Matching(torch.nn.Module):
52 | """ Image Matching Frontend (SuperPoint + SuperGlue) """
53 |
54 | def __init__(self, config={}):
55 | super().__init__()
56 | self.superpoint = SuperPoint(config.get('superpoint', {}))
57 | self.superglue = SuperGlue(config.get('superglue', {}))
58 |
59 | def forward(self, data):
60 | """ Run SuperPoint (optionally) and SuperGlue
61 | SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
62 | Args:
63 | data: dictionary with minimal keys: ['image0', 'image1']
64 | """
65 | pred = {}
66 |
67 | # Extract SuperPoint (keypoints, scores, descriptors) if not provided
68 | if 'keypoints0' not in data:
69 | pred0 = self.superpoint({'image': data['image0']})
70 | pred = {**pred, **{k + '0': v for k, v in pred0.items()}}
71 | if 'keypoints1' not in data:
72 | pred1 = self.superpoint({'image': data['image1']})
73 | pred = {**pred, **{k + '1': v for k, v in pred1.items()}}
74 |
75 | # Batch all features
76 | # We should either have i) one image per batch, or
77 | # ii) the same number of local features for all images in the batch.
78 | data = {**data, **pred}
79 |
80 | for k in data:
81 | if isinstance(data[k], (list, tuple)):
82 | data[k] = torch.stack(data[k])
83 |
84 | # Perform the matching
85 | pred = {**pred, **self.superglue(data)}
86 |
87 | return pred
88 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import TapirPointTracker
2 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/configs/tapir_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapir_config.py
17 |
18 | """Default config to train the TAPIR."""
19 |
20 | from jaxline import base_config
21 | from ml_collections import config_dict
22 |
23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels)
24 |
25 |
26 | # We define the experiment launch config in the same file as the experiment to
27 | # keep things self-contained in a single file.
28 | def get_config() -> config_dict.ConfigDict:
29 | """Return config object for training."""
30 | config = base_config.get_base_config()
31 |
32 | # Experiment config.
33 | config.training_steps = 100000
34 |
35 | # NOTE: duplicates not allowed.
36 | config.shared_module_names = ('tapir_model',)
37 |
38 | config.dataset_names = ('kubric',)
39 | # Note: eval modes must always start with 'eval_'.
40 | config.eval_modes = (
41 | 'eval_davis_points',
42 | 'eval_jhmdb',
43 | 'eval_robotics_points',
44 | 'eval_kinetics_points',
45 | )
46 | config.checkpoint_dir = '/tmp/tapnet_training/'
47 | config.evaluate_every = 10000
48 |
49 | config.experiment_kwargs = config_dict.ConfigDict(
50 | dict(
51 | config=dict(
52 | sweep_name='default_sweep',
53 | save_final_checkpoint_as_npy=True,
54 | # `enable_double_transpose` should always be false when using 1D.
55 | # For other D It is also completely untested and very unlikely
56 | # to work.
57 | optimizer=dict(
58 | base_lr=1e-3,
59 | max_norm=-1, # < 0 to turn off.
60 | weight_decay=1e-1,
61 | schedule_type='cosine',
62 | cosine_decay_kwargs=dict(
63 | init_value=0.0,
64 | warmup_steps=1000,
65 | end_value=0.0,
66 | ),
67 | optimizer='adam',
68 | # Optimizer-specific kwargs.
69 | adam_kwargs=dict(
70 | b1=0.9,
71 | b2=0.95,
72 | eps=1e-8,
73 | ),
74 | ),
75 | fast_variables=tuple(),
76 | shared_modules=dict(
77 | shared_module_names=config.get_oneway_ref(
78 | 'shared_module_names',
79 | ),
80 | tapir_model_kwargs=dict(
81 | bilinear_interp_with_depthwise_conv=True,
82 | use_causal_conv=False,
83 | ),
84 | ),
85 | datasets=dict(
86 | dataset_names=config.get_oneway_ref('dataset_names'),
87 | kubric_kwargs=dict(
88 | batch_dims=8,
89 | shuffle_buffer_size=128,
90 | train_size=TRAIN_SIZE[1:3],
91 | ),
92 | ),
93 | supervised_point_prediction_kwargs=dict(
94 | prediction_algo='cost_volume_regressor',
95 | model_key='tapir_model',
96 | ),
97 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
98 | evaluate_every=config.get_oneway_ref('evaluate_every'),
99 | eval_modes=config.get_oneway_ref('eval_modes'),
100 | # If true, run evaluate() on the experiment once before
101 | # you load a checkpoint.
102 | # This is useful for getting initial values of metrics
103 | # at random weights, or when debugging locally if you
104 | # do not have any train job running.
105 | davis_points_path='',
106 | jhmdb_path='',
107 | robotics_points_path='',
108 | training=dict(
109 | # Note: to sweep n_training_steps, DO NOT sweep these
110 | # fields directly. Instead sweep config.training_steps.
111 | # Otherwise, decay/stopping logic
112 | # is not guaranteed to be consistent.
113 | n_training_steps=config.get_oneway_ref('training_steps'),
114 | ),
115 | inference=dict(
116 | input_video_path='',
117 | output_video_path='',
118 | resize_height=256, # video height resized to before inference
119 | resize_width=256, # video width resized to before inference
120 | num_points=20, # number of random points to sample
121 | ),
122 | )
123 | )
124 | )
125 |
126 | # Set up where to store the resulting model.
127 | config.train_checkpoint_all_hosts = False
128 | config.save_checkpoint_interval = 10
129 | config.eval_initial_weights = True
130 |
131 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
132 | config.lock()
133 |
134 | return config
135 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo program for TAPIR, to make sure that pytorch+jax has been set up correctly.
3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking.
4 |
5 | Example usage:
6 | ```
7 | python -m sam_pt.point_tracker.tapir.demo
8 | ```
9 | """
10 | import time
11 |
12 | import haiku as hk
13 | import jax
14 | import matplotlib.cm as cm
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import tensorflow as tf
18 | import torch
19 | from torch.nn import functional as F
20 |
21 | from demo.demo import load_demo_data
22 | from . import tapir_model
23 | from .configs.tapir_config import get_config
24 |
25 | if __name__ == '__main__':
26 | # 1. Prepare config
27 | config = get_config()
28 | checkpoint_dir = "./models/tapir_ckpts/open_source_ckpt/"
29 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX.
30 | tf.config.experimental.set_visible_devices([], 'GPU')
31 | tf.config.experimental.set_visible_devices([], 'TPU')
32 |
33 | # 2. Prepare model
34 | checkpoint = np.load(checkpoint_dir + "tapir_checkpoint_panning.npy", allow_pickle=True).item()
35 | params, state = checkpoint["params"], checkpoint["state"]
36 | # tapir_model_kwargs = config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"]
37 | tapir_model_kwargs = {
38 | "bilinear_interp_with_depthwise_conv": False,
39 | "pyramid_level": 0,
40 | "use_causal_conv": False,
41 | }
42 |
43 |
44 | def forward(rgbs, query_points):
45 | tapir = tapir_model.TAPIR(**tapir_model_kwargs)
46 | outputs = tapir(
47 | video=rgbs[None, ...],
48 | query_points=query_points[None, ...],
49 | query_chunk_size=64,
50 | is_training=False,
51 | )
52 | return outputs
53 |
54 |
55 | transform = hk.transform_with_state(forward)
56 |
57 |
58 | def f(rgbs_tapir, query_points_tapir):
59 | rng = jax.random.PRNGKey(72)
60 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir)
61 | return outputs
62 |
63 |
64 | jitted_f = jax.jit(f)
65 |
66 | # 3. Prepare data
67 | rgbs, _, query_points = load_demo_data(
68 | frames_path="data/demo_data/bees",
69 | query_points_path="data/demo_data/query_points__bees.txt",
70 | )
71 | original_hw = rgbs.shape[-2:]
72 | tapir_input_hw = (
73 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width)
74 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw)
75 | rgbs_tapir = F.interpolate(rgbs / 255, tapir_input_hw, mode="bilinear", align_corners=False, antialias=True)
76 | rgbs_tapir = rgbs_tapir.numpy() * 2 - 1
77 | rgbs_tapir = rgbs_tapir.transpose(0, 2, 3, 1)
78 |
79 | ## Take the loaded query points
80 | # query_points = query_points
81 | ## Or make a 16x16 grid of query points
82 | query_points = torch.zeros((1, 16, 16, 3), dtype=torch.float32)
83 | query_points[:, :, :, 0] = 1
84 | query_points[:, :, :, 1] = torch.linspace(1, original_hw[1] - 1, 16)
85 | query_points[:, :, :, 2] = torch.linspace(1, original_hw[0] - 1, 16).unsqueeze(-1)
86 | query_points = query_points.reshape(1, -1, 3)
87 |
88 | query_points_tapir = query_points.clone()
89 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0)
90 | query_points_tapir = query_points_tapir.flatten(0, 1)
91 | query_points_tapir[:, 1:] = query_points_tapir[:, 1:].flip(-1)
92 | query_points_tapir = query_points_tapir.numpy()
93 |
94 | # 4. Run model
95 | outputs = jitted_f(rgbs_tapir, query_points_tapir)
96 |
97 | n_frames = rgbs.shape[0]
98 | n_masks, n_points_per_mask, _ = query_points.shape
99 |
100 | # 5. Postprocess
101 | tapir_visibility_threshold = 0.5
102 |
103 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"][0]).copy()).permute(1, 0)
104 | expected_dist = expected_dist.unflatten(1, (n_masks, n_points_per_mask))
105 |
106 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0)
107 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask))
108 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist))
109 | visibilities = visibilities_probs > tapir_visibility_threshold
110 |
111 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2)
112 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask))
113 | trajectories = trajectories / rescale_factor_hw.flip(-1)
114 |
115 | # 6. Visualize
116 | mask_idx = -1
117 | for frame_idx in range(n_frames):
118 | h, w = rgbs.shape[2], rgbs.shape[3]
119 | dpi = 100
120 | plt.figure(figsize=(w / dpi, h / dpi))
121 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none")
122 | x = trajectories[frame_idx, mask_idx, :, 0]
123 | y = trajectories[frame_idx, mask_idx, :, 1]
124 | colors = cm.rainbow(np.linspace(0, 1, len(y)))
125 | v = visibilities[frame_idx, mask_idx, :]
126 | # v = (visibilities[frame_idx, mask_idx, :] * 0) == 0
127 | x = x[v]
128 | y = y[v]
129 | colors = colors[v]
130 | plt.title(f"F{frame_idx:02}-M{mask_idx:02}-V{(visibilities_probs[frame_idx, mask_idx, :5] * 1)}")
131 | plt.scatter(x, y, color=colors, linewidths=6)
132 | plt.xlim(trajectories[..., 0].min(), trajectories[..., 0].max())
133 | plt.ylim(trajectories[..., 1].max(), trajectories[..., 1].min())
134 | plt.axis("off")
135 | plt.tight_layout(pad=0)
136 | plt.show()
137 | time.sleep(0.1)
138 |
139 | # 7. Benchmark forward pass speed in for loop
140 | n_loops = 100
141 | start_time = time.time()
142 | for _ in range(n_loops):
143 | outputs = jitted_f(rgbs_tapir, query_points_tapir)
144 | end_time = time.time()
145 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms")
146 |
147 | print("Done")
148 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapir/models/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | from sam_pt.point_tracker import PointTracker
7 |
8 |
9 | class TapirPointTracker(PointTracker):
10 | """
11 | A point tracker that uses TAPIR from https://arxiv.org/abs/2306.08637 to track points.
12 | """
13 |
14 | def __init__(self, checkpoint_path, visibility_threshold):
15 | from .configs.tapir_config import get_config
16 | super().__init__()
17 |
18 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX
19 | tf.config.experimental.set_visible_devices([], 'GPU')
20 | tf.config.experimental.set_visible_devices([], 'TPU')
21 |
22 | # # v1: use the last GPU
23 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch)
24 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this
25 | # gpus = jax.devices('gpu')
26 | # device = gpus[-1]
27 | # jax.jit ... device=device
28 |
29 | # v2: share the gpu with Sam since they are run sequentially
30 | # but make jax free up the allocated memory once it is done
31 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform`
32 |
33 | assert checkpoint_path is not None
34 | self.checkpoint_path = checkpoint_path
35 | self.config = get_config()
36 | self.visibility_threshold = visibility_threshold
37 | self.jitted_forward = self._create_jitted_forward()
38 |
39 | def _create_jitted_forward(self):
40 | import haiku as hk
41 | import jax
42 | from . import tapir_model
43 |
44 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item()
45 | params, state = checkpoint["params"], checkpoint["state"]
46 | # tapir_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"]
47 | tapir_model_kwargs = {
48 | "bilinear_interp_with_depthwise_conv": False,
49 | "pyramid_level": 0,
50 | "use_causal_conv": False,
51 | }
52 |
53 | def _forward(rgbs, query_points):
54 | tapir = tapir_model.TAPIR(**tapir_model_kwargs)
55 | outputs = tapir(
56 | video=rgbs,
57 | query_points=query_points,
58 | query_chunk_size=64,
59 | is_training=False,
60 | )
61 | return outputs
62 |
63 | transform = hk.transform_with_state(_forward)
64 |
65 | def forward(rgbs_tapir, query_points_tapir):
66 | rng = jax.random.PRNGKey(72)
67 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir)
68 | return outputs
69 |
70 | return jax.jit(forward)
71 |
72 | def forward(self, rgbs, query_points, summary_writer=None):
73 | batch_size, n_frames, channels, height, width = rgbs.shape
74 | n_points = query_points.shape[1]
75 |
76 | # 1. Prepare image resizing
77 | original_hw = (height, width)
78 | tapir_input_hw = (
79 | self.config.experiment_kwargs.config.inference.resize_height,
80 | self.config.experiment_kwargs.config.inference.resize_width,
81 | )
82 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw)
83 |
84 | # 2. Prepare inputs
85 | assert rgbs.dtype == torch.uint8
86 | rgbs_tapir = F.interpolate(rgbs.flatten(0, 1) / 255, tapir_input_hw, mode="bilinear", align_corners=False,
87 | antialias=True)
88 | rgbs_tapir = rgbs_tapir.unflatten(0, (batch_size, n_frames))
89 | rgbs_tapir = rgbs_tapir.cpu().numpy() * 2 - 1
90 | rgbs_tapir = rgbs_tapir.transpose(0, 1, 3, 4, 2)
91 | query_points_tapir = query_points.cpu().clone()
92 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0)
93 | query_points_tapir[:, :, 1:] = query_points_tapir[:, :, 1:].flip(-1) # flip x and y
94 | query_points_tapir = query_points_tapir.numpy()
95 |
96 | # 3. Run model
97 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once?
98 | outputs = self.jitted_forward(rgbs_tapir, query_points_tapir)
99 |
100 | # 4. Postprocess outputs
101 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"]).copy()).permute(0, 2, 1)
102 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1)
103 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist))
104 | visibilities = visibilities_probs > self.visibility_threshold
105 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3)
106 | trajectories = trajectories / rescale_factor_hw.flip(-1)
107 |
108 | return trajectories, visibilities
109 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapir/utils/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapir/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py
17 |
18 | """Utilities for transforming image coordinates."""
19 |
20 | from typing import Sequence
21 |
22 | import numpy as np
23 |
24 |
25 | def convert_grid_coordinates(
26 | coords: np.ndarray,
27 | input_grid_size: Sequence[int],
28 | output_grid_size: Sequence[int],
29 | coordinate_format: str = 'xy',
30 | ) -> np.ndarray:
31 | """Convert image coordinates between image grids of different sizes.
32 |
33 | By default, it assumes that the image corners are aligned. Therefore,
34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid
35 | cell), multiplies by the size ratio, and then subtracts .5.
36 |
37 | Args:
38 | coords: The coordinates to be converted. It is of shape [..., 2] if
39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'.
40 | input_grid_size: The size of the image/grid that the coordinates currently
41 | are with respect to. This is a 2-tuple of the format [width, height]
42 | if coordinate_format is 'xy' or a 3-tuple of the format
43 | [num_frames, height, width] if coordinate_format is 'tyx'.
44 | output_grid_size: The size of the target image/grid that you want the
45 | coordinates to be with respect to. This is a 2-tuple of the format
46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format
47 | [num_frames, height, width] if coordinate_format is 'tyx'.
48 | coordinate_format: Which format the coordinates are in. This can be one
49 | of 'xy' (the default) or 'tyx', which are the only formats used in this
50 | project.
51 |
52 | Returns:
53 | The transformed coordinates, of the same shape as coordinates.
54 |
55 | Raises:
56 | ValueError: if coordinates don't match the given format.
57 | """
58 | if isinstance(input_grid_size, tuple):
59 | input_grid_size = np.array(input_grid_size)
60 | if isinstance(output_grid_size, tuple):
61 | output_grid_size = np.array(output_grid_size)
62 |
63 | if coordinate_format == 'xy':
64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
65 | raise ValueError(
66 | 'If coordinate_format is xy, the shapes must be length 2.')
67 | elif coordinate_format == 'tyx':
68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
69 | raise ValueError(
70 | 'If coordinate_format is tyx, the shapes must be length 3.')
71 | if input_grid_size[0] != output_grid_size[0]:
72 | raise ValueError('converting frame count is not supported.')
73 | else:
74 | raise ValueError('Recognized coordinate formats are xy and tyx.')
75 |
76 | position_in_grid = coords
77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size
78 |
79 | return position_in_grid
80 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracker import TapnetPointTracker
2 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/configs/tapnet_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapnet_config.py
17 |
18 | """Default config to train the TapNet."""
19 |
20 | from jaxline import base_config
21 | from ml_collections import config_dict
22 |
23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels)
24 |
25 |
26 | # We define the experiment launch config in the same file as the experiment to
27 | # keep things self-contained in a single file.
28 | def get_config() -> config_dict.ConfigDict:
29 | """Return config object for training."""
30 | config = base_config.get_base_config()
31 |
32 | # Experiment config.
33 | config.training_steps = 100000
34 |
35 | # NOTE: duplicates not allowed.
36 | config.shared_module_names = ('tapnet_model',)
37 |
38 | config.dataset_names = ('kubric',)
39 | # Note: eval modes must always start with 'eval_'.
40 | config.eval_modes = (
41 | 'eval_davis_points',
42 | 'eval_jhmdb',
43 | 'eval_robotics_points',
44 | 'eval_kinetics_points',
45 | )
46 | config.checkpoint_dir = 'logs/tapnet_training/'
47 | config.evaluate_every = 100
48 |
49 | config.experiment_kwargs = config_dict.ConfigDict(
50 | dict(
51 | config=dict(
52 | sweep_name='default_sweep',
53 | save_final_checkpoint_as_npy=True,
54 | # `enable_double_transpose` should always be false when using 1D.
55 | # For other D It is also completely untested and very unlikely
56 | # to work.
57 | optimizer=dict(
58 | base_lr=2e-3,
59 | max_norm=-1, # < 0 to turn off.
60 | weight_decay=1e-2,
61 | schedule_type='cosine',
62 | cosine_decay_kwargs=dict(
63 | init_value=0.0,
64 | warmup_steps=5000,
65 | end_value=0.0,
66 | ),
67 | optimizer='adam',
68 | # Optimizer-specific kwargs.
69 | adam_kwargs=dict(
70 | b1=0.9,
71 | b2=0.95,
72 | eps=1e-8,
73 | ),
74 | ),
75 | fast_variables=tuple(),
76 | shared_modules=dict(
77 | shared_module_names=config.get_oneway_ref(
78 | 'shared_module_names',
79 | ),
80 | tapnet_model_kwargs=dict(),
81 | ),
82 | datasets=dict(
83 | dataset_names=config.get_oneway_ref('dataset_names'),
84 | kubric_kwargs=dict(
85 | batch_dims=8,
86 | shuffle_buffer_size=128,
87 | train_size=TRAIN_SIZE[1:3],
88 | ),
89 | ),
90 | supervised_point_prediction_kwargs=dict(
91 | prediction_algo='cost_volume_regressor',
92 | ),
93 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
94 | evaluate_every=config.get_oneway_ref('evaluate_every'),
95 | eval_modes=config.get_oneway_ref('eval_modes'),
96 | # If true, run evaluate() on the experiment once before
97 | # you load a checkpoint.
98 | # This is useful for getting initial values of metrics
99 | # at random weights, or when debugging locally if you
100 | # do not have any train job running.
101 | davis_points_path='',
102 | jhmdb_path='',
103 | robotics_points_path='',
104 | training=dict(
105 | # Note: to sweep n_training_steps, DO NOT sweep these
106 | # fields directly. Instead, sweep config.training_steps.
107 | # Otherwise, decay/stopping logic
108 | # is not guaranteed to be consistent.
109 | n_training_steps=config.get_oneway_ref('training_steps'),
110 | ),
111 | inference=dict(
112 | input_video_path='',
113 | output_video_path='',
114 | resize_height=256, # video height resized to before inference
115 | resize_width=256, # video width resized to before inference
116 | num_points=20, # number of random points to sample
117 | ),
118 | )
119 | )
120 | )
121 |
122 | # Set up where to store the resulting model.
123 | config.train_checkpoint_all_hosts = False
124 | config.save_checkpoint_interval = 10
125 | config.eval_initial_weights = True
126 |
127 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
128 | config.lock()
129 |
130 | return config
131 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo program for TAPNet, to make sure that pytorch+jax has been set up correctly.
3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking.
4 |
5 | Example usage:
6 | ```
7 | python -m sam_pt.point_tracker.tapnet.demo
8 | ```
9 | """
10 | import time
11 |
12 | import haiku as hk
13 | import jax
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | import tensorflow as tf
17 | import torch
18 | from torch.nn import functional as F
19 |
20 | from demo.demo import load_demo_data
21 | from . import tapnet_model
22 | from .configs.tapnet_config import get_config
23 |
24 | if __name__ == '__main__':
25 | # 1. Prepare config
26 | config = get_config()
27 | checkpoint_dir = "./models/tapnet_ckpts/open_source_ckpt/"
28 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX.
29 | tf.config.experimental.set_visible_devices([], 'GPU')
30 | tf.config.experimental.set_visible_devices([], 'TPU')
31 |
32 | # 2. Prepare model
33 | checkpoint = np.load(checkpoint_dir + "checkpoint_wo_optstate.npy", allow_pickle=True).item()
34 | params, state = checkpoint["params"], checkpoint["state"]
35 | tapnet_model_kwargs = config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"]
36 |
37 |
38 | def forward(rgbs, query_points):
39 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs)
40 | outputs = tapnet(
41 | video=rgbs[None, ...],
42 | query_points=query_points[None, ...],
43 | query_chunk_size=16,
44 | get_query_feats=True,
45 | is_training=False,
46 | )
47 | return outputs
48 |
49 |
50 | transform = hk.transform_with_state(forward)
51 |
52 |
53 | def f(rgbs_tapnet, query_points_tapnet):
54 | rng = jax.random.PRNGKey(72)
55 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet)
56 | return outputs
57 |
58 |
59 | jitted_f = jax.jit(f)
60 |
61 | # 3. Prepare data
62 | rgbs, _, query_points = load_demo_data(
63 | frames_path="data/demo_data/bees",
64 | query_points_path="data/demo_data/query_points__bees.txt",
65 | )
66 | original_hw = rgbs.shape[-2:]
67 | tapnet_input_hw = (
68 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width)
69 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw)
70 | rgbs_tapnet = F.interpolate(rgbs / 255, tapnet_input_hw, mode="bilinear", align_corners=False, antialias=True)
71 | rgbs_tapnet = rgbs_tapnet.numpy() * 2 - 1
72 | rgbs_tapnet = rgbs_tapnet.transpose(0, 2, 3, 1)
73 | query_points_tapnet = query_points.clone()
74 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0)
75 | query_points_tapnet = query_points_tapnet.flatten(0, 1)
76 | query_points_tapnet[:, 1:] = query_points_tapnet[:, 1:].flip(-1)
77 | query_points_tapnet = query_points_tapnet.numpy()
78 | query_points_tapnet = query_points_tapnet
79 |
80 | # 4. Run model
81 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet)
82 |
83 | n_frames = rgbs.shape[0]
84 | n_masks, n_points_per_mask, _ = query_points.shape
85 |
86 | # 5. Postprocess
87 | tapnet_visibility_threshold = 0.5
88 |
89 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0)
90 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask))
91 | occlussion_probs = torch.sigmoid(occlussion_logits)
92 | visibilities_probs = 1 - occlussion_probs
93 | visibilities = visibilities_probs > tapnet_visibility_threshold
94 |
95 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2)
96 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask))
97 | trajectories = trajectories / rescale_factor_hw.flip(-1)
98 |
99 | # 6. Visualize
100 | for mask_idx in range(n_masks):
101 | if mask_idx != 2:
102 | continue
103 | for frame_idx in range(n_frames):
104 | h, w = rgbs.shape[2], rgbs.shape[3]
105 | dpi = 100
106 | plt.figure(figsize=(w / dpi, h / dpi))
107 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none")
108 | plt.scatter(trajectories[frame_idx, mask_idx, :, 0], trajectories[frame_idx, mask_idx, :, 1])
109 | plt.axis("off")
110 | plt.tight_layout(pad=0)
111 | plt.show()
112 |
113 | # 7. Benchmark forward pass speed in for loop
114 | n_loops = 100
115 | start_time = time.time()
116 | for _ in range(n_loops):
117 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet)
118 | end_time = time.time()
119 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms")
120 |
121 | print("Done")
122 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapnet/models/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | from sam_pt.point_tracker import PointTracker
7 |
8 |
9 | class TapnetPointTracker(PointTracker):
10 | """
11 | A point tracker that uses TapNet from https://arxiv.org/abs/2211.03726 to track points.
12 | """
13 | def __init__(self, checkpoint_path, visibility_threshold):
14 | from .configs.tapnet_config import get_config
15 | super().__init__()
16 |
17 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX
18 | tf.config.experimental.set_visible_devices([], 'GPU')
19 | tf.config.experimental.set_visible_devices([], 'TPU')
20 |
21 | # # v1: use the last GPU
22 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch)
23 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this
24 | # gpus = jax.devices('gpu')
25 | # device = gpus[-1]
26 | # jax.jit ... device=device
27 |
28 | # v2: share the gpu with Sam since they are run sequentially
29 | # but make jax free up the allocated memory once it is done
30 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform`
31 |
32 | assert checkpoint_path is not None
33 | self.checkpoint_path = checkpoint_path
34 | self.config = get_config()
35 | self.visibility_threshold = visibility_threshold
36 | self.jitted_forward = self._create_jitted_forward()
37 |
38 | def _create_jitted_forward(self):
39 | import haiku as hk
40 | import jax
41 | from . import tapnet_model
42 |
43 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item()
44 | params, state = checkpoint["params"], checkpoint["state"]
45 | tapnet_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"]
46 |
47 | def _forward(rgbs, query_points):
48 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs)
49 | outputs = tapnet(
50 | video=rgbs,
51 | query_points=query_points,
52 | query_chunk_size=16,
53 | get_query_feats=True,
54 | is_training=False,
55 | )
56 | return outputs
57 |
58 | transform = hk.transform_with_state(_forward)
59 |
60 | def forward(rgbs_tapnet, query_points_tapnet):
61 | rng = jax.random.PRNGKey(72)
62 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet)
63 | return outputs
64 |
65 | return jax.jit(forward)
66 |
67 | def forward(self, rgbs, query_points, summary_writer=None):
68 | batch_size, n_frames, channels, height, width = rgbs.shape
69 | n_points = query_points.shape[1]
70 |
71 | # 1. Prepare image resizing
72 | original_hw = (height, width)
73 | tapnet_input_hw = (
74 | self.config.experiment_kwargs.config.inference.resize_height,
75 | self.config.experiment_kwargs.config.inference.resize_width,
76 | )
77 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw)
78 |
79 | # 2. Prepare inputs
80 | rgbs_tapnet = F.interpolate(rgbs.flatten(0, 1) / 255, tapnet_input_hw, mode="bilinear", align_corners=False,
81 | antialias=True)
82 | rgbs_tapnet = rgbs_tapnet.unflatten(0, (batch_size, n_frames))
83 | rgbs_tapnet = rgbs_tapnet.cpu().numpy() * 2 - 1
84 | rgbs_tapnet = rgbs_tapnet.transpose(0, 1, 3, 4, 2)
85 | query_points_tapnet = query_points.cpu().clone()
86 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0)
87 | query_points_tapnet[:, :, 1:] = query_points_tapnet[:, :, 1:].flip(-1) # flip x and y
88 | query_points_tapnet = query_points_tapnet.numpy()
89 |
90 | # 3. Run model
91 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once?
92 | outputs = self.jitted_forward(rgbs_tapnet, query_points_tapnet)
93 |
94 | # 4. Postprocess outputs
95 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1)
96 | occlussion_probs = torch.sigmoid(occlussion_logits)
97 | visibilities_probs = 1 - occlussion_probs
98 | visibilities = visibilities_probs > self.visibility_threshold
99 |
100 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3)
101 | trajectories = trajectories / rescale_factor_hw.flip(-1)
102 |
103 | return trajectories, visibilities
104 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/tapnet/utils/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tapnet/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py
17 |
18 | """Utilities for transforming image coordinates."""
19 |
20 | from typing import Sequence
21 |
22 | import numpy as np
23 |
24 |
25 | def convert_grid_coordinates(
26 | coords: np.ndarray,
27 | input_grid_size: Sequence[int],
28 | output_grid_size: Sequence[int],
29 | coordinate_format: str = 'xy',
30 | ) -> np.ndarray:
31 | """Convert image coordinates between image grids of different sizes.
32 |
33 | By default, it assumes that the image corners are aligned. Therefore,
34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid
35 | cell), multiplies by the size ratio, and then subtracts .5.
36 |
37 | Args:
38 | coords: The coordinates to be converted. It is of shape [..., 2] if
39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'.
40 | input_grid_size: The size of the image/grid that the coordinates currently
41 | are with respect to. This is a 2-tuple of the format [width, height]
42 | if coordinate_format is 'xy' or a 3-tuple of the format
43 | [num_frames, height, width] if coordinate_format is 'tyx'.
44 | output_grid_size: The size of the target image/grid that you want the
45 | coordinates to be with respect to. This is a 2-tuple of the format
46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format
47 | [num_frames, height, width] if coordinate_format is 'tyx'.
48 | coordinate_format: Which format the coordinates are in. This can be one
49 | of 'xy' (the default) or 'tyx', which are the only formats used in this
50 | project.
51 |
52 | Returns:
53 | The transformed coordinates, of the same shape as coordinates.
54 |
55 | Raises:
56 | ValueError: if coordinates don't match the given format.
57 | """
58 | if isinstance(input_grid_size, tuple):
59 | input_grid_size = np.array(input_grid_size)
60 | if isinstance(output_grid_size, tuple):
61 | output_grid_size = np.array(output_grid_size)
62 |
63 | if coordinate_format == 'xy':
64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
65 | raise ValueError(
66 | 'If coordinate_format is xy, the shapes must be length 2.')
67 | elif coordinate_format == 'tyx':
68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
69 | raise ValueError(
70 | 'If coordinate_format is tyx, the shapes must be length 3.')
71 | if input_grid_size[0] != output_grid_size[0]:
72 | raise ValueError('converting frame count is not supported.')
73 | else:
74 | raise ValueError('Recognized coordinate formats are xy and tyx.')
75 |
76 | position_in_grid = coords
77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size
78 |
79 | return position_in_grid
80 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/tracker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import ABC, abstractmethod
3 | from torch import nn
4 | from typing import Tuple
5 |
6 |
7 | class PointTracker(ABC, nn.Module):
8 | """
9 | Abstract class for point trackers.
10 |
11 | Methods
12 | -------
13 | forward(rgbs, query_points)
14 | Performs a forward pass through the model and returns the predicted trajectories and visibilities.
15 | evaluate_batch(rgbs, query_points, trajectories_gt=None, visibilities_gt=None)
16 | Evaluates a batch of videos and returns the results.
17 | unpack_results(packed_results, batch_idx)
18 | Unpacks the results for all point and all videos in the batch.
19 | """
20 |
21 | @abstractmethod
22 | def forward(self, rgbs, query_points) -> Tuple[torch.Tensor, torch.Tensor]:
23 | """
24 | Performs a forward pass through the model and returns the predicted trajectories and visibilities.
25 |
26 | Parameters
27 | ----------
28 | rgbs : torch.Tensor
29 | A tensor of shape (batch_size, n_frames, channels, height, width)
30 | containing the RGB images in uint8 [0-255] format.
31 | query_points : torch.Tensor
32 | A tensor of shape (batch_size, n_points, 3) containing the query points,
33 | each point being (t, x, y).
34 |
35 | Returns
36 | -------
37 | tuple of two torch.Tensor
38 | Returns a tuple of (trajectories, visibilities).
39 | - `trajectories`: Predicted point trajectories with shape (batch_size, n_frames, n_points, 2), where each
40 | trajectory represents a series of (x, y) coordinates in the video for a specific point.
41 | - `visibilities`: Predicted point visibilities with shape (batch_size, n_frames, n_points), where each
42 | visibility represents the likelihood of a point being visible in the corresponding frame
43 | of the video.
44 | """
45 | pass
46 |
47 | def evaluate_batch(self, rgbs, query_points, trajectories_gt=None, visibilities_gt=None):
48 | """
49 | Evaluates a batch of data and returns the results.
50 |
51 | Parameters
52 | ----------
53 | rgbs : torch.Tensor
54 | A tensor of shape (batch_size, n_frames, channels, height, width)
55 | containing the RGB images in uint8 [0-255] format.
56 | query_points : torch.Tensor
57 | A tensor of shape (batch_size, n_points, 3) containing the query points,
58 | each point being (t, x, y).
59 | trajectories_gt : torch.Tensor, optional
60 | A 4D tensor representing the ground-truth trajectory. Its shape is (batch_size, n_frames, n_points, 2).
61 | visibilities_gt : torch.Tensor, optional
62 | A 3D tensor representing the ground-truth visibilities. Its shape is (batch_size, n_frames, n_points).
63 |
64 | Returns
65 | -------
66 | dict
67 | A dictionary containing the results.
68 | """
69 | trajectories_pred, visibilities_pred = self.forward(rgbs, query_points)
70 | batch_size = rgbs.shape[0]
71 | n_frames = rgbs.shape[1]
72 | n_points = query_points.shape[1]
73 | assert trajectories_pred.shape == (batch_size, n_frames, n_points, 2)
74 |
75 | results = {
76 | "trajectories_pred": trajectories_pred.detach().clone().cpu(),
77 | "visibilities_pred": visibilities_pred.detach().clone().cpu(),
78 | "query_points": query_points.detach().clone().cpu(),
79 | "trajectories_gt": trajectories_gt.detach().clone().cpu() if trajectories_gt is not None else None,
80 | "visibilities_gt": visibilities_gt.detach().clone().cpu() if visibilities_gt is not None else None,
81 | }
82 |
83 | return results
84 |
85 | @classmethod
86 | def unpack_results(cls, packed_results, batch_idx):
87 | """
88 | Unpacks the results for all point and all videos in the batch.
89 |
90 | Parameters
91 | ----------
92 | packed_results : dict
93 | The dictionary containing the packed results, for all videos in the batch and all points in the video.
94 | batch_idx : int
95 | The index of the current batch.
96 |
97 | Returns
98 | -------
99 | list
100 | A list of dictionaries, each containing the unpacked results for a data point.
101 | """
102 | unpacked_results_list = []
103 | for b in range(packed_results["trajectories_pred"].shape[0]):
104 | for n in range(packed_results["trajectories_pred"].shape[2]):
105 | result = {
106 | "idx": f"{batch_idx}_{b}_{n}",
107 | "iter": batch_idx,
108 | "video_idx": b,
109 | "point_idx_in_video": n,
110 | "query_point": packed_results["query_points"][b, n, :],
111 | "trajectory_pred": packed_results["trajectories_pred"][b, :, n, :],
112 | "visibility_pred": packed_results["visibilities_pred"][b, :, n],
113 | }
114 | if packed_results["trajectories_gt"] is not None:
115 | result["trajectory_gt"] = packed_results["trajectories_gt"][b, :, n, :]
116 | result["visibility_gt"] = packed_results["visibilities_gt"][b, :, n]
117 | unpacked_results_list += [result]
118 | return unpacked_results_list
119 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/point_tracker/utils/__init__.py
--------------------------------------------------------------------------------
/sam_pt/point_tracker/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Adapted from:
2 | # - https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/misc.py
3 | # - https://github.com/aharley/pips2/blob/06bff81f25f2866728ff94f5d3a02c00893a8f15/utils/misc.py
4 |
5 |
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def posemb_sincos_2d_xy(xy, C, temperature=10000, cat_coords=False):
11 | device = xy.device
12 | dtype = xy.dtype
13 | B, S, D = xy.shape
14 | assert (D == 2)
15 | x = xy[:, :, 0]
16 | y = xy[:, :, 1]
17 | assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
18 | omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
19 | omega = 1. / (temperature ** omega)
20 |
21 | y = y.flatten()[:, None] * omega[None, :]
22 | x = x.flatten()[:, None] * omega[None, :]
23 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
24 | pe = pe.reshape(B, S, C).type(dtype)
25 | if cat_coords:
26 | pe = torch.cat([pe, xy], dim=2) # B,N,C+2
27 | return pe
28 |
29 |
30 | def get_3d_embedding(xyz, C, cat_coords=True):
31 | B, N, D = xyz.shape
32 | assert (D == 3)
33 |
34 | x = xyz[:, :, 0:1]
35 | y = xyz[:, :, 1:2]
36 | z = xyz[:, :, 2:3]
37 | div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
38 |
39 | pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
40 | pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
41 | pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
42 |
43 | pe_x[:, :, 0::2] = torch.sin(x * div_term)
44 | pe_x[:, :, 1::2] = torch.cos(x * div_term)
45 |
46 | pe_y[:, :, 0::2] = torch.sin(y * div_term)
47 | pe_y[:, :, 1::2] = torch.cos(y * div_term)
48 |
49 | pe_z[:, :, 0::2] = torch.sin(z * div_term)
50 | pe_z[:, :, 1::2] = torch.cos(z * div_term)
51 |
52 | pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
53 | if cat_coords:
54 | pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
55 | return pe
56 |
57 |
58 | class SimplePool():
59 | def __init__(self, pool_size, version='pt'):
60 | self.pool_size = pool_size
61 | self.version = version
62 | # random.seed(125)
63 | if self.pool_size > 0:
64 | self.num = 0
65 | self.items = []
66 | if not (version == 'pt' or version == 'np'):
67 | print('version = %s; please choose pt or np')
68 | assert (False) # please choose pt or np
69 |
70 | def __len__(self):
71 | return len(self.items)
72 |
73 | def mean(self, min_size='none'):
74 | if min_size == 'half':
75 | pool_size_thresh = self.pool_size / 2
76 | else:
77 | pool_size_thresh = 1
78 |
79 | if self.version == 'np':
80 | if len(self.items) >= pool_size_thresh:
81 | return np.sum(self.items) / float(len(self.items))
82 | else:
83 | return np.nan
84 | if self.version == 'pt':
85 | if len(self.items) >= pool_size_thresh:
86 | return torch.sum(self.items) / float(len(self.items))
87 | else:
88 | return torch.from_numpy(np.nan)
89 |
90 | def sample(self):
91 | idx = np.random.randint(len(self.items))
92 | return self.items[idx]
93 |
94 | def fetch(self, num=None):
95 | if self.version == 'pt':
96 | item_array = torch.stack(self.items)
97 | elif self.version == 'np':
98 | item_array = np.stack(self.items)
99 | if num is not None:
100 | # there better be some items
101 | assert (len(self.items) >= num)
102 |
103 | # if there are not that many elements just return however many there are
104 | if len(self.items) < num:
105 | return item_array
106 | else:
107 | idxs = np.random.randint(len(self.items), size=num)
108 | return item_array[idxs]
109 | else:
110 | return item_array
111 |
112 | def is_full(self):
113 | full = self.num == self.pool_size
114 | # print 'num = %d; full = %s' % (self.num, full)
115 | return full
116 |
117 | def empty(self):
118 | self.items = []
119 | self.num = 0
120 |
121 | def update(self, items):
122 | for item in items:
123 | if self.num < self.pool_size:
124 | # the pool is not full, so let's add this in
125 | self.num = self.num + 1
126 | else:
127 | # the pool is full
128 | # pop from the front
129 | self.items.pop(0)
130 | # add to the back
131 | self.items.append(item)
132 | return self.items
133 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/utils/samp.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/samp.py
2 |
3 | import torch
4 |
5 |
6 | def bilinear_sample2d(im, x, y, return_inbounds=False):
7 | # x and y are each B, N
8 | # output is B, C, N
9 | B, C, H, W = list(im.shape)
10 | N = list(x.shape)[1]
11 |
12 | x = x.float()
13 | y = y.float()
14 | H_f = torch.tensor(H, dtype=torch.float32)
15 | W_f = torch.tensor(W, dtype=torch.float32)
16 |
17 | # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte()
74 | y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
75 | inbounds = (x_valid & y_valid).float()
76 | inbounds = inbounds.reshape(B,
77 | N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
78 | return output, inbounds
79 |
80 | return output # B, C, N
81 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/utils/saverloader.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/saverloader.py
2 |
3 | import os
4 | import pathlib
5 |
6 | import torch
7 |
8 |
9 | def save(ckpt_dir, optimizer, model, global_step, scheduler=None, model_ema=None, keep_latest=5, model_name='model'):
10 | if not os.path.exists(ckpt_dir):
11 | os.makedirs(ckpt_dir)
12 |
13 | prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*' % model_name))
14 | prev_ckpts.sort(key=lambda p: p.stat().st_mtime, reverse=True)
15 | if len(prev_ckpts) > keep_latest - 1:
16 | for f in prev_ckpts[keep_latest - 1:]:
17 | f.unlink()
18 | model_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
19 |
20 | ckpt = {'optimizer_state_dict': optimizer.state_dict()}
21 | ckpt['model_state_dict'] = model.state_dict()
22 | if scheduler is not None:
23 | ckpt['scheduler_state_dict'] = scheduler.state_dict()
24 | if model_ema is not None:
25 | ckpt['ema_model_state_dict'] = model_ema.state_dict()
26 | torch.save(ckpt, model_path)
27 | print("saved a checkpoint: %s" % (model_path))
28 |
29 |
30 | def load(ckpt_dir, model, device=None, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model',
31 | ignore_load=None):
32 | print('reading ckpt from %s' % ckpt_dir)
33 | assert os.path.exists(ckpt_dir)
34 |
35 | ckpt_names = os.listdir(ckpt_dir)
36 | steps = [int((i.split('-')[1]).split('.')[0]) for i in ckpt_names]
37 | assert len(ckpt_names) > 0
38 |
39 | if step == 0:
40 | step = max(steps)
41 | model_name = '%s-%09d.pth' % (model_name, step)
42 | path = os.path.join(ckpt_dir, model_name)
43 | print('...found checkpoint %s' % (path))
44 |
45 | if ignore_load is not None:
46 |
47 | print('ignoring', ignore_load)
48 |
49 | checkpoint = torch.load(path)['model_state_dict']
50 |
51 | model_dict = model.state_dict()
52 |
53 | # 1. filter out ignored keys
54 | pretrained_dict = {k: v for k, v in checkpoint.items()}
55 | for ign in ignore_load:
56 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if not ign in k}
57 |
58 | # 2. overwrite entries in the existing state dict
59 | model_dict.update(pretrained_dict)
60 | # 3. load the new state dict
61 | model.load_state_dict(model_dict, strict=False)
62 | else:
63 | checkpoint = torch.load(path, map_location=device)
64 | model.load_state_dict(checkpoint['model_state_dict'], strict=False)
65 |
66 | if optimizer is not None:
67 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
68 | if scheduler is not None:
69 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
70 | if model_ema is not None:
71 | model_ema.load_state_dict(checkpoint['ema_model_state_dict'])
72 |
73 | return step
74 |
--------------------------------------------------------------------------------
/sam_pt/point_tracker/utils/test.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/utils/test.py
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def prep_frame_for_dino(img, scale_size=[192]):
10 | """
11 | read a single frame & preprocess
12 | """
13 | ori_h, ori_w, _ = img.shape
14 | if len(scale_size) == 1:
15 | if (ori_h > ori_w):
16 | tw = scale_size[0]
17 | th = (tw * ori_h) / ori_w
18 | th = int((th // 64) * 64)
19 | else:
20 | th = scale_size[0]
21 | tw = (th * ori_w) / ori_h
22 | tw = int((tw // 64) * 64)
23 | else:
24 | th, tw = scale_size
25 | img = cv2.resize(img, (tw, th))
26 | img = img.astype(np.float32)
27 | img = img / 255.0
28 | img = img[:, :, ::-1]
29 | img = np.transpose(img.copy(), (2, 0, 1))
30 | img = torch.from_numpy(img).float()
31 |
32 | def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
33 | for t, m, s in zip(x, mean, std):
34 | t.sub_(m)
35 | t.div_(s)
36 | return x
37 |
38 | img = color_normalize(img)
39 | return img, ori_h, ori_w
40 |
41 |
42 | def get_feats_from_dino(model, frame):
43 | # batch version of the other func
44 | B = frame.shape[0]
45 | patch_size = model.patch_embed.patch_size
46 | h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size)
47 | out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim
48 | dim = out.shape[-1]
49 | out = out[:, 1:, :] # discard the [CLS] token
50 | outmap = out.permute(0, 2, 1).reshape(B, dim, h, w)
51 | return out, outmap, h, w
52 |
53 |
54 | def restrict_neighborhood(h, w):
55 | size_mask_neighborhood = 12
56 | # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
57 | mask = torch.zeros(h, w, h, w)
58 | for i in range(h):
59 | for j in range(w):
60 | for p in range(2 * size_mask_neighborhood + 1):
61 | for q in range(2 * size_mask_neighborhood + 1):
62 | if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h:
63 | continue
64 | if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w:
65 | continue
66 | mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1
67 |
68 | mask = mask.reshape(h * w, h * w)
69 | return mask.cuda(non_blocking=True)
70 |
71 |
72 | def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None):
73 | ncontext = len(list_frame_feats)
74 | feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
75 |
76 | feat_tar = F.normalize(feat_tar, dim=1, p=2)
77 | feat_sources = F.normalize(feat_sources, dim=1, p=2)
78 |
79 | # print('feat_tar', feat_tar.shape)
80 | # print('feat_sources', feat_sources.shape)
81 |
82 | feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
83 | aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1)
84 |
85 | size_mask_neighborhood = 12
86 | if size_mask_neighborhood > 0:
87 | if mask_neighborhood is None:
88 | mask_neighborhood = restrict_neighborhood(h, w)
89 | mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
90 | aff *= mask_neighborhood
91 |
92 | aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
93 | topk = 5
94 | tk_val, _ = torch.topk(aff, dim=0, k=topk)
95 | tk_val_min, _ = torch.min(tk_val, dim=0)
96 | aff[aff < tk_val_min] = 0
97 |
98 | aff = aff / torch.sum(aff, keepdim=True, axis=0)
99 |
100 | list_segs = [s.cuda() for s in list_segs]
101 | segs = torch.cat(list_segs)
102 | nmb_context, C, h, w = segs.shape
103 | segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
104 | seg_tar = torch.mm(segs, aff)
105 | seg_tar = seg_tar.reshape(1, C, h, w)
106 |
107 | return seg_tar, mask_neighborhood
108 |
109 |
110 | def norm_mask(mask):
111 | c, h, w = mask.size()
112 | for cnt in range(c):
113 | mask_cnt = mask[cnt, :, :]
114 | if (mask_cnt.max() > 0):
115 | mask_cnt = (mask_cnt - mask_cnt.min())
116 | mask_cnt = mask_cnt / mask_cnt.max()
117 | mask[cnt, :, :] = mask_cnt
118 | return mask
119 |
120 |
121 | def get_dino_output(dino, rgbs, trajs_g, vis_g):
122 | B, S, C, H, W = rgbs.shape
123 |
124 | B1, S1, N, D = trajs_g.shape
125 | assert (B1 == B)
126 | assert (S1 == S)
127 | assert (D == 2)
128 |
129 | assert (B == 1)
130 | xy0 = trajs_g[:, 0] # B, N, 2
131 |
132 | # The queue stores the n preceeding frames
133 | import queue
134 | import copy
135 | n_last_frames = 7
136 | que = queue.Queue(n_last_frames)
137 |
138 | # run dino
139 | prep_rgbs = []
140 | for s in range(S):
141 | prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1, 2, 0).detach().cpu().numpy(), scale_size=[H])
142 | prep_rgbs.append(prep_rgb)
143 | prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W
144 | with torch.no_grad():
145 | bs = 8
146 | idx = 0
147 | featmaps = []
148 | while idx < S:
149 | end_id = min(S, idx + bs)
150 | _, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w
151 | idx = end_id
152 | featmaps.append(featmaps_cur)
153 | featmaps = torch.cat(featmaps, dim=0)
154 | C = featmaps.shape[1]
155 | featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w
156 | # featmaps = F.normalize(featmaps, dim=2, p=2)
157 |
158 | xy0 = trajs_g[:, 0, :] # B, N, 2
159 | patch_size = dino.patch_embed.patch_size
160 | first_seg = torch.zeros((1, N, H // patch_size, W // patch_size))
161 | for n in range(N):
162 | first_seg[0, n, (xy0[0, n, 1] / patch_size).long(), (xy0[0, n, 0] / patch_size).long()] = 1
163 |
164 | frame1_feat = featmaps[0, 0].reshape(C, h * w) # dim x h*w
165 | mask_neighborhood = None
166 | accs = []
167 | trajs_e = torch.zeros_like(trajs_g)
168 | trajs_e[0, 0] = trajs_g[0, 0]
169 | for cnt in range(1, S):
170 | used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
171 | used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
172 |
173 | feat_tar = featmaps[0, cnt].reshape(C, h * w)
174 |
175 | frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs,
176 | mask_neighborhood)
177 |
178 | # pop out oldest frame if neccessary
179 | if que.qsize() == n_last_frames:
180 | que.get()
181 | # push current results into queue
182 | seg = copy.deepcopy(frame_tar_avg)
183 | que.put([feat_tar, seg])
184 |
185 | # upsampling & argmax
186 | frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False,
187 | recompute_scale_factor=False)[0]
188 | frame_tar_avg = norm_mask(frame_tar_avg)
189 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
190 |
191 | for n in range(N):
192 | vis = vis_g[0, cnt, n]
193 | if len(torch.nonzero(frame_tar_avg[n])) > 0:
194 | # weighted average
195 | nz = torch.nonzero(frame_tar_avg[n])
196 | coord_e = torch.sum(frame_tar_avg[n][nz[:, 0], nz[:, 1]].reshape(-1, 1) * nz.float(), 0) / \
197 | frame_tar_avg[n][nz[:, 0], nz[:, 1]].sum() # 2
198 | coord_e = coord_e[[1, 0]]
199 | else:
200 | # stay where it was
201 | coord_e = trajs_e[0, cnt - 1, n]
202 |
203 | trajs_e[0, cnt, n] = coord_e
204 | return trajs_e
205 |
--------------------------------------------------------------------------------
/sam_pt/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/utils/__init__.py
--------------------------------------------------------------------------------
/sam_pt/vis_eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vis_eval/__init__.py
--------------------------------------------------------------------------------
/sam_pt/vis_eval/eval.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pandas as pd
3 | import wandb
4 | from hydra.core.hydra_config import HydraConfig
5 | from hydra.utils import instantiate
6 | from omegaconf import DictConfig, OmegaConf
7 |
8 | from .train_net_video import *
9 |
10 |
11 | def main_inner(cfg: DictConfig) -> None:
12 | # Setup config
13 | detectron2_config = cfg.DETECTRON2_CONFIG
14 | default_setup(detectron2_config, {"eval_only": True})
15 |
16 | # Setup logging
17 | setup_logger(name="point_tracking_vis_eval")
18 | setup_logger(output=cfg.DETECTRON2_CONFIG.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="point_tracking_video")
19 | if comm.is_main_process():
20 | wandb.init(
21 | entity=cfg.logging.wandb.entity,
22 | project=cfg.logging.wandb.project,
23 | name=cfg.logging.exp_id,
24 | group=cfg.logging.exp_id,
25 | config={
26 | "cfg": OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
27 | "work_dir": os.getcwd(),
28 | "hydra_cfg": HydraConfig.get() if HydraConfig.instance().cfg is not None else None,
29 | },
30 | )
31 | wandb.run.log_code(cfg.logging.wandb.log_code_path)
32 | wandb.run.summary["work_dir"] = os.path.abspath(os.getcwd())
33 |
34 | # Load model
35 | model = instantiate(cfg.model)
36 | model = model.to(cfg.device)
37 | model = model.eval()
38 |
39 | # Evaluate model
40 | results = Trainer.test(detectron2_config, model)
41 | print(f"Process {comm.get_rank()} has finished evaluation. Results: {results}")
42 | if detectron2_config.TEST.AUG.ENABLED:
43 | raise NotImplementedError
44 | if comm.is_main_process():
45 | print("Results verification by the main process has started")
46 | verify_results(detectron2_config, results)
47 | print("Results verification has finished")
48 |
49 | df_global = pd.DataFrame.from_dict(results["segm"], orient="index").T
50 | wandb.log({"df_global": wandb.Table(dataframe=df_global)})
51 | wandb.run.summary["score"] = df_global["AR100"].item()
52 |
53 |
54 | @hydra.main(config_path="../../configs", config_name="vis_eval_sam_pt", version_base="1.1")
55 | def main(cfg: DictConfig) -> None:
56 | print(OmegaConf.to_yaml(cfg))
57 | OmegaConf.resolve(cfg)
58 | OmegaConf.set_readonly(cfg, True)
59 | launch(
60 | main_inner,
61 | num_gpus_per_machine=cfg.num_gpus_per_machine,
62 | num_machines=cfg.num_machines,
63 | machine_rank=cfg.machine_rank,
64 | dist_url=cfg.dist_url,
65 | args=(cfg,),
66 | )
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import add_maskformer2_config
2 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # Taken from: https://github.com/facebookresearch/Mask2Former/blob/9b0651c6c1d5b3af2e6da0589b719c514ec0d69a/mask2former/config.py
4 |
5 | from detectron2.config import CfgNode as CN
6 |
7 |
8 | def add_maskformer2_config(cfg):
9 | """
10 | Add config for MASK_FORMER.
11 | """
12 | # NOTE: configs from original maskformer
13 | # data config
14 | # select the dataset mapper
15 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
16 | # Color augmentation
17 | cfg.INPUT.COLOR_AUG_SSD = False
18 | # We retry random cropping until no single category in semantic segmentation GT occupies more
19 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
20 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
21 | # Pad image and segmentation GT in dataset mapper.
22 | cfg.INPUT.SIZE_DIVISIBILITY = -1
23 |
24 | # solver config
25 | # weight decay on embedding
26 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
27 | # optimizer
28 | cfg.SOLVER.OPTIMIZER = "ADAMW"
29 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
30 |
31 | # mask_former model config
32 | cfg.MODEL.MASK_FORMER = CN()
33 |
34 | # loss
35 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
36 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
37 | cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0
38 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
39 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
40 |
41 | # transformer config
42 | cfg.MODEL.MASK_FORMER.NHEADS = 8
43 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
44 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
45 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
46 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
47 | cfg.MODEL.MASK_FORMER.PRE_NORM = False
48 |
49 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
50 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
51 |
52 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
53 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
54 |
55 | # mask_former inference config
56 | cfg.MODEL.MASK_FORMER.TEST = CN()
57 | cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
58 | cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
59 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
60 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
61 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
62 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
63 |
64 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
65 | # you can use this config to override
66 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
67 |
68 | # pixel decoder config
69 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
70 | # adding transformer in pixel decoder
71 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
72 | # pixel decoder
73 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
74 |
75 | # swin transformer backbone
76 | cfg.MODEL.SWIN = CN()
77 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
78 | cfg.MODEL.SWIN.PATCH_SIZE = 4
79 | cfg.MODEL.SWIN.EMBED_DIM = 96
80 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
81 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
82 | cfg.MODEL.SWIN.WINDOW_SIZE = 7
83 | cfg.MODEL.SWIN.MLP_RATIO = 4.0
84 | cfg.MODEL.SWIN.QKV_BIAS = True
85 | cfg.MODEL.SWIN.QK_SCALE = None
86 | cfg.MODEL.SWIN.DROP_RATE = 0.0
87 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
88 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
89 | cfg.MODEL.SWIN.APE = False
90 | cfg.MODEL.SWIN.PATCH_NORM = True
91 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
92 | cfg.MODEL.SWIN.USE_CHECKPOINT = False
93 |
94 | # NOTE: maskformer2 extra configs
95 | # transformer module
96 | cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"
97 |
98 | # LSJ aug
99 | cfg.INPUT.IMAGE_SIZE = 1024
100 | cfg.INPUT.MIN_SCALE = 0.1
101 | cfg.INPUT.MAX_SCALE = 2.0
102 |
103 | # MSDeformAttn encoder configs
104 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
105 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
106 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
107 |
108 | # point loss configs
109 | # Number of points sampled during training for a mask point head.
110 | cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112
111 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
112 | # original paper.
113 | cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
114 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
115 | # the original paper.
116 | cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
117 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/README.md:
--------------------------------------------------------------------------------
1 | # License
2 |
3 | This directory contains code adapted from the [Mask2Former](https://github.com/facebookresearch/Mask2Former/tree/9b0651c6c1d5b3af2e6da0589b719c514ec0d69a) project by Facebook Research, which was released under the MIT license as follows:
4 |
5 | ```txt
6 | Copyright (c) 2022 Meta, Inc.
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 | ```
26 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # config
4 | from .config import add_maskformer2_video_config
5 |
6 | # video
7 | from .data_video import (
8 | YTVISDatasetMapper,
9 | YTVISEvaluator,
10 | build_detection_train_loader,
11 | build_detection_test_loader,
12 | get_detection_dataset_dicts,
13 | )
14 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | from detectron2.config import CfgNode as CN
4 |
5 |
6 | def add_maskformer2_video_config(cfg):
7 | # video data
8 | # DataLoader
9 | cfg.INPUT.SAMPLING_FRAME_NUM = 2
10 | cfg.INPUT.SAMPLING_FRAME_RANGE = 20
11 | cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False
12 | cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation"
13 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC
3 |
4 | from .dataset_mapper import YTVISDatasetMapper, CocoClipDatasetMapper
5 | from .build import *
6 |
7 | from .datasets import *
8 | from .ytvis_eval import YTVISEvaluator
9 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/augmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC
3 |
4 | import logging
5 | import sys
6 |
7 | import numpy as np
8 | from PIL import Image
9 | from detectron2.data import transforms as T
10 | from fvcore.transforms.transform import (
11 | HFlipTransform,
12 | NoOpTransform,
13 | VFlipTransform,
14 | )
15 |
16 |
17 | class ResizeShortestEdge(T.Augmentation):
18 | """
19 | Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
20 | If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
21 | """
22 |
23 | def __init__(
24 | self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR, clip_frame_cnt=1
25 | ):
26 | """
27 | Args:
28 | short_edge_length (list[int]): If ``sample_style=="range"``,
29 | a [min, max] interval from which to sample the shortest edge length.
30 | If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
31 | max_size (int): maximum allowed longest edge length.
32 | sample_style (str): either "range" or "choice".
33 | """
34 | super().__init__()
35 | assert sample_style in ["range", "choice", "range_by_clip", "choice_by_clip"], sample_style
36 |
37 | self.is_range = ("range" in sample_style)
38 | if isinstance(short_edge_length, int):
39 | short_edge_length = (short_edge_length, short_edge_length)
40 | if self.is_range:
41 | assert len(short_edge_length) == 2, (
42 | "short_edge_length must be two values using 'range' sample style."
43 | f" Got {short_edge_length}!"
44 | )
45 | self._cnt = 0
46 | self._init(locals())
47 | self.inerp = interp
48 |
49 | def get_transform(self, image):
50 | if self._cnt % self.clip_frame_cnt == 0:
51 | if self.is_range:
52 | self.size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
53 | else:
54 | self.size = np.random.choice(self.short_edge_length)
55 | if self.size == 0:
56 | return NoOpTransform()
57 |
58 | self._cnt = 0 # avoiding overflow
59 | self._cnt += 1
60 |
61 | h, w = image.shape[:2]
62 |
63 | scale = self.size * 1.0 / min(h, w)
64 | if h < w:
65 | newh, neww = self.size, scale * w
66 | else:
67 | newh, neww = scale * h, self.size
68 | if max(newh, neww) > self.max_size:
69 | scale = self.max_size * 1.0 / max(newh, neww)
70 | newh = newh * scale
71 | neww = neww * scale
72 | neww = int(neww + 0.5)
73 | newh = int(newh + 0.5)
74 | return T.ResizeTransform(h, w, newh, neww, self.interp)
75 |
76 |
77 | class RandomFlip(T.Augmentation):
78 | """
79 | Flip the image horizontally or vertically with the given probability.
80 | """
81 |
82 | def __init__(self, prob=0.5, *, horizontal=True, vertical=False, clip_frame_cnt=1):
83 | """
84 | Args:
85 | prob (float): probability of flip.
86 | horizontal (boolean): whether to apply horizontal flipping
87 | vertical (boolean): whether to apply vertical flipping
88 | """
89 | super().__init__()
90 |
91 | if horizontal and vertical:
92 | raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
93 | if not horizontal and not vertical:
94 | raise ValueError("At least one of horiz or vert has to be True!")
95 | self._cnt = 0
96 |
97 | self._init(locals())
98 |
99 | def get_transform(self, image):
100 | if self._cnt % self.clip_frame_cnt == 0:
101 | self.do = self._rand_range() < self.prob
102 | self._cnt = 0 # avoiding overflow
103 | self._cnt += 1
104 |
105 | h, w = image.shape[:2]
106 |
107 | if self.do:
108 | if self.horizontal:
109 | return HFlipTransform(w)
110 | elif self.vertical:
111 | return VFlipTransform(h)
112 | else:
113 | return NoOpTransform()
114 |
115 |
116 | def build_augmentation(cfg, is_train):
117 | logger = logging.getLogger(__name__)
118 | aug_list = []
119 | if is_train:
120 | # Crop
121 | if cfg.INPUT.CROP.ENABLED:
122 | aug_list.append(T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
123 |
124 | # Resize
125 | min_size = cfg.INPUT.MIN_SIZE_TRAIN
126 | max_size = cfg.INPUT.MAX_SIZE_TRAIN
127 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
128 | ms_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM if "by_clip" in cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else 1
129 | aug_list.append(ResizeShortestEdge(min_size, max_size, sample_style, clip_frame_cnt=ms_clip_frame_cnt))
130 |
131 | # Flip
132 | if cfg.INPUT.RANDOM_FLIP != "none":
133 | if cfg.INPUT.RANDOM_FLIP == "flip_by_clip":
134 | flip_clip_frame_cnt = cfg.INPUT.SAMPLING_FRAME_NUM
135 | else:
136 | flip_clip_frame_cnt = 1
137 |
138 | aug_list.append(
139 | # NOTE using RandomFlip modified for the support of flip maintenance
140 | RandomFlip(
141 | horizontal=(cfg.INPUT.RANDOM_FLIP == "horizontal") or (cfg.INPUT.RANDOM_FLIP == "flip_by_clip"),
142 | vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
143 | clip_frame_cnt=flip_clip_frame_cnt,
144 | )
145 | )
146 |
147 | # Additional augmentations : brightness, contrast, saturation, rotation
148 | augmentations = cfg.INPUT.AUGMENTATIONS
149 | if "brightness" in augmentations:
150 | aug_list.append(T.RandomBrightness(0.9, 1.1))
151 | if "contrast" in augmentations:
152 | aug_list.append(T.RandomContrast(0.9, 1.1))
153 | if "saturation" in augmentations:
154 | aug_list.append(T.RandomSaturation(0.9, 1.1))
155 | if "rotation" in augmentations:
156 | aug_list.append(
157 | T.RandomRotation(
158 | [-15, 15], expand=False, center=[(0.4, 0.4), (0.6, 0.6)], sample_style="range"
159 | )
160 | )
161 | else:
162 | # Resize
163 | min_size = cfg.INPUT.MIN_SIZE_TEST
164 | max_size = cfg.INPUT.MAX_SIZE_TEST
165 | sample_style = "choice"
166 | aug_list.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
167 |
168 | return aug_list
169 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC
3 |
4 | from . import builtin # ensure the builtin datasets are registered
5 |
6 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
7 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/datasets/builtin.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/sukjunhwang/IFC
3 |
4 | import os
5 |
6 | from .uvo import (
7 | _get_uvo_v1_instances_meta,
8 | )
9 | from .ytvis import (
10 | register_ytvis_instances,
11 | _get_ytvis_2019_instances_meta,
12 | _get_ytvis_2021_instances_meta,
13 | )
14 |
15 | # ==== Predefined splits for YTVIS 2019 ===========
16 | _PREDEFINED_SPLITS_YTVIS_2019 = {
17 | "ytvis_2019_train": ("ytvis_2019/train/JPEGImages",
18 | "ytvis_2019/train.json"),
19 | "ytvis_2019_val": ("ytvis_2019/valid/JPEGImages",
20 | "ytvis_2019/valid.json"),
21 | "ytvis_2019_test": ("ytvis_2019/test/JPEGImages",
22 | "ytvis_2019/test.json"),
23 | }
24 |
25 | # ==== Predefined splits for YTVIS 2021 ===========
26 | _PREDEFINED_SPLITS_YTVIS_2021 = {
27 | "ytvis_2021_train": ("ytvis_2021/train/JPEGImages",
28 | "ytvis_2021/train/instances.json"),
29 | "ytvis_2021_train_mini": ("ytvis_2021/train/JPEGImages",
30 | "ytvis_2021/train/instances.mini.27.json"),
31 | "ytvis_2021_train_tiny": (
32 | # cat data/ytvis_2021/train/instances.mini.27.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1]]' > data/ytvis_2021/train/instances.tiny.1.json
33 | "ytvis_2021/train/JPEGImages",
34 | "ytvis_2021/train/instances.tiny.1.json",
35 | ),
36 | "ytvis_2021_val": ("ytvis_2021/valid/JPEGImages",
37 | "ytvis_2021/valid/instances.json"),
38 | "ytvis_2021_val_mini": ("ytvis_2021/valid/JPEGImages",
39 | "ytvis_2021/valid/instances.mini.27.json"),
40 | "ytvis_2021_val_tiny": ("ytvis_2021/valid/JPEGImages",
41 | "ytvis_2021/valid/instances.mini.1.json"),
42 | "ytvis_2021_test": ("ytvis_2021/test/JPEGImages",
43 | "ytvis_2021/test/instances.json"),
44 | }
45 |
46 | _PREDEFINED_SPLITS_UVO_V1 = {
47 | "uvo_v1_train": ("UVOv1.0/uvo_videos_dense_frames/",
48 | "UVOv1.0/VideoDenseSet/UVO_video_train_dense.json"),
49 | "uvo_v1_val": ("UVOv1.0/uvo_videos_dense_frames/",
50 | "UVOv1.0/VideoDenseSet/UVO_video_val_dense.json"),
51 | "uvo_v1_val_tiny": (
52 | # Contains only 1 video
53 | # Split created using jq: `cat data/UVOv1.0/VideoDenseSet/UVO_video_val_dense.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1,2,3]]' > data/UVOv1.0/VideoDenseSet/UVO_video_val_dense.tiny.1.json`
54 | "UVOv1.0/uvo_videos_dense_frames/",
55 | "UVOv1.0/VideoDenseSet/UVO_video_val_dense.tiny.1.json",
56 | ),
57 | "uvo_v1_test": ("UVOv1.0/uvo_videos_dense_frames/",
58 | "UVOv1.0/VideoDenseSet/UVO_video_test_dense.json"),
59 | }
60 |
61 | _PREDEFINED_SPLITS_UVO_V05 = {
62 | "uvo_v05_train": ("UVOv1.0/uvo_videos_dense_frames/",
63 | "UVOv0.5/VideoDenseSet/UVO_video_train_dense.json"),
64 | "uvo_v05_val": ("UVOv1.0/uvo_videos_dense_frames/",
65 | "UVOv0.5/VideoDenseSet/UVO_video_val_dense.json"),
66 | "uvo_v05_val_tiny": (
67 | # Contains only 1 video
68 | # Split created using jq: `cat data/UVOv0.5/VideoDenseSet/UVO_video_val_dense.json | jq '.videos |= [.[0]] | .annotations |= [.[0,1,2,3]]' > data/UVOv0.5/VideoDenseSet/UVO_video_val_dense.tiny.1.json`
69 | "UVOv1.0/uvo_videos_dense_frames/",
70 | "UVOv0.5/VideoDenseSet/UVO_video_val_dense.tiny.1.json",
71 | ),
72 | "uvo_v05_test": ("UVOv1.0/uvo_videos_dense_frames/",
73 | "UVOv0.5/VideoDenseSet/UVO_video_test_dense.json"),
74 | }
75 |
76 |
77 | def register_all_ytvis_2019(root):
78 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_YTVIS_2019.items():
79 | # Assume pre-defined datasets live in `./datasets`.
80 | register_ytvis_instances(
81 | key,
82 | _get_ytvis_2019_instances_meta(),
83 | os.path.join(root, json_file) if "://" not in json_file else json_file,
84 | os.path.join(root, image_root),
85 | )
86 |
87 |
88 | def register_all_ytvis_2021(root):
89 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_YTVIS_2021.items():
90 | # Assume pre-defined datasets live in `./datasets`.
91 | register_ytvis_instances(
92 | key,
93 | _get_ytvis_2021_instances_meta(),
94 | os.path.join(root, json_file) if "://" not in json_file else json_file,
95 | os.path.join(root, image_root),
96 | )
97 |
98 |
99 | def register_all_uvo_v1(_root):
100 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_UVO_V1.items():
101 | # Assume pre-defined datasets live in `./datasets`.
102 | register_ytvis_instances(
103 | key,
104 | _get_uvo_v1_instances_meta(),
105 | os.path.join(_root, json_file) if "://" not in json_file else json_file,
106 | os.path.join(_root, image_root),
107 | )
108 |
109 |
110 | def register_all_uvo_v05(_root):
111 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_UVO_V05.items():
112 | # Assume pre-defined datasets live in `./datasets`.
113 | register_ytvis_instances(
114 | key,
115 | _get_uvo_v1_instances_meta(),
116 | os.path.join(_root, json_file) if "://" not in json_file else json_file,
117 | os.path.join(_root, image_root),
118 | )
119 |
120 |
121 | if __name__.endswith(".builtin"):
122 | # Assume pre-defined datasets live in `./datasets`.
123 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
124 | register_all_ytvis_2019(_root)
125 | register_all_ytvis_2021(_root)
126 | register_all_uvo_v1(_root)
127 | register_all_uvo_v05(_root)
128 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/datasets/uvo.py:
--------------------------------------------------------------------------------
1 | UVO_CATEGORIES_V1_CLASS_AGNOSTIC = [
2 | {"color": [106, 0, 228], "isthing": 1, "id": 1, "name": "object"},
3 | ]
4 |
5 |
6 | def _get_uvo_v1_instances_meta():
7 | thing_ids = [k["id"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1]
8 | assert len(thing_ids) == 1, len(thing_ids)
9 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
10 | thing_classes = [k["name"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1]
11 | thing_colors = [k["color"] for k in UVO_CATEGORIES_V1_CLASS_AGNOSTIC if k["isthing"] == 1]
12 | ret = {
13 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
14 | "thing_classes": thing_classes,
15 | "thing_colors": thing_colors,
16 | }
17 | return ret
18 |
--------------------------------------------------------------------------------
/sam_pt/vis_eval/mask2former_video/data_video/datasets/ytvis_api/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/youtubevos/cocoapi
3 |
--------------------------------------------------------------------------------
/sam_pt/vos_eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vos_eval/__init__.py
--------------------------------------------------------------------------------
/sam_pt/vos_eval/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SysCV/sam-pt/874ff7e73d6ab05418a494d7a02ca233c0b31e8c/sam_pt/vos_eval/data/__init__.py
--------------------------------------------------------------------------------
/sam_pt/vos_eval/data/mask_mapper.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/mask_mapper.py
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def all_to_onehot(masks, labels):
8 | if len(masks.shape) == 3:
9 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
10 | else:
11 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
12 |
13 | for ni, l in enumerate(labels):
14 | Ms[ni] = (masks == l).astype(np.uint8)
15 |
16 | return Ms
17 |
18 |
19 | class MaskMapper:
20 | """
21 | This class is used to convert a indexed-mask to a one-hot representation.
22 | It also takes care of remapping non-continuous indices
23 | It has two modes:
24 | 1. Default. Only masks with new indices are supposed to go into the remapper.
25 | This is also the case for YouTubeVOS.
26 | i.e., regions with index 0 are not "background", but "don't care".
27 |
28 | 2. Exhaustive. Regions with index 0 are considered "background".
29 | Every single pixel is considered to be "labeled".
30 | """
31 |
32 | def __init__(self):
33 | self.labels = []
34 | self.remappings = {}
35 |
36 | # if coherent, no mapping is required
37 | self.coherent = True
38 |
39 | def convert_mask(self, mask, exhaustive=False, dtype=np.uint8, old_labels_allowed=False):
40 | # mask is in index representation, H*W numpy array
41 | labels = np.unique(mask).astype(dtype)
42 | labels = labels[labels != 0].tolist()
43 |
44 | new_labels = list(set(labels) - set(self.labels))
45 | if not exhaustive and not old_labels_allowed:
46 | assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
47 |
48 | # add new remappings
49 | for i, l in enumerate(new_labels):
50 | self.remappings[l] = i + len(self.labels) + 1
51 | if self.coherent and i + len(self.labels) + 1 != l:
52 | self.coherent = False
53 |
54 | if exhaustive:
55 | new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1)
56 | else:
57 | if self.coherent:
58 | new_mapped_labels = new_labels
59 | else:
60 | new_mapped_labels = range(len(self.labels) + 1, len(self.labels) + len(new_labels) + 1)
61 |
62 | self.labels.extend(new_labels)
63 | mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
64 |
65 | # mask num_objects*H*W
66 | return mask, new_mapped_labels
67 |
68 | def remap_index_mask(self, mask):
69 | # mask is in index representation, H*W numpy array
70 | if self.coherent:
71 | return mask
72 |
73 | new_mask = np.zeros_like(mask)
74 | for l, i in self.remappings.items():
75 | new_mask[mask == i] = l
76 | return new_mask
77 |
--------------------------------------------------------------------------------
/sam_pt/vos_eval/data/test_datasets.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/test_datasets.py
2 |
3 | import json
4 | import os
5 | from os import path
6 |
7 | import numpy as np
8 |
9 | from .video_reader import VideoReader
10 |
11 |
12 | class LongTestDataset:
13 | def __init__(self, data_root, size=-1, longest_size=None):
14 | self.image_dir = path.join(data_root, 'JPEGImages')
15 | self.mask_dir = path.join(data_root, 'Annotations')
16 | self.size = size
17 | self.longest_size = longest_size
18 |
19 | self.vid_list = sorted(os.listdir(self.image_dir))
20 |
21 | def get_datasets(self):
22 | for video in self.vid_list:
23 | yield VideoReader(video,
24 | path.join(self.image_dir, video),
25 | path.join(self.mask_dir, video),
26 | to_save=[
27 | name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
28 | ],
29 | shortest_size=self.size,
30 | longest_size=self.longest_size)
31 |
32 | def __len__(self):
33 | return len(self.vid_list)
34 |
35 |
36 | class DAVISTestDataset:
37 | def __init__(self, data_root, imset='2017/val.txt', size=-1, longest_size=None, return_all_gt_masks=False):
38 | if size != 480:
39 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
40 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
41 | if not path.exists(self.image_dir):
42 | print(f'{self.image_dir} not found. Look at other options.')
43 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
44 | self.mask_dir = path.join(data_root, 'Annotations', '1080p')
45 | assert path.exists(self.image_dir), 'path not found'
46 | else:
47 | self.image_dir = path.join(data_root, 'JPEGImages', '480p')
48 | self.mask_dir = path.join(data_root, 'Annotations', '480p')
49 | self.size_dir = path.join(data_root, 'JPEGImages', '480p')
50 | self.size = size
51 | self.longest_size = longest_size
52 | self.return_all_gt_masks = return_all_gt_masks
53 |
54 | with open(path.join(data_root, 'ImageSets', imset)) as f:
55 | self.vid_list = sorted([line.strip() for line in f])
56 |
57 | def get_datasets(self):
58 | for video in self.vid_list:
59 | yield VideoReader(video,
60 | path.join(self.image_dir, video),
61 | path.join(self.mask_dir, video),
62 | shortest_size=self.size,
63 | longest_size=self.longest_size,
64 | size_dir=path.join(self.size_dir, video),
65 | use_all_mask=self.return_all_gt_masks)
66 |
67 | def __len__(self):
68 | return len(self.vid_list)
69 |
70 |
71 | class YouTubeVOSTestDataset:
72 | def __init__(self, data_root, split, size=480, longest_size=None):
73 | self.image_dir = path.join(data_root, 'all_frames', split + '_all_frames', 'JPEGImages')
74 | self.mask_dir = path.join(data_root, split, 'Annotations')
75 | self.size = size
76 | self.longest_size = longest_size
77 |
78 | self.vid_list = sorted(os.listdir(self.image_dir))
79 | self.req_frame_list = {}
80 |
81 | with open(path.join(data_root, split, 'meta.json')) as f:
82 | # read meta.json to know which frame is required for evaluation
83 | meta = json.load(f)['videos']
84 |
85 | for vid in self.vid_list:
86 | req_frames = []
87 | objects = meta[vid]['objects']
88 | for value in objects.values():
89 | req_frames.extend(value['frames'])
90 |
91 | req_frames = list(set(req_frames))
92 | self.req_frame_list[vid] = req_frames
93 |
94 | def get_datasets(self):
95 | for video in self.vid_list:
96 | yield VideoReader(video,
97 | path.join(self.image_dir, video),
98 | path.join(self.mask_dir, video),
99 | shortest_size=self.size,
100 | longest_size=self.longest_size,
101 | to_save=self.req_frame_list[video],
102 | use_all_mask=True)
103 |
104 | def __len__(self):
105 | return len(self.vid_list)
106 |
107 |
108 | class MOSETestDataset:
109 | def __init__(self, data_root, split, shortest_size=-1, longest_size=None):
110 | if split == "val":
111 | split = "valid"
112 |
113 | self.shortest_size = shortest_size
114 | self.longest_size = longest_size
115 |
116 | self.image_dir = path.abspath(path.join(data_root, split, 'JPEGImages'))
117 | self.mask_dir = path.abspath(path.join(data_root, split, 'Annotations'))
118 |
119 | print(f'MOSE-{split}: {self.image_dir}')
120 | print(f'MOSE-{split}: {self.mask_dir}')
121 | assert path.exists(self.image_dir)
122 | assert path.exists(self.mask_dir)
123 |
124 | self.vid_list = sorted(os.listdir(self.image_dir))
125 | print(f'MOSE-{split}: Found {len(self.vid_list)} videos in {self.image_dir}')
126 |
127 | def get_datasets(self):
128 | for video in self.vid_list:
129 | yield VideoReader(
130 | vid_name=video,
131 | image_dir=path.join(self.image_dir, video),
132 | mask_dir=path.join(self.mask_dir, video),
133 | shortest_size=self.shortest_size,
134 | longest_size=self.longest_size,
135 | use_all_mask=True,
136 | )
137 |
138 | def __len__(self):
139 | return len(self.vid_list)
140 |
141 |
142 | class BDD100KTestDataset:
143 | def __init__(self, data_root, split, shortest_size=-1, longest_size=None):
144 | self.shortest_size = shortest_size
145 | self.longest_size = longest_size
146 |
147 | self.image_dir = path.abspath(path.join(data_root, split, 'JPEGImages'))
148 | self.mask_dir = path.abspath(path.join(data_root, split, 'Annotations'))
149 |
150 | print(f'BDD100K-{split}: {self.image_dir}')
151 | print(f'BDD100K-{split}: {self.mask_dir}')
152 | assert path.exists(self.image_dir)
153 | assert path.exists(self.mask_dir)
154 |
155 | self.vid_list = sorted(os.listdir(self.image_dir))
156 | print(f'BDD100K-{split}: Found {len(self.vid_list)} videos in {self.image_dir}')
157 |
158 | def get_datasets(self):
159 | for video in self.vid_list:
160 | yield VideoReader(
161 | vid_name=video,
162 | image_dir=path.join(self.image_dir, video),
163 | mask_dir=path.join(self.mask_dir, video),
164 | shortest_size=self.shortest_size,
165 | longest_size=self.longest_size,
166 | use_all_mask=True,
167 | # mask_mode='I;16',
168 | # mask_dtype=np.int32,
169 | )
170 |
171 | def __len__(self):
172 | return len(self.vid_list)
173 |
--------------------------------------------------------------------------------
/sam_pt/vos_eval/data/video_reader.py:
--------------------------------------------------------------------------------
1 | # Taken from: https://github.com/hkchengrex/XMem/blob/083698bbb4c5ac0ffe1a8923a6c313de46169983/inference/data/video_reader.py
2 |
3 | import os
4 | from os import path
5 |
6 | import numpy as np
7 | import torch.nn.functional as F
8 | from PIL import Image
9 | from segment_anything.utils.transforms import ResizeLongestSide
10 | from torch.utils.data.dataset import Dataset
11 | from torchvision import transforms
12 | from torchvision.transforms import InterpolationMode
13 |
14 |
15 | class VideoReader(Dataset):
16 | """
17 | This class is used to read a video, one frame at a time
18 | """
19 |
20 | def __init__(self, vid_name, image_dir, mask_dir,
21 | shortest_size=-1, longest_size=None,
22 | to_save=None, use_all_mask=False, size_dir=None,
23 | mask_mode='P', mask_dtype=np.uint8,
24 | ):
25 | """
26 | image_dir - points to a directory of jpg images
27 | mask_dir - points to a directory of png masks
28 | size - resize min. side to size. Does nothing if <0.
29 | to_save - optionally contains a list of file names without extensions
30 | where the segmentation mask is required
31 | use_all_mask - when true, read all available mask in mask_dir.
32 | Default false. Set to true for YouTubeVOS validation.
33 | """
34 | assert shortest_size == -1 or longest_size is None, 'One size constraint should be given, not both.'
35 |
36 | self.vid_name = vid_name
37 | self.image_dir = image_dir
38 | self.mask_dir = mask_dir
39 | self.to_save = to_save
40 | self.use_all_mask = use_all_mask
41 | if size_dir is None:
42 | self.size_dir = self.image_dir
43 | else:
44 | self.size_dir = size_dir
45 |
46 | self.mask_mode = mask_mode
47 | self.mask_dtype = mask_dtype
48 |
49 | self.frames = sorted(os.listdir(self.image_dir))
50 | self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
51 | self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
52 |
53 | # TODO SegGPT specific
54 | if shortest_size == "seggpt":
55 | shortest_size = (448, 448)
56 |
57 | self.shortest_size = shortest_size
58 | self.longest_size = longest_size
59 |
60 | # TODO: Model specific transforms are hardcoded here
61 | if self.shortest_size == -1 and self.longest_size is None:
62 | self.resize_longest_side_transform = None
63 | self.im_transform = transforms.Compose([
64 | transforms.ToTensor(),
65 | ])
66 | elif self.shortest_size != -1:
67 | self.resize_longest_side_transform = None
68 | self.im_transform = transforms.Compose([
69 | transforms.ToTensor(),
70 | transforms.Resize(self.shortest_size, interpolation=InterpolationMode.BILINEAR),
71 | ])
72 | elif self.longest_size is not None:
73 | self.resize_longest_side_transform = ResizeLongestSide(self.longest_size)
74 | self.im_transform = transforms.Compose([
75 | transforms.ToTensor(),
76 | ])
77 | else:
78 | raise RuntimeError('Invalid size constraints.')
79 |
80 | def __getitem__(self, idx):
81 | frame = self.frames[idx]
82 | info = {}
83 | data = {}
84 | info['frame'] = frame
85 | info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
86 |
87 | im_path = path.join(self.image_dir, frame)
88 | img = Image.open(im_path).convert('RGB')
89 |
90 | if self.image_dir == self.size_dir:
91 | shape = np.array(img).shape[:2]
92 | else:
93 | size_path = path.join(self.size_dir, frame)
94 | size_im = Image.open(size_path).convert('RGB')
95 | shape = np.array(size_im).shape[:2]
96 |
97 | gt_path = path.join(self.mask_dir, frame[:-4] + '.png')
98 | if self.resize_longest_side_transform is not None:
99 | img = np.array(img)
100 | img = self.resize_longest_side_transform.apply_image(img)
101 |
102 | img = self.im_transform(img)
103 |
104 | load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
105 | if load_mask and path.exists(gt_path):
106 | mask = Image.open(gt_path).convert(self.mask_mode)
107 | mask = np.array(mask, dtype=self.mask_dtype)
108 | data['mask'] = mask
109 |
110 | info['shape'] = shape
111 | info['need_resize'] = self.shortest_size != 0 or self.longest_size is not None
112 | data['rgb'] = img
113 | data['info'] = info
114 |
115 | # TODO: SegGPT specific
116 | if self.shortest_size == (448, 448):
117 | info['shape'] = (448, 448)
118 |
119 | return data
120 |
121 | def resize_mask(self, mask):
122 | # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
123 | old_h, old_w = mask.shape[-2:]
124 | if self.resize_longest_side_transform is None:
125 | min_hw = min(old_h, old_w)
126 | if self.shortest_size == (448, 448):
127 | # TODO SegGPT specific
128 | shape = (448, 448)
129 | else:
130 | shape = (int(old_h / min_hw * self.shortest_size), int(old_w / min_hw * self.shortest_size))
131 | else:
132 | shape = ResizeLongestSide.get_preprocess_shape(old_h, old_w, self.longest_size)
133 | return F.interpolate(mask, shape, mode='nearest')
134 |
135 | def get_palette(self):
136 | return self.palette
137 |
138 | def __len__(self):
139 | return len(self.frames)
140 |
--------------------------------------------------------------------------------
/sam_pt/vos_eval/davis2017eval.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is a modified version of the original DAVIS 2017 evaluation script from:
3 | https://github.com/davisvideochallenge/davis2017-evaluation/blob/ac7c43fca936f9722837b7fbd337d284ba37004b/evaluation_method.py
4 |
5 | Usage:
6 | ```
7 | python -m sam_pt.vos_eval.davis2017eval \
8 | --results_path /srv/beegfs02/scratch/visobt4s/data/3d_point_tracking/sampt_outputs/SegGPT--D17-val--in-sampt-env_D17_val_72_2023.11.09_15.52.53/eval_D17_val \
9 | --davis_path data/DAVIS/2017/trainval \
10 | --set val \
11 | --task semi-supervised \
12 | --year 2017
13 | ```
14 | """
15 |
16 | import argparse
17 | import os
18 | import sys
19 | from time import time
20 | from typing import Union
21 |
22 | import numpy as np
23 | import pandas as pd
24 | from davis2017.evaluation import DAVISEvaluation
25 |
26 |
27 | class Davis2017Evaluator:
28 | def __init__(self, results_path: str, davis_path: str, set: str = "val", task: str = "semi-unsupervised",
29 | year: str = '2017', sequences: Union[str, list] = "all", ):
30 | """
31 | :param results_path: Path to the folder containing the sequences folders.
32 | :param davis_path: Path to the DAVIS folder containing the `JPEGImages`, `Annotations`, `ImageSets`,
33 | `Annotations_unsupervised` folders.
34 | :param set: Subset to evaluate the results.
35 | :param task: Task to evaluate the results.
36 | :param year: DAVIS dataset year.
37 | :param sequences: List of sequences to evaluate. If "all", evaluate all sequences.
38 | """
39 | assert set in ['val', 'test-dev', 'test-challenge']
40 | assert task in ['semi-supervised', 'unsupervised']
41 |
42 | self.davis_path = davis_path
43 | self.set = set
44 | self.task = task
45 | self.year = year
46 | self.sequences = sequences
47 | self.results_path = results_path
48 |
49 | def evaluate(self):
50 | time_start = time()
51 | csv_name_global = f'global_results-{self.set}.csv'
52 | csv_name_per_sequence = f'per-sequence_results-{self.set}.csv'
53 |
54 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results
55 | csv_name_global_path = os.path.join(self.results_path, csv_name_global)
56 | csv_name_per_sequence_path = os.path.join(self.results_path, csv_name_per_sequence)
57 | if os.path.exists(csv_name_global_path) and os.path.exists(csv_name_per_sequence_path):
58 | print('Using precomputed results...')
59 | table_g = pd.read_csv(csv_name_global_path)
60 | table_seq = pd.read_csv(csv_name_per_sequence_path)
61 | else:
62 | print(f'Evaluating sequences for the {self.task} task...')
63 | # Create dataset and evaluate
64 | dataset_eval = DAVISEvaluation(davis_root=self.davis_path, task=self.task, gt_set=self.set, year=self.year,
65 | sequences=self.sequences)
66 | metrics_res = dataset_eval.evaluate(self.results_path)
67 | J, F = metrics_res['J'], metrics_res['F']
68 |
69 | # Generate dataframe for the general results
70 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay']
71 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2.
72 | g_res = np.array(
73 | [final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]),
74 | np.mean(F["D"])])
75 | g_res = np.reshape(g_res, [1, len(g_res)])
76 | table_g = pd.DataFrame(data=g_res, columns=g_measures)
77 | with open(csv_name_global_path, 'w') as f:
78 | table_g.to_csv(f, index=False, float_format="%.3f")
79 | print(f'Global results saved in {csv_name_global_path}')
80 |
81 | # Generate a dataframe for the per sequence results
82 | seq_names = list(J['M_per_object'].keys())
83 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean']
84 | J_per_object = [J['M_per_object'][x] for x in seq_names]
85 | F_per_object = [F['M_per_object'][x] for x in seq_names]
86 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures)
87 | with open(csv_name_per_sequence_path, 'w') as f:
88 | table_seq.to_csv(f, index=False, float_format="%.3f")
89 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}')
90 |
91 | # Print the results
92 | sys.stdout.write(f"--------------------------- Global results for {self.set} ---------------------------\n")
93 | print(table_g.to_string(index=False))
94 | sys.stdout.write(f"\n---------- Per sequence results for {self.set} ----------\n")
95 | print(table_seq.to_string(index=False))
96 | total_time = time() - time_start
97 | sys.stdout.write('\nTotal time:' + str(total_time))
98 |
99 | return table_g, table_seq
100 |
101 |
102 | if __name__ == '__main__':
103 | parser = argparse.ArgumentParser(description='Evaluate a method on the DAVIS 2017 dataset')
104 | parser.add_argument('--results_path', type=str, required=True,
105 | help='Path to the folder containing the sequences folders.')
106 | parser.add_argument('--davis_path', type=str, required=True,
107 | help='Path to the DAVIS folder containing the `JPEGImages`, `Annotations`, `ImageSets`, '
108 | '`Annotations_unsupervised` folders.')
109 | parser.add_argument('--set', type=str, default='val', choices=['val', 'test-dev', 'test-challenge'],
110 | help='Subset to evaluate the results.')
111 | parser.add_argument('--eval_only_on_the_sequences_present_in_the_results', action='store_true',
112 | help='If True, evaluate only on the sequences present in the results folder.')
113 | parser.add_argument('--task', type=str, default='semi-supervised', choices=['semi-supervised', 'unsupervised'],
114 | help='Task to evaluate the results.')
115 | parser.add_argument("--year", type=str, help="Davis dataset year (default: 2017)", default='2017',
116 | choices=['2016', '2017', '2019'])
117 |
118 | args = parser.parse_args()
119 |
120 | sequences = 'all'
121 | if args.eval_only_on_the_sequences_present_in_the_results:
122 | assert os.path.exists(args.results_path)
123 | sequences = sorted(os.listdir(args.results_path))
124 | sequences = [s for s in sequences if s != "overlapping" and "." not in s]
125 | print(f"Evaluating only on the sequences present in the results folder: {sequences}")
126 |
127 | evaluator = Davis2017Evaluator(args.results_path, args.davis_path, args.set, args.task, args.year, sequences)
128 | evaluator.evaluate()
129 |
--------------------------------------------------------------------------------
/sam_pt/vos_eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import abstractmethod, ABC
3 |
4 | from sam_pt.modeling.sam_pt import SamPt
5 |
6 |
7 | class VOSEvaluator(ABC):
8 | """
9 | Abstract class for evaluating a model on the semi-supervised video object segmentation task.
10 | """
11 |
12 | def __init__(self, cfg, model):
13 | self.cfg = cfg
14 | self.model = model
15 |
16 | @abstractmethod
17 | def evaluate_video(self, video):
18 | """
19 | Evaluates model on a video and returns the predictions.
20 |
21 | Parameters
22 | ----------
23 | video : dict
24 | Dictionary with video data. It includes the following keys:
25 | 'video_name': str - The name of the video.
26 | 'video_id': int - The ID of the video.
27 | 'image': List[torch.Tensor] - The frames of the video as uint8 tensors of shape (channels, height, width)
28 | 'info': List[dict] - Information for each frame, includes keys like 'frame', 'save', 'shape', 'need_resize'.
29 | 'target_hw': Tuple[int, int] - The target height and width for the predicted masks.
30 | 'query_masks': torch.Tensor - The query masks as binary float32 tensor of shape (num_masks, height, width).
31 | 'query_point_timestep': torch.Tensor - The query point timesteps as float32 tensor of shape (num_masks,).
32 |
33 | Returns
34 | -------
35 | dict
36 | Dictionary with predictions. It includes the following keys:
37 | 'logits': List[torch.Tensor] - The logits as float32 tensors of shape (num_frames, height, width).
38 | 'trajectories': torch.Tensor - The trajectories as float32 tensor
39 | of shape (num_frames, n_masks, n_points_per_mask, 2).
40 | 'visibilities': torch.Tensor - The visibilities as float32 tensor
41 | of shape (num_frames, n_masks, n_points_per_mask).
42 | 'scores': List[float] - The scores as list of 'num_masks' floats.
43 | """
44 | pass
45 |
46 |
47 | class SamPtEvaluator(VOSEvaluator):
48 | def evaluate_video(self, video):
49 | self.model: SamPt = self.model
50 | device = self.model.device
51 | for k, v in video.items():
52 | if isinstance(v, torch.Tensor):
53 | video[k] = v.to(device)
54 | outputs = self.model(video)
55 | return {
56 | "logits": outputs["logits"],
57 | "trajectories": outputs["trajectories"],
58 | "visibilities": outputs["visibilities"],
59 | 'scores': outputs['scores'],
60 | }
61 |
--------------------------------------------------------------------------------
/scripts/annotation_comparison_gif.py:
--------------------------------------------------------------------------------
1 | """
2 | python scripts/annotation_comparison_gif.py
3 | """
4 |
5 | import os
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 |
8 | import imageio
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 |
13 | def create_gif(results_dir, annotations_dir, images_dir, output_gif_path):
14 | # Get a sorted list of image files and annotation files
15 | result_files = sorted([f for f in os.listdir(results_dir) if f.endswith('.png')])
16 | images_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg')])
17 | annotation_files = sorted([f for f in os.listdir(annotations_dir) if f.endswith('.png')])
18 |
19 | # Check if both folders have the same number of files
20 | assert len(result_files) == len(annotation_files) == len(images_files)
21 |
22 | # Create a list to store concatenated images
23 | concat_images = []
24 |
25 | for res_file, img_file, ann_file in tqdm(list(zip(result_files, images_files, annotation_files))):
26 | # Open the images
27 | result = Image.open(os.path.join(results_dir, res_file))
28 | image = Image.open(os.path.join(images_dir, img_file))
29 | annotation = Image.open(os.path.join(annotations_dir, ann_file))
30 |
31 | # Make sure the images can be concatenated
32 | assert image.size == annotation.size == result.size, "Image sizes do not match."
33 |
34 | # Concatenate the images vertically
35 | total_height = image.size[1] + annotation.size[1] + result.size[1]
36 | combined_image = Image.new('RGB', (image.size[0], total_height))
37 | combined_image.paste(image, (0, 0))
38 | combined_image.paste(annotation, (0, image.size[1]))
39 | combined_image.paste(result, (0, image.size[1] + annotation.size[1]))
40 |
41 | # Add to list of concatenated images
42 | concat_images.append(combined_image)
43 |
44 | # Save the frames as a GIF
45 | imageio.mimsave(output_gif_path, concat_images, duration=0.5, loop=0)
46 |
47 | print(f"GIF created at {output_gif_path}")
48 |
49 |
50 | def create_gif_per_video(video, results_path, annotations_path, images_path):
51 | print(f"Creating GIF for {video}")
52 | result_path = os.path.join(results_path, video)
53 | annotation_path = os.path.join(annotations_path, video)
54 | image_path = os.path.join(images_path, video)
55 | output_gif_path = os.path.join(results_path, video + ".gif")
56 | create_gif(result_path, annotation_path, image_path, output_gif_path)
57 |
58 |
59 | if __name__ == '__main__':
60 | # results_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/K9.000--debug--cotracker-0--1-1024/"
61 | # annotations_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/vos/val/Annotations/"
62 | # images_path = "/mnt/terra/xoding/eth-master-thesis/08-logs-september/bdd100k-results/vos/val/JPEGImages/"
63 |
64 | results_path = "outputs/K9.000--debug--cotracker-0--1-1024/eval_BDD100K_val"
65 | annotations_path = "data/bdd100k/vos/val/Annotations"
66 | images_path = "data/bdd100k/vos/val/JPEGImages"
67 |
68 | videos = [video for video in os.listdir(results_path) if not video.endswith(".gif") and not "." in video]
69 |
70 | with ThreadPoolExecutor() as executor:
71 | # Submit all tasks to the executor
72 | future_to_video = {
73 | executor.submit(create_gif_per_video, video, results_path, annotations_path, images_path): video for video
74 | in videos}
75 |
76 | # Process the futures as they complete
77 | for future in tqdm(as_completed(future_to_video), total=len(videos), desc="Processing videos", unit="video"):
78 | video = future_to_video[future]
79 | try:
80 | future.result()
81 | except Exception as exc:
82 | print(f'{video} generated an exception: {exc}')
83 |
84 | print("All GIFs have been created.")
85 |
--------------------------------------------------------------------------------
/scripts/bdd100k_from_instance_seg_to_vos_annotations.py:
--------------------------------------------------------------------------------
1 | """
2 | To create the VOS annotations from the instance segmentation annotations, run:
3 | ```bash
4 | # Prepare directories
5 | mkdir -p data/bdd100k/vos/val/{Annotations,JPEGImages}
6 |
7 | # Copy JPEGImages
8 | cp -r data/bdd100k/images/seg_track_20/val/* data/bdd100k/vos/val/JPEGImages/
9 |
10 | # Create the Annotations
11 | python -m scripts.bdd100k_from_instance_seg_to_vos_annotations
12 |
13 | # Link the chunks
14 | # e.g., data/bdd100k/vos/val/JPEGImages/b1c66a42-6f7d68ca-chunk2 -> b1c66a42-6f7d68ca/
15 | find data/bdd100k/vos/val/Annotations -type d -name "*-chunk*" | sed 's/Annotations/JPEGImages/' | while read -r src; do
16 | tgt=$(basename "$src" | sed 's/-chunk.*//')
17 | rm $src
18 | ln -s "$tgt" "$src"
19 | done
20 | ```
21 | """
22 | import json
23 | import os
24 |
25 | import math
26 | import numpy as np
27 | import pandas as pd
28 | from PIL import Image
29 | from tqdm import tqdm
30 | from tqdm.contrib.concurrent import process_map
31 |
32 | np.random.seed(72)
33 | palette = (np.multiply(np.random.rand(768), 255).astype(np.uint8).tolist())
34 | palette[:3] = [0, 0, 0]
35 |
36 |
37 | def remap_ids(ids):
38 | # Find the unique IDs and their new remapped positions
39 | unique_ids, inverse_indices = np.unique(ids, return_inverse=True)
40 |
41 | # Reshape the inverse_indices to the shape of the original IDs array
42 | remapped_ids = inverse_indices.reshape(ids.shape)
43 |
44 | return remapped_ids
45 |
46 |
47 | def process_video(video_name, objects_per_chunk=100):
48 | print(f"Processing video {video_name}")
49 | frames = sorted(os.listdir(os.path.join(videos_path, video_name)))
50 | bitmasks = []
51 | for frame_name in frames:
52 | frame_path = os.path.join(videos_path, video_name, frame_name)
53 | bitmask = np.array(Image.open(frame_path))
54 | bitmasks.append(bitmask)
55 | bitmasks = np.stack(bitmasks)
56 | annotation_ids = (bitmasks[:, :, :, 2].astype(np.uint32) << 8) + bitmasks[:, :, :, 3]
57 | unique_ids = np.unique(annotation_ids).size
58 | print(f"Video {video_name} is loaded, it has {unique_ids} unique objects")
59 |
60 | annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 1 == 0) # Remove ignored instances
61 | annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 2 == 0) # Remove crowd instances
62 | # annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 4 == 0) # Remove occluded instances
63 | # annotation_ids = annotation_ids * (bitmasks[:, :, :, 1] & 8 == 0) # Remove truncated instances
64 | unique_ids_old = unique_ids
65 | unique_ids = np.unique(annotation_ids).size
66 | print(f"Video {video_name} is filtered by ignored and crowd instances, "
67 | f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now")
68 |
69 | # # Randomly select max_objects objects
70 | # if unique_ids > max_objects:
71 | # np.random.seed(72)
72 | # selected_ids = np.random.choice(np.sort(np.unique(annotation_ids))[1:], max_objects, replace=False)
73 | # annotation_ids = np.where(np.isin(annotation_ids, selected_ids), annotation_ids, 0)
74 | # unique_ids_old = unique_ids
75 | # unique_ids = np.unique(annotation_ids).size
76 | # print(f"Video {video_name} is filtered by max_objects, "
77 | # f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now")
78 |
79 | # # Select the first max_objects objects
80 | # if unique_ids > max_objects:
81 | # selected_ids = np.sort(np.unique(annotation_ids))[1:max_objects + 1]
82 | # annotation_ids = np.where(np.isin(annotation_ids, selected_ids), annotation_ids, 0)
83 | # unique_ids_old = unique_ids
84 | # unique_ids = np.unique(annotation_ids).size
85 | # print(f"Video {video_name} is filtered by max_objects, "
86 | # f"it has {unique_ids_old:>5d} --> {unique_ids:>5d} unique objects now")
87 |
88 | # Split the objects into chunks of objects_per_chunk objects
89 | annotation_ids_unique = np.unique(annotation_ids)[1:]
90 | for chunk_id in range(math.ceil(annotation_ids_unique.size / objects_per_chunk)):
91 | chunk_name = f"{video_name}-chunk{chunk_id + 1}" if chunk_id > 0 else video_name
92 | chunk = annotation_ids_unique[chunk_id * objects_per_chunk:(chunk_id + 1) * objects_per_chunk]
93 | print(f"Processing {chunk_name}, it has {chunk.size} objects: {chunk}")
94 |
95 | # Select the objects in the chunk
96 | annotation_ids_chunk = np.where(np.isin(annotation_ids, chunk), annotation_ids, 0)
97 | unique_ids = np.unique(annotation_ids_chunk).size
98 |
99 | # Remap annotation IDs to be continuous
100 | remapped_annotation_ids = remap_ids(annotation_ids_chunk)
101 | assert np.unique(remapped_annotation_ids).size == unique_ids
102 | assert np.unique(remapped_annotation_ids).size == remapped_annotation_ids.max() + 1
103 | print(f"Video {video_name} is remapped")
104 |
105 | output_dir = os.path.join(output_path, chunk_name)
106 | os.makedirs(output_dir, exist_ok=True)
107 | assert unique_ids <= 255, "The number of unique objects should be less than 255 to use uint8"
108 | for frame_id, frame_name in enumerate(frames):
109 | x = Image.fromarray(remapped_annotation_ids[frame_id].astype(np.uint8), mode="P")
110 | x.putpalette(palette)
111 | x.save(os.path.join(output_dir, frame_name))
112 | print(f"Video {video_name} is saved")
113 |
114 |
115 | def sanity_check(output_path, rles_path):
116 | for i, video_json_name in enumerate(tqdm(sorted([vp for vp in os.listdir(rles_path) if vp.endswith("json")]))):
117 | video_name = video_json_name.replace(".json", "")
118 | # if i < 15:
119 | # print(f"Skipping video {video_name}")
120 | # continue
121 | with open(os.path.join(rles_path, video_json_name), "r") as fp:
122 | video = json.load(fp)
123 | df = pd.DataFrame([
124 | (label["category"], label["id"])
125 | for frame in video["frames"]
126 | for label in frame["labels"]
127 | ], columns=["cat", "id"])
128 | assert df[~df.duplicated()].groupby("id").count().max().item() == 1
129 |
130 | annotation_ids = [
131 | np.array(Image.open(os.path.join(output_path, video_name, frame_name)))
132 | for frame_name in sorted(os.listdir(os.path.join(output_path, video_name)))
133 | ]
134 | annotation_ids = np.stack(annotation_ids)
135 | assert np.unique(annotation_ids).size == annotation_ids.max() + 1
136 | if np.unique(annotation_ids).size != df.id.unique().size + 1:
137 | print(f"Video {video_name} has {np.unique(annotation_ids).size} unique objects, "
138 | f"but RLE has {df.id.unique().size + 1} unique objects")
139 | # breakpoint()
140 | else:
141 | assert np.unique(annotation_ids).size == df.id.unique().size + 1
142 |
143 | print(f"Unique objects for video {i:02d}: {df.id.unique().size}")
144 |
145 |
146 | if __name__ == '__main__':
147 | videos_path = "data/bdd100k/labels/seg_track_20/bitmasks/val"
148 | output_path = "data/bdd100k/vos/val/Annotations"
149 | video_names = sorted([name for name in os.listdir(videos_path) if os.path.isdir(os.path.join(videos_path, name))])
150 |
151 | # Create the VOS annotations
152 | process_map(process_video, video_names, chunksize=1)
153 | print("Done creating VOS annotations")
154 |
155 | # # Sanity check that the number of objects in the VOS annotations is the same as in the RLEs
156 | # print("Sanity check that the number of objects in the VOS annotations is the same as in the RLEs")
157 | # rles_path = "data/bdd100k/labels/seg_track_20/rles/val"
158 | # sanity_check(output_path, rles_path)
159 |
--------------------------------------------------------------------------------
/scripts/clean_tapnet_checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | This script cleans-up the original TapNet checkpoint by removing objects
3 | that require `import tapnet` to work. The cleaned checkpoint saves only
4 | the weights and removes the optimizer state. The cleaned checkpoint can
5 | be used within SAM-PT.
6 |
7 | Note that we provide a link to the cleaned checkpoint in the
8 | documentation and that you might not need to run this script yourself.
9 |
10 | Usage:
11 | 1. Clone the [TapNet repository](https://github.com/deepmind/tapnet) and
12 | checkout the commit `ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1`.
13 | 2. Setup the TapNet environment.
14 | 3. Run this script one level above the TapNet repository (i.e., not
15 | within the TapNet repository, but within its parent directory). For
16 | that, navigate to the parent directory of TapNet repository (`cd ..`)
17 | and set the PYTHONPATH environment variable
18 | (```export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH```).
19 |
20 | Run the script from the command line with the following arguments:
21 | - --input: The path to the original TapNet checkpoint file.
22 | - --output: The path where the cleaned checkpoint file will be saved.
23 |
24 | For example:
25 | ```bash
26 | python script_name.py \
27 | --input "./models/tapnet_ckpts/open_source_ckpt/checkpoint.npy" \
28 | --output "./models/tapnet_ckpts/open_source_ckpt/checkpoint_wo_optstate.npy"
29 | ```
30 | """
31 |
32 | import argparse
33 | import numpy as np
34 | import tensorflow as tf
35 |
36 |
37 | def clean_checkpoint(input_path, output_path):
38 | # Load the original checkpoint file.
39 | checkpoint = np.load(input_path, allow_pickle=True).item()
40 |
41 | print(checkpoint.keys())
42 | # dict_keys(['params', 'state', 'opt_state', 'global_step'])
43 |
44 | # Create a new dictionary without the 'opt_state' and 'global_step'.
45 | checkpoint_wo_optstate = {
46 | "params": checkpoint["params"],
47 | "state": checkpoint["state"],
48 | }
49 |
50 | # Save the cleaned checkpoint file.
51 | with tf.io.gfile.GFile(output_path, 'wb') as fp:
52 | np.save(fp, checkpoint_wo_optstate)
53 |
54 |
55 | def parse_arguments():
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--input", help="The path to the original TapNet checkpoint file.")
58 | parser.add_argument("--output", help="The path where the cleaned checkpoint file will be saved.")
59 | return parser.parse_args()
60 |
61 |
62 | if __name__ == "__main__":
63 | args = parse_arguments()
64 | clean_checkpoint(args.input, args.output)
65 |
--------------------------------------------------------------------------------
/scripts/davis_mask_to_contour.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to convert DAVIS mask annotation images to contour images.
3 | Used to prepare figures for the SAM-PT paper.
4 | Note that paths are hardcoded in the script.
5 |
6 | Usage: `python -m scripts.davis_mask_to_contour`
7 | """
8 |
9 | import cv2
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 |
13 |
14 | def davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path, contour_radius=5):
15 | # Open image and convert it to numpy array
16 | print(f"Input image path: {input_image_path}")
17 | image = cv2.imread(input_image_path)
18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
19 | assert image.dtype == np.uint8
20 | assert image.min() >= 0 and image.max() <= 255
21 | plt.imshow(image)
22 | plt.show()
23 |
24 | # The number of masks is the number of unique colors in the image
25 | n_masks = len(np.unique(image.reshape(-1, image.shape[2]), axis=0)) - 1
26 | print(f"Number of masks: {n_masks}")
27 |
28 | # Take each mask separately and create a binary mask, remember the color of each mask
29 | masks = []
30 | colors = np.unique(image.reshape(-1, image.shape[2]), axis=0)
31 | assert (colors[0] == [0, 0, 0]).all()
32 | colors = colors[1:]
33 | for mask_idx in range(n_masks):
34 | mask = (image == colors[mask_idx][None, None, :]).all(-1)
35 | masks.append(mask)
36 |
37 | # Create a contour mask for each mask
38 | contour_masks = []
39 | for mask_idx in range(n_masks):
40 | m_8int = masks[mask_idx].astype(np.uint8)
41 | dist_transform_fore = cv2.distanceTransform(m_8int, cv2.DIST_L2, 3)
42 | contour_mask = (dist_transform_fore <= contour_radius) & (dist_transform_fore > 0)
43 | contour_mask = contour_mask.astype(np.uint8)
44 | contour_masks.append(contour_mask)
45 | plt.imshow(contour_mask)
46 | plt.show()
47 |
48 | # Add contour mask to the image
49 | output_image = np.zeros_like(image)
50 | for mask_idx in range(n_masks):
51 | output_image = np.where(contour_masks[mask_idx][:, :, None] == 1, colors[mask_idx][None, None, :], output_image)
52 |
53 | # Plot the image
54 | plt.imshow(output_image)
55 | plt.show()
56 |
57 | # Save the image
58 | print(f"Output image path: {output_image_path}")
59 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
60 | cv2.imwrite(output_image_path, output_image)
61 |
62 | # Save also RGBA image
63 | output_image = cv2.imread(output_image_path)
64 | output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
65 | r, g, b = cv2.split(output_image)
66 | a = 255 - (output_image == np.array([0, 0, 0])[None, None, :]).all(-1).astype(np.uint8) * 255
67 | output_image = cv2.merge([r, g, b, a], 4)
68 | print(f"RGBA image path: {output_image_path}.rgba.png")
69 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2BGRA)
70 | cv2.imwrite(output_image_path + ".rgba.png", output_image)
71 | print("Done.")
72 |
73 |
74 | if __name__ == '__main__':
75 | for i in [1, 7, 16, 23, 32]:
76 | input_image_path = f"../../04-logs/system-figure/gt--mask-only--frame-{i}--cropped.png"
77 | output_image_path = f"../../04-logs/system-figure/gt--mask-only--contour--frame-{i}--cropped.png"
78 | davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path)
79 | for i in [1, 7, 16, 23, 32]:
80 | input_image_path = f"../../04-logs/system-figure/gt--mask-only--frame-{i}--cropped.png"
81 | output_image_path = f"../../04-logs/system-figure/gt--mask-only--contour--thin--frame-{i}--cropped.png"
82 | davis_mask_annotation_image_to_contour_image(input_image_path, output_image_path, contour_radius=2)
83 |
--------------------------------------------------------------------------------
/scripts/uvo_video2frames.py:
--------------------------------------------------------------------------------
1 | """
2 | A utility script to split UVO videos into frames.
3 |
4 | The script takes two command-line arguments:
5 | 1. --video_dir: The directory containing the videos you wish to split into frames.
6 | 2. --frames_dir: The directory where the frames will be saved.
7 |
8 | Each video in the input directory will be split into frames, and these frames will be stored in a subdirectory of --frames_dir named after the video.
9 |
10 | Usage:
11 |
12 | ```bash
13 | python ../scripts/uvo_video2frames.py --video_dir UVOv1.0/uvo_videos_dense --frames_dir UVOv1.0/uvo_videos_dense_frames
14 | python ../scripts/uvo_video2frames.py --video_dir UVOv1.0/uvo_videos_sparse --frames_dir UVOv1.0/uvo_videos_sparse_frames
15 | ```
16 | """
17 | import argparse
18 | import cv2
19 | import os
20 | import pathlib
21 | from tqdm import tqdm
22 |
23 |
24 | def split_single_video(video_path, frames_dir=""):
25 | cap = cv2.VideoCapture(video_path)
26 | cnt = 0
27 | while cap.isOpened():
28 | ret, frame = cap.read()
29 | if ret:
30 | success, buffer = cv2.imencode(".png", frame)
31 | if success:
32 | with open(f"{frames_dir}{cnt}.png", "wb") as f:
33 | f.write(buffer.tobytes())
34 | f.flush()
35 | cnt += 1
36 | else:
37 | break
38 | return cnt
39 |
40 |
41 | def get_parser():
42 | arg_parser = argparse.ArgumentParser()
43 | arg_parser.add_argument("--video_dir", type=str, default="NonPublic/uvo_videos_dense/")
44 | arg_parser.add_argument("--frames_dir", type=str, default="NonPublic/uvo_videos_dense_frames/")
45 | return arg_parser
46 |
47 |
48 | if __name__ == '__main__':
49 | parser = get_parser()
50 | args = parser.parse_args()
51 | video_paths = os.listdir(args.video_dir)
52 | print(f"Splitting videos in {args.video_dir} to frames in {args.frames_dir}...")
53 | print(f"Total number of videos: {len(video_paths)}")
54 | for video_path in tqdm(video_paths):
55 | print(f"Splitting {video_path}...")
56 | v_frame_dir = pathlib.Path(os.path.join(args.frames_dir, video_path[:-4]))
57 | if not v_frame_dir.is_dir():
58 | v_frame_dir.mkdir(parents=True, exist_ok=False)
59 | n_frames = split_single_video(os.path.join(args.video_dir, video_path), frames_dir=v_frame_dir)
60 | print(f"Total number of frames extracted from {video_path}: {n_frames}")
61 | print(f"Done.")
62 |
--------------------------------------------------------------------------------
/scripts/visualize_point_sampling_methods.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to visualize the different point sampling methods for the SAM-PT paper.
3 |
4 | Usage: `python -m scripts.visualize_point_sampling_methods`
5 | """
6 | import argparse
7 | import cv2
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import torch
11 | from functools import partial
12 |
13 | from sam_pt.utils.query_points import extract_corner_points
14 | from sam_pt.utils.query_points import extract_kmedoid_points
15 | from sam_pt.utils.query_points import extract_mixed_points
16 | from sam_pt.utils.query_points import extract_random_mask_points
17 | from sam_pt.utils.util import seed_all
18 |
19 |
20 | def mixed_point_id_to_marker_and_rescale(n_points, point_id):
21 | n_kmedoid = n_points // 4
22 | n_shi_tomasi = n_points // 3
23 | if point_id < n_kmedoid:
24 | return "o", 1
25 | elif point_id < n_kmedoid + n_shi_tomasi:
26 | return "*", 3
27 | else:
28 | return "v", 1.2
29 |
30 |
31 | def visualize_point_sampling_methods(
32 | rgb_image_path,
33 | annotation_image_path,
34 | output_image_path,
35 | point_sampling_method_name="kmedoids",
36 | n_points=8,
37 | seed=72,
38 | ):
39 | # Open image and convert it to numpy array
40 | image = cv2.imread(rgb_image_path)
41 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
42 | assert image.dtype == np.uint8
43 | assert image.min() >= 0 and image.max() <= 255
44 | plt.imshow(image)
45 | plt.show()
46 |
47 | annotation_image = cv2.imread(annotation_image_path)
48 | annotation_image = cv2.cvtColor(annotation_image, cv2.COLOR_BGR2RGB)
49 | assert annotation_image.dtype == np.uint8
50 | assert annotation_image.min() >= 0 and annotation_image.max() <= 255
51 | plt.imshow(annotation_image)
52 | plt.show()
53 |
54 | # The number of masks is the number of unique colors in the image
55 | n_masks = len(np.unique(annotation_image.reshape(-1, annotation_image.shape[2]), axis=0)) - 1
56 | print(f"Number of masks: {n_masks}")
57 |
58 | # Prepare the point sampling methods
59 | point_sampling_methods = {
60 | "kmedoids": {
61 | "function": extract_kmedoid_points,
62 | "marker": ["o" for _ in range(n_points)],
63 | "rescale": [1 for _ in range(n_points)],
64 | },
65 | "shi-tomasi": {
66 | "function": partial(extract_corner_points, image=torch.from_numpy(image).permute(2, 0, 1)),
67 | "marker": ["*" for _ in range(n_points)],
68 | "rescale": [3 for _ in range(n_points)],
69 | },
70 | "random": {
71 | "function": extract_random_mask_points,
72 | "marker": ["v" for _ in range(n_points)],
73 | "rescale": [1.2 for _ in range(n_points)]
74 | },
75 | "mixed": {
76 | "function": lambda mask, n_points_to_select: extract_mixed_points(
77 | query_masks=mask[None, ...],
78 | query_points_timestep=torch.zeros(n_masks),
79 | images=torch.from_numpy(image).permute(2, 0, 1)[None, ...],
80 | n_points=n_points_to_select,
81 | )[0],
82 | "marker": [mixed_point_id_to_marker_and_rescale(n_points, point_id)[0] for point_id in range(n_points)],
83 | "rescale": [mixed_point_id_to_marker_and_rescale(n_points, point_id)[1] for point_id in range(n_points)]
84 | },
85 | }
86 |
87 | # Take each mask separately and create a binary mask, remember the color of each mask
88 | masks = []
89 | colors = np.unique(annotation_image.reshape(-1, annotation_image.shape[2]), axis=0)
90 | assert (colors[0] == [0, 0, 0]).all()
91 | colors = colors[1:]
92 | for mask_idx in range(n_masks):
93 | mask = (annotation_image == colors[mask_idx][None, None, :]).all(-1)
94 | masks.append(mask)
95 |
96 | # Sample points from each mask
97 | mask_points = []
98 | for mask_idx in range(n_masks):
99 | seed_all(seed + 3)
100 | mask = torch.from_numpy(masks[mask_idx]).bool()
101 | points = point_sampling_methods[point_sampling_method_name]["function"](mask=mask, n_points_to_select=n_points)
102 | mask_points.append(points)
103 |
104 | # Create a contour mask for each mask
105 | contour_radius = 3
106 | contour_masks = []
107 | for mask_idx in range(n_masks):
108 | m_8int = masks[mask_idx].astype(np.uint8)
109 | dist_transform_fore = cv2.distanceTransform(m_8int, cv2.DIST_L2, 3)
110 | contour_mask = (dist_transform_fore <= contour_radius) & (dist_transform_fore > 0)
111 | contour_mask = contour_mask.astype(np.uint8)
112 | contour_masks.append(contour_mask)
113 |
114 | # Add contour and sampled points to the image
115 | output_image = np.zeros_like(annotation_image)
116 | for mask_idx in range(n_masks):
117 | output_image = np.where(contour_masks[mask_idx][:, :, None] == 1, colors[mask_idx][None, None, :], output_image)
118 | h, w, dpi = output_image.shape[0], output_image.shape[1], 100
119 | plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi)
120 | plt.imshow(output_image)
121 | for mask_idx in range(n_masks):
122 | for point_idx in range(n_points):
123 | plt.scatter(
124 | x=mask_points[mask_idx][point_idx, 0],
125 | y=mask_points[mask_idx][point_idx, 1],
126 | s=90 * point_sampling_methods[point_sampling_method_name]["rescale"][point_idx],
127 | c=(colors[mask_idx][None, :] * 1.8 / 255).clip(min=0, max=1),
128 | linewidths=0,
129 | marker=point_sampling_methods[point_sampling_method_name]["marker"][point_idx]
130 | )
131 | plt.axis("off")
132 | plt.tight_layout(pad=0)
133 | print(f"Output image path: {output_image_path}")
134 | plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
135 | plt.show()
136 |
137 | # Save also RGBA image
138 | output_image = cv2.imread(output_image_path)
139 | output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
140 | r, g, b = cv2.split(output_image)
141 | a = 255 - (output_image == np.array([0, 0, 0])[None, None, :]).all(-1).astype(np.uint8) * 255
142 | output_image = cv2.merge([r, g, b, a], 4)
143 | print(f"RGBA image path: {output_image_path}.rgba.png")
144 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2BGRA)
145 | cv2.imwrite(output_image_path + ".rgba.png", output_image)
146 | print("Done.")
147 |
148 |
149 | def main(args):
150 | n_points = args.n_points
151 | for psm in args.point_sampling_methods:
152 | output_image_path = f"{args.output_path_prefix}--point-sampling-method-{psm}.png"
153 | visualize_point_sampling_methods(
154 | rgb_image_path=args.rgb_path,
155 | annotation_image_path=args.annotation_path,
156 | output_image_path=output_image_path,
157 | point_sampling_method_name=psm,
158 | n_points=n_points,
159 | seed=args.seed,
160 | )
161 |
162 |
163 | if __name__ == '__main__':
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument('--n_points', type=int, default=8)
166 | parser.add_argument('--rgb_path', type=str,
167 | default="../../04-logs/system-figure/horse-input--frame-16--cropped.png")
168 | parser.add_argument('--annotation_path', type=str,
169 | default="../../04-logs/system-figure/gt--mask-only--frame-16--cropped.png")
170 | parser.add_argument('--output_path_prefix', type=str,
171 | default="../../04-logs/system-figure/gt--mask-only--frame-16--cropped")
172 | parser.add_argument('--point_sampling_methods', type=str, nargs='+',
173 | default=["kmedoids", "shi-tomasi", "random", "mixed"])
174 | parser.add_argument('--seed', type=int, default=72)
175 | args = parser.parse_args()
176 | main(args)
177 |
--------------------------------------------------------------------------------