├── .gitignore
├── README.md
├── config
├── avss
│ ├── AVSegFormer_pvt2_avss.py
│ └── AVSegFormer_res50_avss.py
├── ms3
│ ├── AVSegFormer_pvt2_ms3.py
│ └── AVSegFormer_res50_ms3.py
└── s4
│ ├── AVSegFormer_pvt2_s4.py
│ └── AVSegFormer_res50_s4.py
├── dataloader
├── __init__.py
├── ms3_dataset.py
├── s4_dataset.py
└── v2_dataset.py
├── image
└── arch.png
├── model
├── AVSegFormer.py
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── pvt.py
│ └── resnet.py
├── head
│ ├── AVSegHead.py
│ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── fusion_block.py
│ ├── positional_encoding.py
│ ├── query_generator.py
│ └── transformer.py
└── vggish
│ ├── __init__.py
│ ├── mel_features.py
│ ├── vggish.py
│ ├── vggish_input.py
│ └── vggish_params.py
├── ops
├── functions
│ ├── __init__.py
│ └── ms_deform_attn_func.py
├── make.sh
├── modules
│ ├── __init__.py
│ └── ms_deform_attn.py
├── setup.py
├── src
│ ├── cpu
│ │ ├── ms_deform_attn_cpu.cpp
│ │ └── ms_deform_attn_cpu.h
│ ├── cuda
│ │ ├── ms_deform_attn_cuda.cu
│ │ ├── ms_deform_attn_cuda.h
│ │ └── ms_deform_im2col_cuda.cuh
│ ├── ms_deform_attn.h
│ └── vision.cpp
└── test.py
├── scripts
├── avss
│ ├── loss.py
│ ├── test.py
│ └── train.py
├── ms3
│ ├── loss.py
│ ├── test.py
│ ├── train.py
│ └── utility.py
└── s4
│ ├── loss.py
│ ├── test.py
│ ├── train.py
│ └── utility.py
├── test.sh
├── train.sh
└── utils
├── compute_color_metrics.py
├── logger.py
├── loss_util.py
├── pyutils.py
└── vis_mask.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | **/__pycache__/**
3 | *.pth
4 | temp
5 | data
6 | pretrained
7 | work_dir
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 💬 AVSegFormer [[paper](https://arxiv.org/abs/2307.01146)]
2 | The combination of vision and audio has long been a topic of interest among researchers in the multi-modal field. Recently, a new audio-visual segmentation task has been introduced, aiming to locate and segment the corresponding sound source objects in a given video. This task demands pixel-level fine-grained features for the first time, posing significant challenges. In this paper, we propose AVSegFormer, a new method for audio-visual segmentation tasks that leverages the Transformer architecture for its outstanding performance in multi-modal tasks. We combine audio features and learnable queries as decoder inputs to facilitate multi-modal information exchange. Furthermore, we design an audio-visual mixer to amplify the features of target objects. Additionally, we devise an intermediate mask loss to enhance training efficacy. Our method demonstrates robust performance and achieves state-of-the-art results in audio-visual segmentation tasks.
3 |
4 |
5 | ## 🚀 What's New
6 | - (2023.04.28) Upload pre-trained checkpoints and update README.
7 | - (2023.04.25) We completed the implemention of AVSegFormer and push the code.
8 |
9 |
10 | ## 🏠 Method
11 |
12 |
13 |
14 | ## 🛠️ Get Started
15 |
16 | ### 1. Environments
17 | ```shell
18 | # recommended
19 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
20 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html
21 | pip install pandas
22 | pip install timm
23 | pip install resampy
24 | pip install soundfile
25 | # build MSDeformAttention
26 | cd ops
27 | sh make.sh
28 | ```
29 |
30 |
31 | ### 2. Data
32 |
33 | Please refer to the link [AVSBenchmark](https://github.com/OpenNLPLab/AVSBench) to download the datasets. You can put the data under `data` folder or rename your own folder. Remember to modify the path in config files. The `data` directory is as bellow:
34 | ```
35 | |--data
36 | |--AVSS
37 | |--Multi-sources
38 | |--Single-source
39 | ```
40 |
41 |
42 | ### 3. Download Pre-Trained Models
43 |
44 | - The pretrained backbone is available from benchmark [AVSBench pretrained backbones](https://drive.google.com/drive/folders/1386rcFHJ1QEQQMF6bV1rXJTzy8v26RTV).
45 | - We provides pre-trained models for all three subtasks. You can download them from [AVSegFormer pretrained models](https://drive.google.com/drive/folders/1ZYZOWAfoXcGPDsocswEN7ZYvcAn4H8kY).
46 |
47 | |Method|Backbone|Subset|Lr schd|Config|mIoU|F-score|Download|
48 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
49 | |AVSegFormer-R50|ResNet-50|S4|30ep|[config](config/s4/AVSegFormer_res50_s4.py)|76.38|86.7|[ckpt](https://drive.google.com/file/d/1nvIfR-1XZ_BgP8ZSUDuAsGhAwDJRxgC3/view?usp=drive_link)|
50 | |AVSegFormer-PVTv2|PVTv2-B5|S4|30ep|[config](config/s4/AVSegFormer_pvt2_s4.py)|83.06|90.5|[ckpt](https://drive.google.com/file/d/1ZJ55jxoHP1ur-hLBkGcha8sjptE_shfw/view?usp=drive_link)|
51 | |AVSegFormer-R50|ResNet-50|MS3|60ep|[config](config/ms3/AVSegFormer_res50_ms3.py)|53.81|65.6|[ckpt](https://drive.google.com/file/d/1MRk5gQnUtiWwYDpPfB20fO07SVLhfuIV/view?usp=drive_link)|
52 | |AVSegFormer-PVTv2|PVTv2-B5|MS3|60ep|[config](config/ms3/AVSegFormer_pvt2_ms3.py)|61.33|73.0|[ckpt](https://drive.google.com/file/d/1iKTxWtehAgCkNVty-4H1zVyAOaNxipHv/view?usp=drive_link)|
53 | |AVSegFormer-R50|ResNet-50|AVSS|30ep|[config](config/avss/AVSegFormer_res50_avss.py)|26.58|31.5|[ckpt](https://drive.google.com/file/d/1RvL6psDsINuUwd9V1ESgE2Kixh9MXIke/view?usp=drive_link)|
54 | |AVSegFormer-PVTv2|PVTv2-B5|AVSS|30ep|[config](config/avss/AVSegFormer_pvt2_avss.py)|37.31|42.8|[ckpt](https://drive.google.com/file/d/1P8a2dJSUoW0EqFyxyP8B1-Rnscxnh0YY/view?usp=drive_link)|
55 |
56 |
57 | ### 4. Train
58 | ```shell
59 | TASK = "s4" # or ms3, avss
60 | CONFIG = "config/s4/AVSegFormer_pvt2_s4.py"
61 |
62 | bash train.sh ${TASK} ${CONFIG}
63 | ```
64 |
65 |
66 | ### 5. Test
67 | ```shell
68 | TASK = "s4" # or ms3, avss
69 | CONFIG = "config/s4/AVSegFormer_pvt2_s4.py"
70 | CHECKPOINT = "work_dir/AVSegFormer_pvt2_s4/S4_best.pth"
71 |
72 | bash test.sh ${TASK} ${CONFIG} ${CHECKPOINT}
73 | ```
74 |
75 |
76 | ## 🤝 Citation
77 |
78 | If you use our model, please consider cite following papers:
79 | ```
80 | @article{zhou2023avss,
81 | title={Audio-Visual Segmentation with Semantics},
82 | author={Zhou, Jinxing and Shen, Xuyang and Wang, Jianyuan and Zhang, Jiayi and Sun, Weixuan and Zhang, Jing and Birchfield, Stan and Guo, Dan and Kong, Lingpeng and Wang, Meng and Zhong, Yiran},
83 | journal={arXiv preprint arXiv:2301.13190},
84 | year={2023},
85 | }
86 |
87 | @misc{gao2023avsegformer,
88 | title={AVSegFormer: Audio-Visual Segmentation with Transformer},
89 | author={Shengyi Gao and Zhe Chen and Guo Chen and Wenhai Wang and Tong Lu},
90 | year={2023},
91 | eprint={2307.01146},
92 | archivePrefix={arXiv},
93 | primaryClass={cs.CV}
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/config/avss/AVSegFormer_pvt2_avss.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='pvt_v2_b5',
6 | init_weights_path='pretrained/pvt_v2_b5.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=True,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[64, 128, 320, 512],
16 | num_classes=71,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | query_generator=dict(
21 | type='AttentionGenerator',
22 | num_layers=6,
23 | query_num=300),
24 | positional_encoding=dict(
25 | type='SinePositionalEncoding',
26 | num_feats=128),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=10)
45 | dataset = dict(
46 | train=dict(
47 | type='V2Dataset',
48 | split='train',
49 | num_class=71,
50 | mask_num=10,
51 | crop_img_and_mask=True,
52 | crop_size=224,
53 | meta_csv_path='data/AVSS/metadata.csv',
54 | label_idx_path='data/AVSS/label2idx.json',
55 | dir_base='data/AVSS',
56 | img_size=(224, 224),
57 | resize_pred_mask=True,
58 | save_pred_mask_img_size=(360, 240),
59 | batch_size=4),
60 | val=dict(
61 | type='V2Dataset',
62 | split='val',
63 | num_class=71,
64 | mask_num=10,
65 | crop_img_and_mask=True,
66 | crop_size=224,
67 | meta_csv_path='data/AVSS/metadata.csv',
68 | label_idx_path='data/AVSS/label2idx.json',
69 | dir_base='data/AVSS',
70 | img_size=(224, 224),
71 | resize_pred_mask=True,
72 | save_pred_mask_img_size=(360, 240),
73 | batch_size=4),
74 | test=dict(
75 | type='V2Dataset',
76 | split='test',
77 | num_class=71,
78 | mask_num=10,
79 | crop_img_and_mask=True,
80 | crop_size=224,
81 | meta_csv_path='data/AVSS/metadata.csv',
82 | label_idx_path='data/AVSS/label2idx.json',
83 | dir_base='data/AVSS',
84 | img_size=(224, 224),
85 | resize_pred_mask=True,
86 | save_pred_mask_img_size=(360, 240),
87 | batch_size=4))
88 | optimizer = dict(
89 | type='AdamW',
90 | lr=2e-5)
91 | loss = dict(
92 | weight_dict=dict(
93 | iou_loss=1.0,
94 | mix_loss=0.1))
95 | process = dict(
96 | num_works=8,
97 | train_epochs=30,
98 | start_eval_epoch=10,
99 | eval_interval=2,
100 | freeze_epochs=0)
101 |
--------------------------------------------------------------------------------
/config/avss/AVSegFormer_res50_avss.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='res50',
6 | init_weights_path='pretrained/resnet50-19c8e357.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=True,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[256, 512, 1024, 2048],
16 | num_classes=71,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | query_generator=dict(
21 | type='AttentionGenerator',
22 | num_layers=6,
23 | query_num=300),
24 | positional_encoding=dict(
25 | type='SinePositionalEncoding',
26 | num_feats=128),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=10)
45 | dataset = dict(
46 | train=dict(
47 | type='V2Dataset',
48 | split='train',
49 | num_class=71,
50 | mask_num=10,
51 | crop_img_and_mask=True,
52 | crop_size=224,
53 | meta_csv_path='data/AVSS/metadata.csv',
54 | label_idx_path='data/AVSS/label2idx.json',
55 | dir_base='data/AVSS',
56 | img_size=(224, 224),
57 | resize_pred_mask=True,
58 | save_pred_mask_img_size=(360, 240),
59 | batch_size=4),
60 | val=dict(
61 | type='V2Dataset',
62 | split='val',
63 | num_class=71,
64 | mask_num=10,
65 | crop_img_and_mask=True,
66 | crop_size=224,
67 | meta_csv_path='data/AVSS/metadata.csv',
68 | label_idx_path='data/AVSS/label2idx.json',
69 | dir_base='data/AVSS',
70 | img_size=(224, 224),
71 | resize_pred_mask=True,
72 | save_pred_mask_img_size=(360, 240),
73 | batch_size=4),
74 | test=dict(
75 | type='V2Dataset',
76 | split='test',
77 | num_class=71,
78 | mask_num=10,
79 | crop_img_and_mask=True,
80 | crop_size=224,
81 | meta_csv_path='data/AVSS/metadata.csv',
82 | label_idx_path='data/AVSS/label2idx.json',
83 | dir_base='data/AVSS',
84 | img_size=(224, 224),
85 | resize_pred_mask=True,
86 | save_pred_mask_img_size=(360, 240),
87 | batch_size=4))
88 | optimizer = dict(
89 | type='AdamW',
90 | lr=2e-5)
91 | loss = dict(
92 | weight_dict=dict(
93 | iou_loss=1.0,
94 | mix_loss=0.1))
95 | process = dict(
96 | num_works=8,
97 | train_epochs=30,
98 | start_eval_epoch=10,
99 | eval_interval=2,
100 | freeze_epochs=0)
101 |
--------------------------------------------------------------------------------
/config/ms3/AVSegFormer_pvt2_ms3.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='pvt_v2_b5',
6 | init_weights_path='pretrained/pvt_v2_b5.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=False,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[64, 128, 320, 512],
16 | num_classes=1,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | query_generator=dict(
21 | type='AttentionGenerator',
22 | num_layers=6,
23 | query_num=300),
24 | positional_encoding=dict(
25 | type='SinePositionalEncoding',
26 | num_feats=128),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=5)
45 | dataset = dict(
46 | train=dict(
47 | type='MS3Dataset',
48 | split='train',
49 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
50 | dir_img='data/Multi-sources/ms3_data/visual_frames',
51 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
52 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
53 | img_size=(224, 224),
54 | batch_size=2),
55 | val=dict(
56 | type='MS3Dataset',
57 | split='val',
58 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
59 | dir_img='data/Multi-sources/ms3_data/visual_frames',
60 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
61 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
62 | img_size=(224, 224),
63 | batch_size=2),
64 | test=dict(
65 | type='MS3Dataset',
66 | split='test',
67 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
68 | dir_img='data/Multi-sources/ms3_data/visual_frames',
69 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
70 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
71 | img_size=(224, 224),
72 | batch_size=2))
73 | optimizer = dict(
74 | type='AdamW',
75 | lr=2e-5)
76 | loss = dict(
77 | weight_dict=dict(
78 | iou_loss=1.0,
79 | mix_loss=0.1),
80 | loss_type='dice')
81 | process = dict(
82 | num_works=8,
83 | train_epochs=60,
84 | freeze_epochs=10)
85 |
--------------------------------------------------------------------------------
/config/ms3/AVSegFormer_res50_ms3.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='res50',
6 | init_weights_path='pretrained/resnet50-19c8e357.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=False,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[256, 512, 1024, 2048],
16 | num_classes=1,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | query_generator=dict(
21 | type='AttentionGenerator',
22 | num_layers=6,
23 | query_num=300),
24 | positional_encoding=dict(
25 | type='SinePositionalEncoding',
26 | num_feats=128),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=5)
45 | dataset = dict(
46 | train=dict(
47 | type='MS3Dataset',
48 | split='train',
49 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
50 | dir_img='data/Multi-sources/ms3_data/visual_frames',
51 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
52 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
53 | img_size=(224, 224),
54 | batch_size=2),
55 | val=dict(
56 | type='MS3Dataset',
57 | split='val',
58 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
59 | dir_img='data/Multi-sources/ms3_data/visual_frames',
60 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
61 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
62 | img_size=(224, 224),
63 | batch_size=2),
64 | test=dict(
65 | type='MS3Dataset',
66 | split='test',
67 | anno_csv='data/Multi-sources/ms3_meta_data.csv',
68 | dir_img='data/Multi-sources/ms3_data/visual_frames',
69 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel',
70 | dir_mask='data/Multi-sources/ms3_data/gt_masks',
71 | img_size=(224, 224),
72 | batch_size=2))
73 | optimizer = dict(
74 | type='AdamW',
75 | lr=2e-5)
76 | loss = dict(
77 | weight_dict=dict(
78 | iou_loss=1.0,
79 | mix_loss=0.1),
80 | loss_type='dice')
81 | process = dict(
82 | num_works=8,
83 | train_epochs=60,
84 | freeze_epochs=10)
85 |
--------------------------------------------------------------------------------
/config/s4/AVSegFormer_pvt2_s4.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='pvt_v2_b5',
6 | init_weights_path='pretrained/pvt_v2_b5.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=False,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[64, 128, 320, 512],
16 | num_classes=1,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | positional_encoding=dict(
21 | type='SinePositionalEncoding',
22 | num_feats=128),
23 | query_generator=dict(
24 | type='AttentionGenerator',
25 | num_layers=6,
26 | query_num=300),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=5)
45 | dataset = dict(
46 | train=dict(
47 | type='S4Dataset',
48 | split='train',
49 | anno_csv='data/Single-source/s4_meta_data.csv',
50 | dir_img='data/Single-source/s4_data/visual_frames',
51 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
52 | dir_mask='data/Single-source/s4_data/gt_masks',
53 | img_size=(224, 224),
54 | batch_size=2),
55 | val=dict(
56 | type='S4Dataset',
57 | split='val',
58 | anno_csv='data/Single-source/s4_meta_data.csv',
59 | dir_img='data/Single-source/s4_data/visual_frames',
60 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
61 | dir_mask='data/Single-source/s4_data/gt_masks',
62 | img_size=(224, 224),
63 | batch_size=2),
64 | test=dict(
65 | type='S4Dataset',
66 | split='test',
67 | anno_csv='data/Single-source/s4_meta_data.csv',
68 | dir_img='data/Single-source/s4_data/visual_frames',
69 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
70 | dir_mask='data/Single-source/s4_data/gt_masks',
71 | img_size=(224, 224),
72 | batch_size=2))
73 | optimizer = dict(
74 | type='AdamW',
75 | lr=2e-5)
76 | loss = dict(
77 | weight_dict=dict(
78 | iou_loss=1.0,
79 | mix_loss=0.1),
80 | loss_type='dice')
81 | process = dict(
82 | num_works=8,
83 | train_epochs=30,
84 | freeze_epochs=5)
85 |
--------------------------------------------------------------------------------
/config/s4/AVSegFormer_res50_s4.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='AVSegFormer',
3 | neck=None,
4 | backbone=dict(
5 | type='res50',
6 | init_weights_path='pretrained/resnet50-19c8e357.pth'),
7 | vggish=dict(
8 | freeze_audio_extractor=True,
9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
10 | preprocess_audio_to_log_mel=False,
11 | postprocess_log_mel_with_pca=False,
12 | pretrained_pca_params_path=None),
13 | head=dict(
14 | type='AVSegHead',
15 | in_channels=[256, 512, 1024, 2048],
16 | num_classes=1,
17 | query_num=300,
18 | use_learnable_queries=True,
19 | fusion_block=dict(type='CrossModalMixer'),
20 | positional_encoding=dict(
21 | type='SinePositionalEncoding',
22 | num_feats=128),
23 | query_generator=dict(
24 | type='AttentionGenerator',
25 | num_layers=6,
26 | query_num=300),
27 | transformer=dict(
28 | type='AVSTransformer',
29 | encoder=dict(
30 | num_layers=6,
31 | layer=dict(
32 | dim=256,
33 | ffn_dim=2048,
34 | dropout=0.1)),
35 | decoder=dict(
36 | num_layers=6,
37 | layer=dict(
38 | dim=256,
39 | ffn_dim=2048,
40 | dropout=0.1)))),
41 | audio_dim=128,
42 | embed_dim=256,
43 | freeze_audio_backbone=True,
44 | T=5)
45 | dataset = dict(
46 | train=dict(
47 | type='S4Dataset',
48 | split='train',
49 | anno_csv='data/Single-source/s4_meta_data.csv',
50 | dir_img='data/Single-source/s4_data/visual_frames',
51 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
52 | dir_mask='data/Single-source/s4_data/gt_masks',
53 | img_size=(224, 224),
54 | batch_size=2),
55 | val=dict(
56 | type='S4Dataset',
57 | split='val',
58 | anno_csv='data/Single-source/s4_meta_data.csv',
59 | dir_img='data/Single-source/s4_data/visual_frames',
60 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
61 | dir_mask='data/Single-source/s4_data/gt_masks',
62 | img_size=(224, 224),
63 | batch_size=2),
64 | test=dict(
65 | type='S4Dataset',
66 | split='test',
67 | anno_csv='data/Single-source/s4_meta_data.csv',
68 | dir_img='data/Single-source/s4_data/visual_frames',
69 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
70 | dir_mask='data/Single-source/s4_data/gt_masks',
71 | img_size=(224, 224),
72 | batch_size=2))
73 | optimizer = dict(
74 | type='AdamW',
75 | lr=2e-5)
76 | loss = dict(
77 | weight_dict=dict(
78 | iou_loss=1.0,
79 | mix_loss=0.1),
80 | loss_type='dice')
81 | process = dict(
82 | num_works=8,
83 | train_epochs=30,
84 | freeze_epochs=5)
85 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .v2_dataset import V2Dataset, get_v2_pallete
2 | from .s4_dataset import S4Dataset
3 | from .ms3_dataset import MS3Dataset
4 | from mmcv import Config
5 |
6 |
7 | def build_dataset(type, split, **kwargs):
8 | if type == 'V2Dataset':
9 | return V2Dataset(split=split, cfg=Config(kwargs))
10 | elif type == 'S4Dataset':
11 | return S4Dataset(split=split, cfg=Config(kwargs))
12 | elif type == 'MS3Dataset':
13 | return MS3Dataset(split=split, cfg=Config(kwargs))
14 | else:
15 | raise ValueError
16 |
17 |
18 | __all__ = ['build_dataset', 'get_v2_pallete']
19 |
--------------------------------------------------------------------------------
/dataloader/ms3_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.data import Dataset
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import pickle
9 |
10 | import cv2
11 | from PIL import Image
12 | from torchvision import transforms
13 |
14 |
15 | def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None):
16 | img_PIL = Image.open(path).convert(mode)
17 | if transform:
18 | img_tensor = transform(img_PIL)
19 | return img_tensor
20 | return img_PIL
21 |
22 |
23 | def load_audio_lm(audio_lm_path):
24 | with open(audio_lm_path, 'rb') as fr:
25 | audio_log_mel = pickle.load(fr)
26 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64]
27 | return audio_log_mel
28 |
29 |
30 | class MS3Dataset(Dataset):
31 | """Dataset for multiple sound source segmentation"""
32 |
33 | def __init__(self, split='train', cfg=None):
34 | super(MS3Dataset, self).__init__()
35 | self.split = split
36 | self.mask_num = 5
37 | self.cfg = cfg
38 | df_all = pd.read_csv(cfg.anno_csv, sep=',')
39 | self.df_split = df_all[df_all['split'] == split]
40 | print("{}/{} videos are used for {}".format(len(self.df_split),
41 | len(df_all), self.split))
42 | self.img_transform = transforms.Compose([
43 | transforms.Resize([512, 512]),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
46 | ])
47 | self.mask_transform = transforms.Compose([
48 | transforms.Resize([512, 512]),
49 | transforms.ToTensor(),
50 | ])
51 |
52 | def __getitem__(self, index):
53 | df_one_video = self.df_split.iloc[index]
54 | video_name = df_one_video[0]
55 | img_base_path = os.path.join(self.cfg.dir_img, video_name)
56 | audio_lm_path = os.path.join(
57 | self.cfg.dir_audio_log_mel, self.split, video_name + '.pkl')
58 | mask_base_path = os.path.join(
59 | self.cfg.dir_mask, self.split, video_name)
60 | audio_log_mel = load_audio_lm(audio_lm_path)
61 | # audio_lm_tensor = torch.from_numpy(audio_log_mel)
62 | imgs, masks = [], []
63 | for img_id in range(1, 6):
64 | img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png" % (
65 | video_name, img_id)), transform=self.img_transform)
66 | imgs.append(img)
67 | for mask_id in range(1, self.mask_num + 1):
68 | mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % (
69 | video_name, mask_id)), transform=self.mask_transform, mode='P')
70 | masks.append(mask)
71 | imgs_tensor = torch.stack(imgs, dim=0)
72 | masks_tensor = torch.stack(masks, dim=0)
73 |
74 | return imgs_tensor, audio_log_mel, masks_tensor, video_name
75 |
76 | def __len__(self):
77 | return len(self.df_split)
78 |
--------------------------------------------------------------------------------
/dataloader/s4_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from wave import _wave_params
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import Dataset
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import pickle
10 |
11 | import cv2
12 | from PIL import Image
13 | from torchvision import transforms
14 |
15 |
16 | def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None):
17 | img_PIL = Image.open(path).convert(mode)
18 | if transform:
19 | img_tensor = transform(img_PIL)
20 | return img_tensor
21 | return img_PIL
22 |
23 |
24 | def load_audio_lm(audio_lm_path):
25 | with open(audio_lm_path, 'rb') as fr:
26 | audio_log_mel = pickle.load(fr)
27 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64]
28 | return audio_log_mel
29 |
30 |
31 | class S4Dataset(Dataset):
32 | """Dataset for single sound source segmentation"""
33 |
34 | def __init__(self, split='train', cfg=None):
35 | super(S4Dataset, self).__init__()
36 | self.split = split
37 | self.cfg = cfg
38 | self.mask_num = 1 if self.split == 'train' else 5
39 | df_all = pd.read_csv(cfg.anno_csv, sep=',')
40 | self.df_split = df_all[df_all['split'] == split]
41 | print("{}/{} videos are used for {}".format(len(self.df_split),
42 | len(df_all), self.split))
43 | self.img_transform = transforms.Compose([
44 | transforms.Resize([512, 512]),
45 | transforms.ToTensor(),
46 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
47 | ])
48 | self.mask_transform = transforms.Compose([
49 | transforms.Resize([512, 512]),
50 | transforms.ToTensor(),
51 | ])
52 |
53 | def __getitem__(self, index):
54 | df_one_video = self.df_split.iloc[index]
55 | video_name, category = df_one_video[0], df_one_video[2]
56 | img_base_path = os.path.join(
57 | self.cfg.dir_img, self.split, category, video_name)
58 | audio_lm_path = os.path.join(
59 | self.cfg.dir_audio_log_mel, self.split, category, video_name + '.pkl')
60 | mask_base_path = os.path.join(
61 | self.cfg.dir_mask, self.split, category, video_name)
62 | audio_log_mel = load_audio_lm(audio_lm_path)
63 | # audio_lm_tensor = torch.from_numpy(audio_log_mel)
64 | imgs, masks = [], []
65 | for img_id in range(1, 6):
66 | img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s_%d.png" % (
67 | video_name, img_id)), transform=self.img_transform)
68 | imgs.append(img)
69 | for mask_id in range(1, self.mask_num + 1):
70 | mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % (
71 | video_name, mask_id)), transform=self.mask_transform, mode='1')
72 | masks.append(mask)
73 | imgs_tensor = torch.stack(imgs, dim=0)
74 | masks_tensor = torch.stack(masks, dim=0)
75 |
76 | if self.split == 'train':
77 | return imgs_tensor, audio_log_mel, masks_tensor
78 | else:
79 | return imgs_tensor, audio_log_mel, masks_tensor, category, video_name
80 |
81 | def __len__(self):
82 | return len(self.df_split)
83 |
--------------------------------------------------------------------------------
/dataloader/v2_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | # from wave import _wave_params
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import Dataset
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import pickle
10 | import json
11 |
12 | # import cv2
13 | from PIL import Image
14 | from torchvision import transforms
15 |
16 | # from .config import cfg_avs
17 |
18 |
19 | def get_v2_pallete(label_to_idx_path, num_cls=71):
20 | def _getpallete(num_cls=71):
21 | """build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
22 | 71 is the total category number of V2 dataset, you should not change that"""
23 | n = num_cls
24 | pallete = [0] * (n * 3)
25 | for j in range(0, n):
26 | lab = j
27 | pallete[j * 3 + 0] = 0
28 | pallete[j * 3 + 1] = 0
29 | pallete[j * 3 + 2] = 0
30 | i = 0
31 | while (lab > 0):
32 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
33 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
34 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
35 | i = i + 1
36 | lab >>= 3
37 | return pallete # list, lenth is n_classes*3
38 |
39 | with open(label_to_idx_path, 'r') as fr:
40 | label_to_pallete_idx = json.load(fr)
41 | v2_pallete = _getpallete(num_cls) # list
42 | v2_pallete = np.array(v2_pallete).reshape(-1, 3)
43 | assert len(v2_pallete) == len(label_to_pallete_idx)
44 | return v2_pallete
45 |
46 |
47 | def crop_resize_img(crop_size, img, img_is_mask=False):
48 | outsize = crop_size
49 | short_size = outsize
50 | w, h = img.size
51 | if w > h:
52 | oh = short_size
53 | ow = int(1.0 * w * oh / h)
54 | else:
55 | ow = short_size
56 | oh = int(1.0 * h * ow / w)
57 | if not img_is_mask:
58 | img = img.resize((ow, oh), Image.BILINEAR)
59 | else:
60 | img = img.resize((ow, oh), Image.NEAREST)
61 | # center crop
62 | w, h = img.size
63 | x1 = int(round((w - outsize) / 2.))
64 | y1 = int(round((h - outsize) / 2.))
65 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
66 | # print("crop for train. set")
67 | return img
68 |
69 |
70 | def resize_img(crop_size, img, img_is_mask=False):
71 | outsize = crop_size
72 | # only resize for val./test. set
73 | if not img_is_mask:
74 | img = img.resize((outsize, outsize), Image.BILINEAR)
75 | else:
76 | img = img.resize((outsize, outsize), Image.NEAREST)
77 | return img
78 |
79 |
80 | def color_mask_to_label(mask, v_pallete):
81 | mask_array = np.array(mask).astype('int32')
82 | semantic_map = []
83 | for colour in v_pallete:
84 | equality = np.equal(mask_array, colour)
85 | class_map = np.all(equality, axis=-1)
86 | semantic_map.append(class_map)
87 | semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
88 | # pdb.set_trace() # there is only one '1' value for each pixel, run np.sum(semantic_map, axis=-1)
89 | label = np.argmax(semantic_map, axis=-1)
90 | return label
91 |
92 |
93 | def load_image_in_PIL_to_Tensor(path, split='train', mode='RGB', transform=None, cfg=None):
94 | img_PIL = Image.open(path).convert(mode)
95 | if cfg.crop_img_and_mask:
96 | if split == 'train':
97 | img_PIL = crop_resize_img(
98 | cfg.crop_size, img_PIL, img_is_mask=False)
99 | else:
100 | img_PIL = resize_img(cfg.crop_size,
101 | img_PIL, img_is_mask=False)
102 | if transform:
103 | img_tensor = transform(img_PIL)
104 | return img_tensor
105 | return img_PIL
106 |
107 |
108 | def load_color_mask_in_PIL_to_Tensor(path, v_pallete, split='train', mode='RGB', cfg=None):
109 | color_mask_PIL = Image.open(path).convert(mode)
110 | if cfg.crop_img_and_mask:
111 | if split == 'train':
112 | color_mask_PIL = crop_resize_img(
113 | cfg.crop_size, color_mask_PIL, img_is_mask=True)
114 | else:
115 | color_mask_PIL = resize_img(
116 | cfg.crop_size, color_mask_PIL, img_is_mask=True)
117 | # obtain semantic label
118 | color_label = color_mask_to_label(color_mask_PIL, v_pallete)
119 | color_label = torch.from_numpy(color_label) # [H, W]
120 | color_label = color_label.unsqueeze(0)
121 | # binary_mask = (color_label != (cfg_avs.NUM_CLASSES-1)).float()
122 | # return color_label, binary_mask # both [1, H, W]
123 | return color_label # both [1, H, W]
124 |
125 |
126 | def load_audio_lm(audio_lm_path):
127 | with open(audio_lm_path, 'rb') as fr:
128 | audio_log_mel = pickle.load(fr)
129 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64]
130 | return audio_log_mel
131 |
132 |
133 | class V2Dataset(Dataset):
134 | """Dataset for audio visual semantic segmentation of AVSBench-semantic (V2)"""
135 |
136 | def __init__(self, split='train', cfg=None, debug_flag=False):
137 | super(V2Dataset, self).__init__()
138 | self.split = split
139 | self.cfg = cfg
140 | self.mask_num = cfg.mask_num
141 | df_all = pd.read_csv(cfg.meta_csv_path, sep=',')
142 | self.df_split = df_all[df_all['split'] == split]
143 | if debug_flag:
144 | self.df_split = self.df_split[:100]
145 | print("{}/{} videos are used for {}.".format(len(self.df_split),
146 | len(df_all), self.split))
147 | self.img_transform = transforms.Compose([
148 | transforms.ToTensor(),
149 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
150 | ])
151 | self.v2_pallete = get_v2_pallete(
152 | cfg.label_idx_path, num_cls=cfg.num_class)
153 |
154 | def __getitem__(self, index):
155 | df_one_video = self.df_split.iloc[index]
156 | video_name, set = df_one_video['uid'], df_one_video['label']
157 | img_base_path = os.path.join(
158 | self.cfg.dir_base, set, video_name, 'frames')
159 | audio_path = os.path.join(
160 | self.cfg.dir_base, set, video_name, 'audio.wav')
161 | color_mask_base_path = os.path.join(
162 | self.cfg.dir_base, set, video_name, 'labels_rgb')
163 |
164 | # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame)
165 | if set == 'v1s':
166 | vid_temporal_mask_flag = torch.Tensor(
167 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool()
168 | gt_temporal_mask_flag = torch.Tensor(
169 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # .bool()
170 | # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated)
171 | elif set == 'v1m':
172 | vid_temporal_mask_flag = torch.Tensor(
173 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool()
174 | gt_temporal_mask_flag = torch.Tensor(
175 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool()
176 | # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated))
177 | elif set == 'v2':
178 | vid_temporal_mask_flag = torch.ones(10) # .bool()
179 | gt_temporal_mask_flag = torch.ones(10) # .bool()
180 |
181 | img_path_list = sorted(os.listdir(img_base_path)
182 | ) # 5 for v1, 10 for new v2
183 | imgs_num = len(img_path_list)
184 | imgs_pad_zero_num = 10 - imgs_num
185 | imgs = []
186 | for img_id in range(imgs_num):
187 | img_path = os.path.join(img_base_path, "%d.jpg" % (img_id))
188 | img = load_image_in_PIL_to_Tensor(
189 | img_path, split=self.split, transform=self.img_transform, cfg=self.cfg)
190 | imgs.append(img)
191 | for pad_i in range(imgs_pad_zero_num): # ! pad black image?
192 | img = torch.zeros_like(img)
193 | imgs.append(img)
194 |
195 | labels = []
196 | mask_path_list = sorted(os.listdir(color_mask_base_path))
197 | for mask_path in mask_path_list:
198 | if not mask_path.endswith(".png"):
199 | mask_path_list.remove(mask_path)
200 | mask_num = len(mask_path_list)
201 | if self.split != 'train':
202 | if set == 'v2':
203 | assert mask_num == 10
204 | else:
205 | assert mask_num == 5
206 |
207 | mask_num = len(mask_path_list)
208 | label_pad_zero_num = 10 - mask_num
209 | for mask_id in range(mask_num):
210 | mask_path = os.path.join(
211 | color_mask_base_path, "%d.png" % (mask_id))
212 | # mask_path = os.path.join(color_mask_base_path, mask_path_list[mask_id])
213 | color_label = load_color_mask_in_PIL_to_Tensor(
214 | mask_path, v_pallete=self.v2_pallete, split=self.split, cfg=self.cfg)
215 | # print('color_label.shape: ', color_label.shape)
216 | labels.append(color_label)
217 | for pad_j in range(label_pad_zero_num):
218 | color_label = torch.zeros_like(color_label)
219 | labels.append(color_label)
220 |
221 | imgs_tensor = torch.stack(imgs, dim=0)
222 | labels_tensor = torch.stack(labels, dim=0)
223 |
224 | return imgs_tensor, audio_path, labels_tensor, \
225 | vid_temporal_mask_flag, gt_temporal_mask_flag, video_name
226 |
227 | def __len__(self):
228 | return len(self.df_split)
229 |
230 | @property
231 | def num_classes(self):
232 | """Number of categories (including background)."""
233 | return self.cfg.num_class
234 |
235 | @property
236 | def classes(self):
237 | """Category names."""
238 | with open(self.cfg.label_idx_path, 'r') as fr:
239 | classes = json.load(fr)
240 | return [label for label in classes.keys()]
241 |
--------------------------------------------------------------------------------
/image/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vvvb-github/AVSegFormer/7083cb6586d1bc06d174a34be3c34bb5814cf0c3/image/arch.png
--------------------------------------------------------------------------------
/model/AVSegFormer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .backbone import build_backbone
4 | # from .neck import build_neck
5 | from .head import build_head
6 | from .vggish import VGGish
7 |
8 |
9 | class AVSegFormer(nn.Module):
10 | def __init__(self,
11 | backbone,
12 | vggish,
13 | head,
14 | neck=None,
15 | audio_dim=128,
16 | embed_dim=256,
17 | T=5,
18 | freeze_audio_backbone=True,
19 | *args, **kwargs):
20 | super().__init__()
21 |
22 | self.embed_dim = embed_dim
23 | self.T = T
24 | self.freeze_audio_backbone = freeze_audio_backbone
25 | self.backbone = build_backbone(**backbone)
26 | self.vggish = VGGish(**vggish)
27 | self.head = build_head(**head)
28 | self.audio_proj = nn.Linear(audio_dim, embed_dim)
29 |
30 | if self.freeze_audio_backbone:
31 | for p in self.vggish.parameters():
32 | p.requires_grad = False
33 | self.freeze_backbone(True)
34 |
35 | self.neck = neck
36 | # if neck is not None:
37 | # self.neck = build_neck(**neck)
38 | # else:
39 | # self.neck = None
40 |
41 | def freeze_backbone(self, freeze=False):
42 | for p in self.backbone.parameters():
43 | p.requires_grad = not freeze
44 |
45 | def mul_temporal_mask(self, feats, vid_temporal_mask_flag=None):
46 | if vid_temporal_mask_flag is None:
47 | return feats
48 | else:
49 | if isinstance(feats, list):
50 | out = []
51 | for x in feats:
52 | out.append(x * vid_temporal_mask_flag)
53 | elif isinstance(feats, torch.Tensor):
54 | out = feats * vid_temporal_mask_flag
55 |
56 | return out
57 |
58 | def extract_feat(self, x):
59 | feats = self.backbone(x)
60 | if self.neck is not None:
61 | feats = self.neck(feats)
62 | return feats
63 |
64 | def forward(self, audio, frames, vid_temporal_mask_flag=None):
65 | if vid_temporal_mask_flag is not None:
66 | vid_temporal_mask_flag = vid_temporal_mask_flag.view(-1, 1, 1, 1)
67 | with torch.no_grad():
68 | audio_feat = self.vggish(audio) # [B*T,128]
69 |
70 | audio_feat = audio_feat.unsqueeze(1)
71 | audio_feat = self.audio_proj(audio_feat)
72 | img_feat = self.extract_feat(frames)
73 | img_feat = self.mul_temporal_mask(img_feat, vid_temporal_mask_flag)
74 |
75 | pred, mask_feature = self.head(img_feat, audio_feat)
76 | pred = self.mul_temporal_mask(pred, vid_temporal_mask_flag)
77 | mask_feature = self.mul_temporal_mask(
78 | mask_feature, vid_temporal_mask_flag)
79 |
80 | return pred, mask_feature
81 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .AVSegFormer import AVSegFormer
2 |
3 |
4 | def build_model(type, **kwargs):
5 | if type == 'AVSegFormer':
6 | return AVSegFormer(**kwargs)
7 | else:
8 | raise ValueError
9 |
--------------------------------------------------------------------------------
/model/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import B2_ResNet
2 | from .pvt import pvt_v2_b5
3 |
4 |
5 | def build_backbone(type, **kwargs):
6 | if type == 'res50':
7 | return B2_ResNet(**kwargs)
8 | elif type=='pvt_v2_b5':
9 | return pvt_v2_b5(**kwargs)
10 |
11 |
12 | __all__=['build_backbone']
13 |
--------------------------------------------------------------------------------
/model/backbone/pvt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | from timm.models.vision_transformer import _cfg
9 | # from mmseg.models.builder import BACKBONES
10 | # from mmseg.utils import get_root_logger
11 | # from mmcv.runner import load_checkpoint
12 | import math
13 |
14 |
15 | class Mlp(nn.Module):
16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
17 | super().__init__()
18 | out_features = out_features or in_features
19 | hidden_features = hidden_features or in_features
20 | self.fc1 = nn.Linear(in_features, hidden_features)
21 | self.dwconv = DWConv(hidden_features)
22 | self.act = act_layer()
23 | self.fc2 = nn.Linear(hidden_features, out_features)
24 | self.drop = nn.Dropout(drop)
25 | self.linear = linear
26 |
27 | if self.linear:
28 | self.relu = nn.ReLU(inplace=True)
29 | self.apply(self._init_weights)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 | elif isinstance(m, nn.LayerNorm):
37 | nn.init.constant_(m.bias, 0)
38 | nn.init.constant_(m.weight, 1.0)
39 | elif isinstance(m, nn.Conv2d):
40 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | fan_out //= m.groups
42 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
43 | if m.bias is not None:
44 | m.bias.data.zero_()
45 |
46 | def forward(self, x, H, W):
47 | x = self.fc1(x)
48 | if self.linear:
49 | x = self.relu(x)
50 | x = self.dwconv(x, H, W)
51 | x = self.act(x)
52 | x = self.drop(x)
53 | x = self.fc2(x)
54 | x = self.drop(x)
55 | return x
56 |
57 |
58 | class Attention(nn.Module):
59 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
60 | super().__init__()
61 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
62 |
63 | self.dim = dim
64 | self.num_heads = num_heads
65 | head_dim = dim // num_heads
66 | self.scale = qk_scale or head_dim ** -0.5
67 |
68 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
69 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
70 | self.attn_drop = nn.Dropout(attn_drop)
71 | self.proj = nn.Linear(dim, dim)
72 | self.proj_drop = nn.Dropout(proj_drop)
73 |
74 | self.linear = linear
75 | self.sr_ratio = sr_ratio
76 | if not linear:
77 | if sr_ratio > 1:
78 | self.sr = nn.Conv2d(
79 | dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
80 | self.norm = nn.LayerNorm(dim)
81 | else:
82 | self.pool = nn.AdaptiveAvgPool2d(7)
83 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
84 | self.norm = nn.LayerNorm(dim)
85 | self.act = nn.GELU()
86 | self.apply(self._init_weights)
87 |
88 | def _init_weights(self, m):
89 | if isinstance(m, nn.Linear):
90 | trunc_normal_(m.weight, std=.02)
91 | if isinstance(m, nn.Linear) and m.bias is not None:
92 | nn.init.constant_(m.bias, 0)
93 | elif isinstance(m, nn.LayerNorm):
94 | nn.init.constant_(m.bias, 0)
95 | nn.init.constant_(m.weight, 1.0)
96 | elif isinstance(m, nn.Conv2d):
97 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98 | fan_out //= m.groups
99 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
100 | if m.bias is not None:
101 | m.bias.data.zero_()
102 |
103 | def forward(self, x, H, W):
104 | B, N, C = x.shape
105 | q = self.q(x).reshape(B, N, self.num_heads, C //
106 | self.num_heads).permute(0, 2, 1, 3)
107 |
108 | if not self.linear:
109 | if self.sr_ratio > 1:
110 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
111 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
112 | x_ = self.norm(x_)
113 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads,
114 | C // self.num_heads).permute(2, 0, 3, 1, 4)
115 | else:
116 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads,
117 | C // self.num_heads).permute(2, 0, 3, 1, 4)
118 | else:
119 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
120 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
121 | x_ = self.norm(x_)
122 | x_ = self.act(x_)
123 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads,
124 | C // self.num_heads).permute(2, 0, 3, 1, 4)
125 | k, v = kv[0], kv[1]
126 |
127 | attn = (q @ k.transpose(-2, -1)) * self.scale
128 | attn = attn.softmax(dim=-1)
129 | attn = self.attn_drop(attn)
130 |
131 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
132 | x = self.proj(x)
133 | x = self.proj_drop(x)
134 |
135 | return x
136 |
137 |
138 | class Block(nn.Module):
139 |
140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
142 | super().__init__()
143 | self.norm1 = norm_layer(dim)
144 | self.attn = Attention(
145 | dim,
146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
149 | self.drop_path = DropPath(
150 | drop_path) if drop_path > 0. else nn.Identity()
151 | self.norm2 = norm_layer(dim)
152 | mlp_hidden_dim = int(dim * mlp_ratio)
153 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
154 | act_layer=act_layer, drop=drop, linear=linear)
155 |
156 | self.apply(self._init_weights)
157 |
158 | def _init_weights(self, m):
159 | if isinstance(m, nn.Linear):
160 | trunc_normal_(m.weight, std=.02)
161 | if isinstance(m, nn.Linear) and m.bias is not None:
162 | nn.init.constant_(m.bias, 0)
163 | elif isinstance(m, nn.LayerNorm):
164 | nn.init.constant_(m.bias, 0)
165 | nn.init.constant_(m.weight, 1.0)
166 | elif isinstance(m, nn.Conv2d):
167 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
168 | fan_out //= m.groups
169 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
170 | if m.bias is not None:
171 | m.bias.data.zero_()
172 |
173 | def forward(self, x, H, W):
174 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
175 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
176 |
177 | return x
178 |
179 |
180 | class OverlapPatchEmbed(nn.Module):
181 | """ Image to Patch Embedding
182 | """
183 |
184 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
185 | super().__init__()
186 | img_size = to_2tuple(img_size)
187 | patch_size = to_2tuple(patch_size)
188 |
189 | assert max(patch_size) > stride, "Set larger patch_size than stride"
190 |
191 | self.img_size = img_size
192 | self.patch_size = patch_size
193 | self.H, self.W = img_size[0] // stride, img_size[1] // stride
194 | self.num_patches = self.H * self.W
195 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
196 | padding=(patch_size[0] // 2, patch_size[1] // 2))
197 | self.norm = nn.LayerNorm(embed_dim)
198 |
199 | self.apply(self._init_weights)
200 |
201 | def _init_weights(self, m):
202 | if isinstance(m, nn.Linear):
203 | trunc_normal_(m.weight, std=.02)
204 | if isinstance(m, nn.Linear) and m.bias is not None:
205 | nn.init.constant_(m.bias, 0)
206 | elif isinstance(m, nn.LayerNorm):
207 | nn.init.constant_(m.bias, 0)
208 | nn.init.constant_(m.weight, 1.0)
209 | elif isinstance(m, nn.Conv2d):
210 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
211 | fan_out //= m.groups
212 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
213 | if m.bias is not None:
214 | m.bias.data.zero_()
215 |
216 | def forward(self, x):
217 | x = self.proj(x)
218 | _, _, H, W = x.shape
219 | x = x.flatten(2).transpose(1, 2)
220 | x = self.norm(x)
221 |
222 | return x, H, W
223 |
224 |
225 | class PyramidVisionTransformerV2(nn.Module):
226 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
227 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
228 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3],
229 | sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False, init_weights_path=None):
230 | super().__init__()
231 | # self.num_classes = num_classes
232 | self.depths = depths
233 | self.num_stages = num_stages
234 | self.linear = linear
235 |
236 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
237 | sum(depths))] # stochastic depth decay rule
238 | cur = 0
239 |
240 | for i in range(num_stages):
241 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
242 | patch_size=7 if i == 0 else 3,
243 | stride=4 if i == 0 else 2,
244 | in_chans=in_chans if i == 0 else embed_dims[i - 1],
245 | embed_dim=embed_dims[i])
246 |
247 | block = nn.ModuleList([Block(
248 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
249 | qk_scale=qk_scale,
250 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur +
251 | j], norm_layer=norm_layer,
252 | sr_ratio=sr_ratios[i], linear=linear)
253 | for j in range(depths[i])])
254 | norm = norm_layer(embed_dims[i])
255 | cur += depths[i]
256 |
257 | setattr(self, f"patch_embed{i + 1}", patch_embed)
258 | setattr(self, f"block{i + 1}", block)
259 | setattr(self, f"norm{i + 1}", norm)
260 |
261 | self.apply(self._init_weights)
262 |
263 | if init_weights_path is not None:
264 | self.initialize_weights(init_weights_path)
265 |
266 | def initialize_weights(self, path):
267 | pvt_model_dict = self.state_dict()
268 | pretrained_state_dicts = torch.load(path)
269 | # for k, v in pretrained_state_dicts['model'].items():
270 | # if k in pvt_model_dict.keys():
271 | # print(k, v.requires_grad)
272 | state_dict = {k: v for k, v in pretrained_state_dicts.items()
273 | if k in pvt_model_dict.keys()}
274 | pvt_model_dict.update(state_dict)
275 | self.load_state_dict(pvt_model_dict)
276 | print(f'==> Load pvt-v2-b5 parameters pretrained on ImageNet: {path}')
277 |
278 | def _init_weights(self, m):
279 | if isinstance(m, nn.Linear):
280 | trunc_normal_(m.weight, std=.02)
281 | if isinstance(m, nn.Linear) and m.bias is not None:
282 | nn.init.constant_(m.bias, 0)
283 | elif isinstance(m, nn.LayerNorm):
284 | nn.init.constant_(m.bias, 0)
285 | nn.init.constant_(m.weight, 1.0)
286 | elif isinstance(m, nn.Conv2d):
287 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
288 | fan_out //= m.groups
289 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
290 | if m.bias is not None:
291 | m.bias.data.zero_()
292 |
293 | def freeze_patch_emb(self):
294 | self.patch_embed1.requires_grad = False
295 |
296 | @torch.jit.ignore
297 | def no_weight_decay(self):
298 | # has pos_embed may be better
299 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}
300 |
301 | def get_classifier(self):
302 | return self.head
303 |
304 | def reset_classifier(self, num_classes, global_pool=''):
305 | self.num_classes = num_classes
306 | self.head = nn.Linear(
307 | self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
308 |
309 | def forward_features(self, x):
310 | B = x.shape[0]
311 | outs = []
312 |
313 | for i in range(self.num_stages):
314 | patch_embed = getattr(self, f"patch_embed{i + 1}")
315 | block = getattr(self, f"block{i + 1}")
316 | norm = getattr(self, f"norm{i + 1}")
317 | x, H, W = patch_embed(x)
318 | for blk in block:
319 | x = blk(x, H, W)
320 | x = norm(x)
321 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
322 | outs.append(x)
323 |
324 | return outs
325 |
326 | def forward(self, x):
327 | x = self.forward_features(x)
328 | # x = self.head(x)
329 |
330 | return x
331 |
332 |
333 | class DWConv(nn.Module):
334 | def __init__(self, dim=768):
335 | super(DWConv, self).__init__()
336 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
337 |
338 | def forward(self, x, H, W):
339 | B, N, C = x.shape
340 | x = x.transpose(1, 2).view(B, C, H, W)
341 | x = self.dwconv(x)
342 | x = x.flatten(2).transpose(1, 2)
343 |
344 | return x
345 |
346 |
347 | @register_model
348 | def pvt_v2_b5(init_weights_path=False, **kwargs):
349 | model = PyramidVisionTransformerV2(
350 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
351 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
352 | drop_rate=0.0, drop_path_rate=0.1, init_weights_path=init_weights_path,
353 | **kwargs)
354 | model.default_cfg = _cfg()
355 |
356 | return model
357 |
--------------------------------------------------------------------------------
/model/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torchvision.models as models
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class B2_ResNet(nn.Module):
85 | # ResNet50 with two branches
86 | def __init__(self, init_weights_path=None):
87 | # self.inplanes = 128
88 | self.inplanes = 64
89 | super(B2_ResNet, self).__init__()
90 |
91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
92 | bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(Bottleneck, 64, 3)
97 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
98 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2)
99 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2)
100 |
101 | self.inplanes = 512
102 | self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2)
103 | self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2)
104 |
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108 | m.weight.data.normal_(0, math.sqrt(2. / n))
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | if init_weights_path is not None:
114 | self.initialize_weights(init_weights_path)
115 |
116 | def initialize_weights(self, path):
117 | res50 = models.resnet50(pretrained=False)
118 | resnet50_dict = torch.load(path)
119 | res50.load_state_dict(resnet50_dict)
120 | pretrained_dict = res50.state_dict()
121 |
122 | all_params = {}
123 | for k, v in self.state_dict().items():
124 | if k in pretrained_dict.keys():
125 | v = pretrained_dict[k]
126 | all_params[k] = v
127 | elif '_1' in k:
128 | name = k.split('_1')[0] + k.split('_1')[1]
129 | v = pretrained_dict[name]
130 | all_params[k] = v
131 | elif '_2' in k:
132 | name = k.split('_2')[0] + k.split('_2')[1]
133 | v = pretrained_dict[name]
134 | all_params[k] = v
135 | assert len(all_params.keys()) == len(self.state_dict().keys())
136 | self.load_state_dict(all_params)
137 | print(f'==> Load pretrained ResNet50 parameters from {path}')
138 |
139 | def _make_layer(self, block, planes, blocks, stride=1):
140 | downsample = None
141 | if stride != 1 or self.inplanes != planes * block.expansion:
142 | downsample = nn.Sequential(
143 | nn.Conv2d(self.inplanes, planes * block.expansion,
144 | kernel_size=1, stride=stride, bias=False),
145 | nn.BatchNorm2d(planes * block.expansion),
146 | )
147 |
148 | layers = []
149 | layers.append(block(self.inplanes, planes, stride, downsample))
150 | self.inplanes = planes * block.expansion
151 | for i in range(1, blocks):
152 | layers.append(block(self.inplanes, planes))
153 |
154 | return nn.Sequential(*layers)
155 |
156 | def forward(self, x, branch=1):
157 | layer3 = getattr(self, f'layer3_{branch}')
158 | layer4 = getattr(self, f'layer4_{branch}')
159 |
160 | x = self.conv1(x)
161 | x = self.bn1(x)
162 | x = self.relu(x)
163 | x = self.maxpool(x)
164 |
165 | x1 = self.layer1(x)
166 | x2 = self.layer2(x1)
167 | x3 = layer3(x2)
168 | x4 = layer4(x3)
169 |
170 | return [x1, x2, x3, x4]
171 |
--------------------------------------------------------------------------------
/model/head/AVSegHead.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from model.utils import build_transformer, build_positional_encoding, build_fusion_block, build_generator
5 | from ops.modules import MSDeformAttn
6 | from torch.nn.init import normal_
7 | from torch.nn.functional import interpolate
8 |
9 |
10 | class Interpolate(nn.Module):
11 | def __init__(self, scale_factor, mode, align_corners=False):
12 | super(Interpolate, self).__init__()
13 |
14 | self.interp = interpolate
15 | self.scale_factor = scale_factor
16 | self.mode = mode
17 | self.align_corners = align_corners
18 |
19 | def forward(self, x):
20 | x = self.interp(
21 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
22 | )
23 | return x
24 |
25 |
26 | class MLP(nn.Module):
27 | """ Very simple multi-layer perceptron (also called FFN)"""
28 |
29 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
30 | super().__init__()
31 | self.num_layers = num_layers
32 | h = [hidden_dim] * (num_layers - 1)
33 | self.layers = nn.ModuleList(nn.Conv2d(n, k, kernel_size=1, stride=1, padding=0)
34 | for n, k in zip([input_dim] + h, h + [output_dim]))
35 |
36 | def forward(self, x):
37 | for i, layer in enumerate(self.layers):
38 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
39 | return x
40 |
41 |
42 | class SimpleFPN(nn.Module):
43 | def __init__(self, channel=256, layers=3):
44 | super().__init__()
45 |
46 | assert layers == 3 # only designed for 3 layers
47 | self.up1 = nn.Sequential(
48 | Interpolate(scale_factor=2, mode='bilinear'),
49 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
50 | )
51 | self.up2 = nn.Sequential(
52 | Interpolate(scale_factor=2, mode='bilinear'),
53 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
54 | )
55 | self.up3 = nn.Sequential(
56 | Interpolate(scale_factor=2, mode='bilinear'),
57 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
58 | )
59 |
60 | def forward(self, x):
61 | x1 = self.up1(x[-1])
62 | x1 = x1 + x[-2]
63 |
64 | x2 = self.up2(x1)
65 | x2 = x2 + x[-3]
66 |
67 | y = self.up3(x2)
68 | return y
69 |
70 |
71 | class AVSegHead(nn.Module):
72 | def __init__(self,
73 | in_channels,
74 | num_classes,
75 | query_num,
76 | transformer,
77 | query_generator,
78 | embed_dim=256,
79 | valid_indices=[1, 2, 3],
80 | scale_factor=4,
81 | positional_encoding=None,
82 | use_learnable_queries=True,
83 | fusion_block=None) -> None:
84 | super().__init__()
85 |
86 | self.in_channels = in_channels
87 | self.embed_dim = embed_dim
88 | self.num_classes = num_classes
89 | self.query_num = query_num
90 | self.valid_indices = valid_indices
91 | self.num_feats = len(valid_indices)
92 | self.scale_factor = scale_factor
93 | self.use_learnable_queries = use_learnable_queries
94 | self.level_embed = nn.Parameter(
95 | torch.Tensor(self.num_feats, embed_dim))
96 | self.learnable_query = nn.Embedding(query_num, embed_dim)
97 |
98 | self.query_generator = build_generator(**query_generator)
99 |
100 | self.transformer = build_transformer(**transformer)
101 | if positional_encoding is not None:
102 | self.positional_encoding = build_positional_encoding(
103 | **positional_encoding)
104 | else:
105 | self.positional_encoding = None
106 |
107 | in_proj = []
108 | for c in in_channels:
109 | in_proj.append(
110 | nn.Sequential(
111 | nn.Conv2d(c, embed_dim, kernel_size=1),
112 | nn.GroupNorm(32, embed_dim)
113 | )
114 | )
115 | self.in_proj = nn.ModuleList(in_proj)
116 | self.mlp = MLP(query_num, 2048, embed_dim, 3)
117 |
118 | if fusion_block is not None:
119 | self.fusion_block = build_fusion_block(**fusion_block)
120 |
121 | self.lateral_conv = nn.Sequential(
122 | nn.Conv2d(embed_dim, embed_dim,
123 | kernel_size=1, stride=1, padding=0),
124 | nn.GroupNorm(32, embed_dim)
125 | )
126 | self.out_conv = nn.Sequential(
127 | nn.Conv2d(embed_dim, embed_dim,
128 | kernel_size=3, stride=1, padding=1),
129 | nn.GroupNorm(32, embed_dim),
130 | nn.ReLU(True)
131 | )
132 |
133 | self.fpn = SimpleFPN()
134 | self.attn_fc = nn.Sequential(
135 | nn.Conv2d(embed_dim, 128, kernel_size=3, stride=1, padding=1),
136 | Interpolate(scale_factor=scale_factor, mode="bilinear"),
137 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
138 | nn.ReLU(True),
139 | nn.Conv2d(32, num_classes, kernel_size=1,
140 | stride=1, padding=0, bias=False)
141 | )
142 | self.fc = nn.Sequential(
143 | nn.Conv2d(embed_dim, 128, kernel_size=3, stride=1, padding=1),
144 | Interpolate(scale_factor=scale_factor, mode="bilinear"),
145 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
146 | nn.ReLU(True),
147 | nn.Conv2d(32, num_classes, kernel_size=1,
148 | stride=1, padding=0, bias=False)
149 | )
150 |
151 | self._reset_parameters()
152 |
153 | def _reset_parameters(self):
154 | for p in self.parameters():
155 | if p.dim() > 1:
156 | nn.init.xavier_uniform_(p)
157 | for m in self.modules():
158 | if isinstance(m, MSDeformAttn):
159 | m._reset_parameters()
160 | normal_(self.level_embed)
161 |
162 | def get_valid_ratio(self, mask):
163 | _, H, W = mask.shape
164 | valid_H = torch.sum(~mask[:, :, 0], 1)
165 | valid_W = torch.sum(~mask[:, 0, :], 1)
166 | valid_ratio_h = valid_H.float() / H
167 | valid_ratio_w = valid_W.float() / W
168 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
169 | return valid_ratio
170 |
171 | def reform_output_squences(self, memory, spatial_shapes, level_start_index, dim=1):
172 | split_size_or_sections = [None] * self.num_feats
173 | for i in range(self.num_feats):
174 | if i < self.num_feats - 1:
175 | split_size_or_sections[i] = level_start_index[i +
176 | 1] - level_start_index[i]
177 | else:
178 | split_size_or_sections[i] = memory.shape[dim] - \
179 | level_start_index[i]
180 | y = torch.split(memory, split_size_or_sections, dim=dim)
181 | return y
182 |
183 | def forward(self, feats, audio_feat):
184 | feat14 = self.in_proj[0](feats[0])
185 | srcs = [self.in_proj[i](feats[i]) for i in self.valid_indices]
186 | masks = [torch.zeros((x.size(0), x.size(2), x.size(
187 | 3)), device=x.device, dtype=torch.bool) for x in srcs]
188 | pos_embeds = []
189 | for m in masks:
190 | pos_embeds.append(self.positional_encoding(m))
191 | # prepare input for encoder
192 | src_flatten = []
193 | mask_flatten = []
194 | lvl_pos_embed_flatten = []
195 | spatial_shapes = []
196 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
197 | bs, c, h, w = src.shape
198 | spatial_shape = (h, w)
199 | spatial_shapes.append(spatial_shape)
200 | src = src.flatten(2).transpose(1, 2)
201 | mask = mask.flatten(1)
202 | pos_embed = pos_embed.flatten(2).transpose(1, 2)
203 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
204 | lvl_pos_embed_flatten.append(lvl_pos_embed)
205 | src_flatten.append(src)
206 | mask_flatten.append(mask)
207 | src_flatten = torch.cat(src_flatten, 1)
208 | mask_flatten = torch.cat(mask_flatten, 1)
209 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
210 | spatial_shapes = torch.as_tensor(
211 | spatial_shapes, dtype=torch.long, device=src_flatten.device)
212 | level_start_index = torch.cat((spatial_shapes.new_zeros(
213 | (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
214 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
215 |
216 | # prepare queries
217 | bs = audio_feat.shape[0]
218 | query = self.query_generator(audio_feat)
219 | if self.use_learnable_queries:
220 | query = query + \
221 | self.learnable_query.weight[None, :, :].repeat(bs, 1, 1)
222 |
223 | memory, outputs = self.transformer(query, src_flatten, spatial_shapes,
224 | level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
225 |
226 | # generate mask feature
227 | mask_feats = []
228 | for i, z in enumerate(self.reform_output_squences(memory, spatial_shapes, level_start_index, 1)):
229 | mask_feats.append(z.transpose(1, 2).view(
230 | bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
231 | cur_fpn = self.lateral_conv(feat14)
232 | mask_feature = mask_feats[0]
233 | mask_feature = cur_fpn + \
234 | F.interpolate(
235 | mask_feature, size=cur_fpn.shape[-2:], mode='bilinear', align_corners=False)
236 | mask_feature = self.out_conv(mask_feature)
237 | if hasattr(self, 'fusion_block'):
238 | mask_feature = self.fusion_block(mask_feature, audio_feat)
239 |
240 | # predict output mask
241 | pred_feature = torch.einsum(
242 | 'bqc,bchw->bqhw', outputs[-1], mask_feature)
243 | pred_feature = self.mlp(pred_feature)
244 | pred_mask = mask_feature + pred_feature
245 | pred_mask = self.fc(pred_mask)
246 |
247 | return pred_mask, mask_feature
248 |
249 | # def forward_prediction_head(self, output, mask_embed, spatial_shapes, level_start_index):
250 | # masks = torch.einsum('bqc,bqn->bcn', output, mask_embed)
251 | # splitted_masks = self.reform_output_squences(
252 | # masks, spatial_shapes, level_start_index, 2)
253 |
254 | # bs = output.shape[0]
255 | # reforms = []
256 | # for i, embed in enumerate(splitted_masks):
257 | # embed = embed.view(
258 | # bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])
259 | # reforms.append(embed)
260 |
261 | # attn_mask = self.fpn(reforms)
262 | # attn_mask = self.attn_fc(attn_mask)
263 | # return attn_mask
264 |
--------------------------------------------------------------------------------
/model/head/__init__.py:
--------------------------------------------------------------------------------
1 | from .AVSegHead import AVSegHead
2 |
3 |
4 | def build_head(type, **kwargs):
5 | if type == 'AVSegHead':
6 | return AVSegHead(**kwargs)
7 | else:
8 | raise ValueError
9 |
10 |
11 | __all__ = ['build_head']
12 |
--------------------------------------------------------------------------------
/model/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import build_transformer
2 | from .positional_encoding import build_positional_encoding
3 | from .fusion_block import build_fusion_block
4 | from .query_generator import build_generator
5 |
6 |
7 | __all__ = ['build_transformer', 'build_positional_encoding',
8 | 'build_fusion_block', 'build_generator']
9 |
--------------------------------------------------------------------------------
/model/utils/fusion_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CrossModalMixer(nn.Module):
6 | def __init__(self, dim=256, n_heads=8, qkv_bias=False, dropout=0.):
7 | super().__init__()
8 |
9 | self.dim = dim
10 | self.n_heads = n_heads
11 | self.dropout = dropout
12 | self.scale = (dim // n_heads)**-0.5
13 |
14 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
15 | self.kv_proj = nn.Linear(dim, dim * 2, bias=qkv_bias)
16 | self.attn_drop = nn.Dropout(dropout)
17 | self.proj = nn.Linear(dim, dim)
18 | self.proj_drop = nn.Dropout(dropout)
19 |
20 | def forward(self, feature_map, audio_feature):
21 | """channel attention for modality fusion
22 |
23 | Args:
24 | feature_map (Tensor): (bs, c, h, w)
25 | audio_feature (Tensor): (bs, 1, c)
26 |
27 | Returns:
28 | Tensor: (bs, c, h, w)
29 | """
30 | flatten_map = feature_map.flatten(2).transpose(1, 2)
31 | B, N, C = flatten_map.shape
32 |
33 | q = self.q_proj(audio_feature).reshape(
34 | B, 1, self.n_heads, C // self.n_heads).permute(0, 2, 1, 3)
35 | kv = self.kv_proj(flatten_map).reshape(
36 | B, N, 2, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
37 | k, v = kv.unbind(0)
38 | attn = (q @ k.transpose(-2, -1)) * self.scale
39 | attn = attn.softmax(dim=-1)
40 | x = (attn @ v).transpose(1, 2).reshape(B, 1, C)
41 | x = self.proj_drop(self.proj(x))
42 |
43 | x = x.sigmoid()
44 | fusion_map = torch.einsum('bchw,bc->bchw', feature_map, x.squeeze())
45 | return fusion_map
46 |
47 |
48 | def build_fusion_block(type, **kwargs):
49 | if type == 'CrossModalMixer':
50 | return CrossModalMixer(**kwargs)
51 | else:
52 | raise ValueError
53 |
--------------------------------------------------------------------------------
/model/utils/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class SinePositionalEncoding(nn.Module):
7 | """Position encoding with sine and cosine functions.
8 |
9 | See `End-to-End Object Detection with Transformers
10 | `_ for details.
11 |
12 | Args:
13 | num_feats (int): The feature dimension for each position
14 | along x-axis or y-axis. Note the final returned dimension
15 | for each position is 2 times of this value.
16 | temperature (int, optional): The temperature used for scaling
17 | the position embedding. Defaults to 10000.
18 | normalize (bool, optional): Whether to normalize the position
19 | embedding. Defaults to False.
20 | scale (float, optional): A scale factor that scales the position
21 | embedding. The scale will be used only when `normalize` is True.
22 | Defaults to 2*pi.
23 | eps (float, optional): A value added to the denominator for
24 | numerical stability. Defaults to 1e-6.
25 | offset (float): offset add to embed when do the normalization.
26 | Defaults to 0.
27 | init_cfg (dict or list[dict], optional): Initialization config dict.
28 | Default: None
29 | """
30 |
31 | def __init__(self,
32 | num_feats,
33 | temperature=10000,
34 | normalize=False,
35 | scale=2 * math.pi,
36 | eps=1e-6,
37 | offset=0.):
38 | super().__init__()
39 | if normalize:
40 | assert isinstance(scale, (float, int)), 'when normalize is set,' \
41 | 'scale should be provided and in float or int type, ' \
42 | f'found {type(scale)}'
43 | self.num_feats = num_feats
44 | self.temperature = temperature
45 | self.normalize = normalize
46 | self.scale = scale
47 | self.eps = eps
48 | self.offset = offset
49 |
50 | def forward(self, mask):
51 | """Forward function for `SinePositionalEncoding`.
52 |
53 | Args:
54 | mask (Tensor): ByteTensor mask. Non-zero values representing
55 | ignored positions, while zero values means valid positions
56 | for this image. Shape [bs, h, w].
57 |
58 | Returns:
59 | pos (Tensor): Returned position embedding with shape
60 | [bs, num_feats*2, h, w].
61 | """
62 | # For convenience of exporting to ONNX, it's required to convert
63 | # `masks` from bool to int.
64 | mask = mask.to(torch.int)
65 | not_mask = 1 - mask # logical_not
66 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
67 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
68 | if self.normalize:
69 | y_embed = (y_embed + self.offset) / \
70 | (y_embed[:, -1:, :] + self.eps) * self.scale
71 | x_embed = (x_embed + self.offset) / \
72 | (x_embed[:, :, -1:] + self.eps) * self.scale
73 | dim_t = torch.arange(
74 | self.num_feats, dtype=torch.float32, device=mask.device)
75 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
76 | pos_x = x_embed[:, :, :, None] / dim_t
77 | pos_y = y_embed[:, :, :, None] / dim_t
78 | # use `view` instead of `flatten` for dynamically exporting to ONNX
79 | B, H, W = mask.size()
80 | pos_x = torch.stack(
81 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
82 | dim=4).view(B, H, W, -1)
83 | pos_y = torch.stack(
84 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
85 | dim=4).view(B, H, W, -1)
86 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
87 | return pos
88 |
89 |
90 | def build_positional_encoding(type, **kwargs):
91 | if type == 'SinePositionalEncoding':
92 | return SinePositionalEncoding(**kwargs)
93 | else:
94 | raise ValueError
95 |
--------------------------------------------------------------------------------
/model/utils/query_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RepeatGenerator(nn.Module):
6 | def __init__(self, query_num) -> None:
7 | super().__init__()
8 | self.query_num = query_num
9 |
10 | def forward(self, audio_feat):
11 | return audio_feat.repeat(1, self.query_num, 1)
12 |
13 |
14 | class AttentionLayer(nn.Module):
15 | def __init__(self, embed_dim, num_heads, hidden_dim) -> None:
16 | super().__init__()
17 | self.self_attn = nn.MultiheadAttention(
18 | embed_dim, num_heads, bias=False, batch_first=True)
19 | self.cross_attn = nn.MultiheadAttention(
20 | embed_dim, num_heads, bias=False, batch_first=True)
21 | self.ffn = nn.Sequential(
22 | nn.Linear(embed_dim, hidden_dim),
23 | nn.GELU(),
24 | nn.Linear(hidden_dim, embed_dim)
25 | )
26 | self.norm1 = nn.LayerNorm(embed_dim)
27 | self.norm2 = nn.LayerNorm(embed_dim)
28 | self.norm3 = nn.LayerNorm(embed_dim)
29 |
30 | def forward(self, query, audio_feat):
31 | out1 = self.self_attn(query, query, query)[0]
32 | query = self.norm1(query+out1)
33 | out2 = self.cross_attn(query, audio_feat, audio_feat)[0]
34 | query = self.norm2(query+out2)
35 | out3 = self.ffn(query)
36 | query = self.norm3(query+out3)
37 | return query
38 |
39 |
40 | class AttentionGenerator(nn.Module):
41 | def __init__(self, num_layers, query_num, embed_dim=256, num_heads=8, hidden_dim=1024):
42 | super().__init__()
43 | self.num_layers = num_layers
44 | self.query_num = query_num
45 | self.embed_dim = embed_dim
46 | self.query = nn.Embedding(query_num, embed_dim)
47 | self.layers = nn.ModuleList(
48 | [AttentionLayer(embed_dim, num_heads, hidden_dim)
49 | for i in range(num_layers)]
50 | )
51 |
52 | self._reset_parameters()
53 |
54 | def _reset_parameters(self):
55 | for p in self.parameters():
56 | if p.dim() > 1:
57 | nn.init.xavier_uniform_(p)
58 |
59 | def forward(self, audio_feat):
60 | bs = audio_feat.shape[0]
61 | query = self.query.weight[None, :, :].repeat(bs, 1, 1)
62 | for layer in self.layers:
63 | query = layer(query, audio_feat)
64 | return query
65 |
66 |
67 | def build_generator(type, **kwargs):
68 | if type == 'AttentionGenerator':
69 | return AttentionGenerator(**kwargs)
70 | elif type == 'RepeatGenerator':
71 | return RepeatGenerator(**kwargs)
72 | else:
73 | raise ValueError
74 |
--------------------------------------------------------------------------------
/model/utils/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ops.modules import MSDeformAttn
4 |
5 |
6 | class AVSTransformerEncoderLayer(nn.Module):
7 | def __init__(self, dim=256, ffn_dim=2048, dropout=0.1, num_levels=3, num_heads=8, num_points=4):
8 | super().__init__()
9 |
10 | # self attention
11 | self.self_attn = MSDeformAttn(dim, num_levels, num_heads, num_points)
12 | self.dropout1 = nn.Dropout(dropout)
13 | self.norm1 = nn.LayerNorm(dim)
14 |
15 | # ffn
16 | self.linear1 = nn.Linear(dim, ffn_dim)
17 | self.activation = nn.GELU()
18 | self.dropout2 = nn.Dropout(dropout)
19 | self.linear2 = nn.Linear(ffn_dim, dim)
20 | self.dropout3 = nn.Dropout(dropout)
21 | self.norm2 = nn.LayerNorm(dim)
22 |
23 | @staticmethod
24 | def with_pos_embed(tensor, pos):
25 | return tensor if pos is None else tensor + pos
26 |
27 | def ffn(self, src):
28 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
29 | src = src + self.dropout3(src2)
30 | src = self.norm2(src)
31 | return src
32 |
33 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
34 | # self attention
35 | src2 = self.self_attn(self.with_pos_embed(
36 | src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
37 | src = src + self.dropout1(src2)
38 | src = self.norm1(src)
39 | # ffn
40 | src = self.ffn(src)
41 | return src
42 |
43 |
44 | class AVSTransformerEncoder(nn.Module):
45 | def __init__(self, num_layers, layer, *args, **kwargs) -> None:
46 | super().__init__()
47 |
48 | self.num_layers = num_layers
49 | self.layers = nn.ModuleList(
50 | [AVSTransformerEncoderLayer(**layer) for i in range(num_layers)]
51 | )
52 |
53 | @staticmethod
54 | def get_reference_points(spatial_shapes, valid_ratios, device):
55 | reference_points_list = []
56 | for lvl, (H_, W_) in enumerate(spatial_shapes):
57 |
58 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
59 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
60 | ref_y = ref_y.reshape(-1)[None] / \
61 | (valid_ratios[:, None, lvl, 1] * H_)
62 | ref_x = ref_x.reshape(-1)[None] / \
63 | (valid_ratios[:, None, lvl, 0] * W_)
64 | ref = torch.stack((ref_x, ref_y), -1)
65 | reference_points_list.append(ref)
66 | reference_points = torch.cat(reference_points_list, 1)
67 | reference_points = reference_points[:, :, None] * valid_ratios[:, None]
68 | return reference_points
69 |
70 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
71 | out = src
72 | reference_points = self.get_reference_points(
73 | spatial_shapes, valid_ratios, device=src.device)
74 | for layer in self.layers:
75 | out = layer(out, pos, reference_points, spatial_shapes,
76 | level_start_index, padding_mask)
77 | return out, reference_points
78 |
79 |
80 | class AVSTransformerDecoderLayer(nn.Module):
81 | def __init__(self, dim=256, num_heads=8, ffn_dim=2048, dropout=0.1, num_levels=3, num_points=4, *args, **kwargs) -> None:
82 | super().__init__()
83 |
84 | # self attention
85 | self.self_attn = nn.MultiheadAttention(
86 | dim, num_heads, dropout=dropout, batch_first=True)
87 | self.dropout1 = nn.Dropout(dropout)
88 | self.norm1 = nn.LayerNorm(dim)
89 |
90 | # cross attention
91 | # self.cross_attn = MSDeformAttn(dim, num_levels, num_heads, num_points)
92 | self.cross_attn = nn.MultiheadAttention(
93 | dim, num_heads, dropout=dropout, batch_first=True)
94 | self.dropout2 = nn.Dropout(dropout)
95 | self.norm2 = nn.LayerNorm(dim)
96 |
97 | # ffn
98 | self.linear1 = nn.Linear(dim, ffn_dim)
99 | self.activation = nn.GELU()
100 | self.dropout3 = nn.Dropout(dropout)
101 | self.linear2 = nn.Linear(ffn_dim, dim)
102 | self.dropout4 = nn.Dropout(dropout)
103 | self.norm3 = nn.LayerNorm(dim)
104 |
105 | def ffn(self, src):
106 | src2 = self.linear2(self.dropout3(self.activation(self.linear1(src))))
107 | src = src + self.dropout4(src2)
108 | src = self.norm3(src)
109 | return src
110 |
111 | def forward(self, query, src, reference_points, spatial_shapes, level_start_index, padding_mask=None):
112 | # self attention
113 | out1 = self.self_attn(query, query, query)[0]
114 | query = query + self.dropout1(out1)
115 | query = self.norm1(query)
116 | # cross attention
117 | out2 = self.cross_attn(
118 | query, src, src, key_padding_mask=padding_mask)[0]
119 | query = query + self.dropout2(out2)
120 | query = self.norm2(query)
121 | # ffn
122 | query = self.ffn(query)
123 | return query
124 |
125 |
126 | class AVSTransformerDecoder(nn.Module):
127 | def __init__(self, num_layers, layer, *args, **kwargs) -> None:
128 | super().__init__()
129 |
130 | self.num_layers = num_layers
131 | self.layers = nn.ModuleList(
132 | [AVSTransformerDecoderLayer(**layer) for i in range(num_layers)]
133 | )
134 |
135 | def forward(self, query, src, reference_points, spatial_shapes, level_start_index, padding_mask=None):
136 | out = query
137 | outputs = []
138 | for layer in self.layers:
139 | out = layer(out, src, reference_points, spatial_shapes,
140 | level_start_index, padding_mask)
141 | outputs.append(out)
142 | return outputs
143 |
144 |
145 | class AVSTransformer(nn.Module):
146 | def __init__(self, encoder, decoder, *args, **kwargs) -> None:
147 | super().__init__()
148 |
149 | self.encoder = AVSTransformerEncoder(**encoder)
150 | self.decoder = AVSTransformerDecoder(**decoder)
151 |
152 | def _reset_parameters(self):
153 | for p in self.parameters():
154 | if p.dim() > 1:
155 | nn.init.xavier_uniform_(p)
156 |
157 | def forward(self, query, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
158 | memory, reference_points = self.encoder(
159 | src, spatial_shapes, level_start_index, valid_ratios, pos, padding_mask)
160 | outputs = self.decoder(query, memory, reference_points,
161 | spatial_shapes, level_start_index, padding_mask)
162 | return memory, outputs
163 |
164 |
165 | def build_transformer(type, **kwargs):
166 | if type == 'AVSTransformer':
167 | return AVSTransformer(**kwargs)
168 | else:
169 | raise ValueError
170 |
--------------------------------------------------------------------------------
/model/vggish/__init__.py:
--------------------------------------------------------------------------------
1 | from .vggish import VGGish
2 |
3 |
4 | __all__=['VGGish']
--------------------------------------------------------------------------------
/model/vggish/mel_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Defines routines to compute mel spectrogram features from audio waveform."""
17 |
18 | import numpy as np
19 |
20 |
21 | def frame(data, window_length, hop_length):
22 | """Convert array into a sequence of successive possibly overlapping frames.
23 |
24 | An n-dimensional array of shape (num_samples, ...) is converted into an
25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame
26 | starts hop_length points after the preceding one.
27 |
28 | This is accomplished using stride_tricks, so the original data is not
29 | copied. However, there is no zero-padding, so any incomplete frames at the
30 | end are not included.
31 |
32 | Args:
33 | data: np.array of dimension N >= 1.
34 | window_length: Number of samples in each frame.
35 | hop_length: Advance (in samples) between each window.
36 |
37 | Returns:
38 | (N+1)-D np.array with as many rows as there are complete frames that can be
39 | extracted.
40 | """
41 | num_samples = data.shape[0]
42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
43 | shape = (num_frames, window_length) + data.shape[1:]
44 | strides = (data.strides[0] * hop_length,) + data.strides
45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
46 |
47 |
48 | def periodic_hann(window_length):
49 | """Calculate a "periodic" Hann window.
50 |
51 | The classic Hann window is defined as a raised cosine that starts and
52 | ends on zero, and where every value appears twice, except the middle
53 | point for an odd-length window. Matlab calls this a "symmetric" window
54 | and np.hanning() returns it. However, for Fourier analysis, this
55 | actually represents just over one cycle of a period N-1 cosine, and
56 | thus is not compactly expressed on a length-N Fourier basis. Instead,
57 | it's better to use a raised cosine that ends just before the final
58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab
59 | calls this a "periodic" window. This routine calculates it.
60 |
61 | Args:
62 | window_length: The number of points in the returned window.
63 |
64 | Returns:
65 | A 1D np.array containing the periodic hann window.
66 | """
67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
68 | np.arange(window_length)))
69 |
70 |
71 | def stft_magnitude(signal, fft_length,
72 | hop_length=None,
73 | window_length=None):
74 | """Calculate the short-time Fourier transform magnitude.
75 |
76 | Args:
77 | signal: 1D np.array of the input time-domain signal.
78 | fft_length: Size of the FFT to apply.
79 | hop_length: Advance (in samples) between each frame passed to FFT.
80 | window_length: Length of each block of samples to pass to FFT.
81 |
82 | Returns:
83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1
84 | unique values of the FFT for the corresponding frame of input samples.
85 | """
86 | frames = frame(signal, window_length, hop_length)
87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period
88 | # window_length) instead of the symmetric Hann of np.hanning (period
89 | # window_length-1).
90 | window = periodic_hann(window_length)
91 | windowed_frames = frames * window
92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
93 |
94 |
95 | # Mel spectrum constants and functions.
96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0
97 | _MEL_HIGH_FREQUENCY_Q = 1127.0
98 |
99 |
100 | def hertz_to_mel(frequencies_hertz):
101 | """Convert frequencies to mel scale using HTK formula.
102 |
103 | Args:
104 | frequencies_hertz: Scalar or np.array of frequencies in hertz.
105 |
106 | Returns:
107 | Object of same size as frequencies_hertz containing corresponding values
108 | on the mel scale.
109 | """
110 | return _MEL_HIGH_FREQUENCY_Q * np.log(
111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
112 |
113 |
114 | def spectrogram_to_mel_matrix(num_mel_bins=20,
115 | num_spectrogram_bins=129,
116 | audio_sample_rate=8000,
117 | lower_edge_hertz=125.0,
118 | upper_edge_hertz=3800.0):
119 | """Return a matrix that can post-multiply spectrogram rows to make mel.
120 |
121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of
122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
123 | "mel spectrogram" M of frames x num_mel_bins. M = S A.
124 |
125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands
126 | to multiply each FFT bin by only one mel weight, then add it, with positive
127 | and negative signs, to the two adjacent mel bands to which that bin
128 | contributes. Here, by expressing this operation as a matrix multiply, we go
129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around
130 | num_fft^2 multiplies and adds. However, because these are all presumably
131 | accomplished in a single call to np.dot(), it's not clear which approach is
132 | faster in Python. The matrix multiplication has the attraction of being more
133 | general and flexible, and much easier to read.
134 |
135 | Args:
136 | num_mel_bins: How many bands in the resulting mel spectrum. This is
137 | the number of columns in the output matrix.
138 | num_spectrogram_bins: How many bins there are in the source spectrogram
139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
140 | only contains the nonredundant FFT bins.
141 | audio_sample_rate: Samples per second of the audio at the input to the
142 | spectrogram. We need this to figure out the actual frequencies for
143 | each spectrogram bin, which dictates how they are mapped into mel.
144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel
145 | spectrum. This corresponds to the lower edge of the lowest triangular
146 | band.
147 | upper_edge_hertz: The desired top edge of the highest frequency band.
148 |
149 | Returns:
150 | An np.array with shape (num_spectrogram_bins, num_mel_bins).
151 |
152 | Raises:
153 | ValueError: if frequency edges are incorrectly ordered or out of range.
154 | """
155 | nyquist_hertz = audio_sample_rate / 2.
156 | if lower_edge_hertz < 0.0:
157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
158 | if lower_edge_hertz >= upper_edge_hertz:
159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
160 | (lower_edge_hertz, upper_edge_hertz))
161 | if upper_edge_hertz > nyquist_hertz:
162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
163 | (upper_edge_hertz, nyquist_hertz))
164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
166 | # The i'th mel band (starting from i=1) has center frequency
167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
169 | # the band_edges_mel arrays.
170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
173 | # of spectrogram values.
174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
175 | for i in range(num_mel_bins):
176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
177 | # Calculate lower and upper slopes for every spectrogram bin.
178 | # Line segments are linear in the *mel* domain, not hertz.
179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
180 | (center_mel - lower_edge_mel))
181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
182 | (upper_edge_mel - center_mel))
183 | # .. then intersect them with each other and zero.
184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
185 | upper_slope))
186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero
187 | # coefficient.
188 | mel_weights_matrix[0, :] = 0.0
189 | return mel_weights_matrix
190 |
191 |
192 | def log_mel_spectrogram(data,
193 | audio_sample_rate=8000,
194 | log_offset=0.0,
195 | window_length_secs=0.025,
196 | hop_length_secs=0.010,
197 | **kwargs):
198 | """Convert waveform to a log magnitude mel-frequency spectrogram.
199 |
200 | Args:
201 | data: 1D np.array of waveform data.
202 | audio_sample_rate: The sampling rate of data.
203 | log_offset: Add this to values when taking log to avoid -Infs.
204 | window_length_secs: Duration of each window to analyze.
205 | hop_length_secs: Advance between successive analysis windows.
206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
207 |
208 | Returns:
209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
210 | magnitudes for successive frames.
211 | """
212 | window_length_samples = int(round(audio_sample_rate * window_length_secs))
213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
215 | spectrogram = stft_magnitude(
216 | data,
217 | fft_length=fft_length,
218 | hop_length=hop_length_samples,
219 | window_length=window_length_samples)
220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
221 | num_spectrogram_bins=spectrogram.shape[1],
222 | audio_sample_rate=audio_sample_rate, **kwargs))
223 | return np.log(mel_spectrogram + log_offset)
224 |
--------------------------------------------------------------------------------
/model/vggish/vggish.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from . import vggish_input, vggish_params
5 |
6 |
7 | class VGG(nn.Module):
8 | def __init__(self, features):
9 | super(VGG, self).__init__()
10 | self.features = features
11 | self.embeddings = nn.Sequential(
12 | nn.Linear(512 * 4 * 6, 4096),
13 | nn.ReLU(True),
14 | nn.Linear(4096, 4096),
15 | nn.ReLU(True),
16 | nn.Linear(4096, 128),
17 | nn.ReLU(True))
18 |
19 | def forward(self, x):
20 | x = self.features(x)
21 |
22 | # Transpose the output from features to
23 | # remain compatible with vggish embeddings
24 | x = torch.transpose(x, 1, 3)
25 | x = torch.transpose(x, 1, 2)
26 | x = x.contiguous()
27 | x = x.view(x.size(0), -1)
28 |
29 | return self.embeddings(x)
30 |
31 |
32 | class Postprocessor(nn.Module):
33 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a
34 | numpy array in order to preserve the gradient.
35 |
36 | "The initial release of AudioSet included 128-D VGGish embeddings for each
37 | segment of AudioSet. These released embeddings were produced by applying
38 | a PCA transformation (technically, a whitening transform is included as well)
39 | and 8-bit quantization to the raw embedding output from VGGish, in order to
40 | stay compatible with the YouTube-8M project which provides visual embeddings
41 | in the same format for a large set of YouTube videos. This class implements
42 | the same PCA (with whitening) and quantization transformations."
43 | """
44 |
45 | def __init__(self):
46 | """Constructs a postprocessor."""
47 | super(Postprocessor, self).__init__()
48 | # Create empty matrix, for user's state_dict to load
49 | self.pca_eigen_vectors = torch.empty(
50 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
51 | dtype=torch.float,
52 | )
53 | self.pca_means = torch.empty(
54 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
55 | )
56 |
57 | self.pca_eigen_vectors = nn.Parameter(
58 | self.pca_eigen_vectors, requires_grad=False)
59 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)
60 |
61 | def postprocess(self, embeddings_batch):
62 | """Applies tensor postprocessing to a batch of embeddings.
63 |
64 | Args:
65 | embeddings_batch: An tensor of shape [batch_size, embedding_size]
66 | containing output from the embedding layer of VGGish.
67 |
68 | Returns:
69 | A tensor of the same shape as the input, containing the PCA-transformed,
70 | quantized, and clipped version of the input.
71 | """
72 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % (
73 | embeddings_batch.shape,
74 | )
75 | assert (
76 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE
77 | ), "Bad batch shape: %r" % (embeddings_batch.shape,)
78 |
79 | # Apply PCA.
80 | # - Embeddings come in as [batch_size, embedding_size].
81 | # - Transpose to [embedding_size, batch_size].
82 | # - Subtract pca_means column vector from each column.
83 | # - Premultiply by PCA matrix of shape [output_dims, input_dims]
84 | # where both are are equal to embedding_size in our case.
85 | # - Transpose result back to [batch_size, embedding_size].
86 | pca_applied = torch.mm(self.pca_eigen_vectors,
87 | (embeddings_batch.t() - self.pca_means)).t()
88 |
89 | # Quantize by:
90 | # - clipping to [min, max] range
91 | clipped_embeddings = torch.clamp(
92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL
93 | )
94 | # - convert to 8-bit in range [0.0, 255.0]
95 | quantized_embeddings = torch.round(
96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL)
97 | * (
98 | 255.0
99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)
100 | )
101 | )
102 | return torch.squeeze(quantized_embeddings)
103 |
104 | def forward(self, x):
105 | return self.postprocess(x)
106 |
107 |
108 | def make_layers():
109 | layers = []
110 | in_channels = 1
111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]:
112 | if v == "M":
113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
114 | else:
115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
116 | layers += [conv2d, nn.ReLU(inplace=True)]
117 | in_channels = v
118 | return nn.Sequential(*layers)
119 |
120 |
121 | def _vgg():
122 | return VGG(make_layers())
123 |
124 |
125 | class VGGish(VGG):
126 | def __init__(self,
127 | freeze_audio_extractor,
128 | pretrained_vggish_model_path,
129 | preprocess_audio_to_log_mel,
130 | postprocess_log_mel_with_pca,
131 | pretrained_pca_params_path,
132 | *args, **kwargs):
133 | super().__init__(make_layers())
134 | if freeze_audio_extractor:
135 | state_dict = torch.load(pretrained_vggish_model_path)
136 | super().load_state_dict(state_dict)
137 |
138 | self.preprocess = preprocess_audio_to_log_mel
139 | self.postprocess = postprocess_log_mel_with_pca
140 | if self.postprocess:
141 | self.pproc = Postprocessor()
142 | if freeze_audio_extractor:
143 | state_dict = torch.load(pretrained_pca_params_path)
144 | # TODO: Convert the state_dict to torch
145 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
146 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
147 | )
148 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
149 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
150 | )
151 | self.pproc.load_state_dict(state_dict)
152 |
153 | def forward(self, x):
154 | if self.preprocess:
155 | x = self._preprocess(x)
156 | x = VGG.forward(self, x)
157 | if self.postprocess:
158 | x = self._postprocess(x)
159 | return x
160 |
161 | def _preprocess(self, x):
162 | if isinstance(x, str):
163 | x = vggish_input.waveform_to_examples(x)
164 | return x
165 | else:
166 | batch_num = len(x)
167 | audio_fea_list = []
168 | for xx in x:
169 | if isinstance(xx, str):
170 | xx = vggish_input.wavfile_to_examples(
171 | xx) # [5 or 10, 1, 96, 64]
172 | #! notice:
173 | if xx.shape[0] != 10:
174 | new_xx = torch.zeros(10, 1, 96, 64)
175 | new_xx[:xx.shape[0]] = xx
176 | audio_fea_list.append(new_xx)
177 | else:
178 | audio_fea_list.append(xx)
179 |
180 | # [bs, 10, 1, 96, 64]
181 | audio_fea = torch.stack(audio_fea_list, dim=0)
182 | audio_fea = audio_fea.view(
183 | batch_num * 10, xx.shape[1], xx.shape[2], xx.shape[3]) # [bs*10, 1, 96, 64]
184 | audio_fea = audio_fea.cuda()
185 | return audio_fea
186 |
187 | def _postprocess(self, x):
188 | return self.pproc(x)
189 |
--------------------------------------------------------------------------------
/model/vggish/vggish_input.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute input examples for VGGish from audio waveform."""
17 |
18 | # Modification: Return torch tensors rather than numpy arrays
19 | import torch
20 |
21 | import numpy as np
22 | import resampy
23 |
24 | from . import mel_features
25 | from . import vggish_params
26 |
27 | import soundfile as sf
28 |
29 |
30 | def waveform_to_examples(data, sample_rate, return_tensor=True):
31 | """Converts audio waveform into an array of examples for VGGish.
32 |
33 | Args:
34 | data: np.array of either one dimension (mono) or two dimensions
35 | (multi-channel, with the outer dimension representing channels).
36 | Each sample is generally expected to lie in the range [-1.0, +1.0],
37 | although this is not required.
38 | sample_rate: Sample rate of data.
39 | return_tensor: Return data as a Pytorch tensor ready for VGGish
40 |
41 | Returns:
42 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents
43 | a sequence of examples, each of which contains a patch of log mel
44 | spectrogram, covering num_frames frames of audio and num_bands mel frequency
45 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
46 |
47 | """
48 | # Convert to mono.
49 | if len(data.shape) > 1:
50 | data = np.mean(data, axis=1)
51 | # Resample to the rate assumed by VGGish.
52 | if sample_rate != vggish_params.SAMPLE_RATE:
53 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
54 |
55 | # Compute log mel spectrogram features.
56 | log_mel = mel_features.log_mel_spectrogram(
57 | data,
58 | audio_sample_rate=vggish_params.SAMPLE_RATE,
59 | log_offset=vggish_params.LOG_OFFSET,
60 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
61 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
62 | num_mel_bins=vggish_params.NUM_MEL_BINS,
63 | lower_edge_hertz=vggish_params.MEL_MIN_HZ,
64 | upper_edge_hertz=vggish_params.MEL_MAX_HZ)
65 |
66 | # Frame features into examples.
67 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
68 | example_window_length = int(round(
69 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
70 | example_hop_length = int(round(
71 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
72 | log_mel_examples = mel_features.frame(
73 | log_mel,
74 | window_length=example_window_length,
75 | hop_length=example_hop_length)
76 |
77 | if return_tensor:
78 | log_mel_examples = torch.tensor(
79 | log_mel_examples, requires_grad=True)[:, None, :, :].float()
80 |
81 | return log_mel_examples
82 |
83 |
84 | def wavfile_to_examples(wav_file, return_tensor=True):
85 | """Convenience wrapper around waveform_to_examples() for a common WAV format.
86 |
87 | Args:
88 | wav_file: String path to a file, or a file-like object. The file
89 | is assumed to contain WAV audio data with signed 16-bit PCM samples.
90 | torch: Return data as a Pytorch tensor ready for VGGish
91 |
92 | Returns:
93 | See waveform_to_examples.
94 | """
95 | wav_data, sr = sf.read(wav_file, dtype='int16')
96 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
97 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
98 | return waveform_to_examples(samples, sr, return_tensor)
99 |
--------------------------------------------------------------------------------
/model/vggish/vggish_params.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Global parameters for the VGGish model.
17 |
18 | See vggish_slim.py for more information.
19 | """
20 |
21 | # Architectural constants.
22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
24 | EMBEDDING_SIZE = 128 # Size of embedding layer.
25 |
26 | # Hyperparameters used in feature and example generation.
27 | SAMPLE_RATE = 16000
28 | STFT_WINDOW_LENGTH_SECONDS = 0.025
29 | STFT_HOP_LENGTH_SECONDS = 0.010
30 | NUM_MEL_BINS = NUM_BANDS
31 | MEL_MIN_HZ = 125
32 | MEL_MAX_HZ = 7500
33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
36 |
37 | # Parameters used for embedding postprocessing.
38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
39 | PCA_MEANS_NAME = 'pca_means'
40 | QUANTIZE_MIN_VAL = -2.0
41 | QUANTIZE_MAX_VAL = +2.0
42 |
43 | # Hyperparameters used in training.
44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
47 |
48 | # Names of ops, tensors, and features.
49 | INPUT_OP_NAME = 'vggish/input_features'
50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
51 | OUTPUT_OP_NAME = 'vggish/embedding'
52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
54 |
--------------------------------------------------------------------------------
/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from .ms_deform_attn_func import MSDeformAttnFunction
13 |
14 |
--------------------------------------------------------------------------------
/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from torch.autograd import Function
19 | from torch.autograd.function import once_differentiable
20 |
21 | try:
22 | import MultiScaleDeformableAttention as MSDA
23 | except ModuleNotFoundError as e:
24 | info_string = (
25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26 | "\t`cd maskdino/modeling/pixel_decoder/ops`\n"
27 | "\t`sh make.sh`\n"
28 | )
29 | raise ModuleNotFoundError(info_string)
30 |
31 |
32 | class MSDeformAttnFunction(Function):
33 | @staticmethod
34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35 | ctx.im2col_step = im2col_step
36 | output = MSDA.ms_deform_attn_forward(
37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39 | return output
40 |
41 | @staticmethod
42 | @once_differentiable
43 | def backward(ctx, grad_output):
44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45 | grad_value, grad_sampling_loc, grad_attn_weight = \
46 | MSDA.ms_deform_attn_backward(
47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48 |
49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50 |
51 |
52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53 | # for debug and test only,
54 | # need to use cuda version instead
55 | N_, S_, M_, D_ = value.shape
56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58 | sampling_grids = 2 * sampling_locations - 1
59 | sampling_value_list = []
60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65 | # N_*M_, D_, Lq_, P_
66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67 | mode='bilinear', padding_mode='zeros', align_corners=False)
68 | sampling_value_list.append(sampling_value_l_)
69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72 | return output.transpose(1, 2).contiguous()
73 |
--------------------------------------------------------------------------------
/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | # Copyright (c) Facebook, Inc. and its affiliates.
11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12 |
13 | python setup.py build install
14 |
--------------------------------------------------------------------------------
/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from .ms_deform_attn import MSDeformAttn
13 |
--------------------------------------------------------------------------------
/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 |
16 | import warnings
17 | import math
18 |
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 | from torch.nn.init import xavier_uniform_, constant_
23 |
24 | from ..functions import MSDeformAttnFunction
25 | from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
26 |
27 |
28 | def _is_power_of_2(n):
29 | if (not isinstance(n, int)) or (n < 0):
30 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
31 | return (n & (n-1) == 0) and n != 0
32 |
33 |
34 | class MSDeformAttn(nn.Module):
35 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
36 | """
37 | Multi-Scale Deformable Attention Module
38 | :param d_model hidden dimension
39 | :param n_levels number of feature levels
40 | :param n_heads number of attention heads
41 | :param n_points number of sampling points per attention head per feature level
42 | """
43 | super().__init__()
44 | if d_model % n_heads != 0:
45 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
46 | _d_per_head = d_model // n_heads
47 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
48 | if not _is_power_of_2(_d_per_head):
49 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
50 | "which is more efficient in our CUDA implementation.")
51 |
52 | self.im2col_step = 128
53 |
54 | self.d_model = d_model
55 | self.n_levels = n_levels
56 | self.n_heads = n_heads
57 | self.n_points = n_points
58 |
59 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
60 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
61 | self.value_proj = nn.Linear(d_model, d_model)
62 | self.output_proj = nn.Linear(d_model, d_model)
63 |
64 | self._reset_parameters()
65 |
66 | def _reset_parameters(self):
67 | constant_(self.sampling_offsets.weight.data, 0.)
68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71 | for i in range(self.n_points):
72 | grid_init[:, :, i, :] *= i + 1
73 | with torch.no_grad():
74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75 | constant_(self.attention_weights.weight.data, 0.)
76 | constant_(self.attention_weights.bias.data, 0.)
77 | xavier_uniform_(self.value_proj.weight.data)
78 | constant_(self.value_proj.bias.data, 0.)
79 | xavier_uniform_(self.output_proj.weight.data)
80 | constant_(self.output_proj.bias.data, 0.)
81 |
82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83 | """
84 | :param query (N, Length_{query}, C)
85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91 |
92 | :return output (N, Length_{query}, C)
93 | """
94 | N, Len_q, _ = query.shape
95 | N, Len_in, _ = input_flatten.shape
96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
97 |
98 | value = self.value_proj(input_flatten)
99 | if input_padding_mask is not None:
100 | value = value.masked_fill(input_padding_mask[..., None], float(0))
101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
105 | # N, Len_q, n_heads, n_levels, n_points, 2
106 | if reference_points.shape[-1] == 2:
107 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
108 | sampling_locations = reference_points[:, :, None, :, None, :] \
109 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
110 | elif reference_points.shape[-1] == 4:
111 | sampling_locations = reference_points[:, :, None, :, None, :2] \
112 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
113 | else:
114 | raise ValueError(
115 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
116 | try:
117 | output = MSDeformAttnFunction.apply(
118 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
119 | except:
120 | # CPU
121 | output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
122 | # # For FLOPs calculation only
123 | # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
124 | output = self.output_proj(output)
125 | return output
126 |
--------------------------------------------------------------------------------
/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | import os
13 | import glob
14 |
15 | import torch
16 |
17 | from torch.utils.cpp_extension import CUDA_HOME
18 | from torch.utils.cpp_extension import CppExtension
19 | from torch.utils.cpp_extension import CUDAExtension
20 |
21 | from setuptools import find_packages
22 | from setuptools import setup
23 |
24 | requirements = ["torch", "torchvision"]
25 |
26 | def get_extensions():
27 | this_dir = os.path.dirname(os.path.abspath(__file__))
28 | extensions_dir = os.path.join(this_dir, "src")
29 |
30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33 |
34 | sources = main_file + source_cpu
35 | extension = CppExtension
36 | extra_compile_args = {"cxx": []}
37 | define_macros = []
38 |
39 | # Force cuda since torch ask for a device, not if cuda is in fact available.
40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41 | extension = CUDAExtension
42 | sources += source_cuda
43 | define_macros += [("WITH_CUDA", None)]
44 | extra_compile_args["nvcc"] = [
45 | "-DCUDA_HAS_FP16=1",
46 | "-D__CUDA_NO_HALF_OPERATORS__",
47 | "-D__CUDA_NO_HALF_CONVERSIONS__",
48 | "-D__CUDA_NO_HALF2_OPERATORS__",
49 | ]
50 | else:
51 | if CUDA_HOME is None:
52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53 | else:
54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55 |
56 | sources = [os.path.join(extensions_dir, s) for s in sources]
57 | include_dirs = [extensions_dir]
58 | ext_modules = [
59 | extension(
60 | "MultiScaleDeformableAttention",
61 | sources,
62 | include_dirs=include_dirs,
63 | define_macros=define_macros,
64 | extra_compile_args=extra_compile_args,
65 | )
66 | ]
67 | return ext_modules
68 |
69 | setup(
70 | name="MultiScaleDeformableAttention",
71 | version="1.0",
72 | author="Weijie Su",
73 | url="https://github.com/fundamentalvision/Deformable-DETR",
74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75 | packages=find_packages(exclude=("configs", "tests",)),
76 | ext_modules=get_extensions(),
77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78 | )
79 |
--------------------------------------------------------------------------------
/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #include
17 |
18 | #include
19 | #include
20 |
21 |
22 | at::Tensor
23 | ms_deform_attn_cpu_forward(
24 | const at::Tensor &value,
25 | const at::Tensor &spatial_shapes,
26 | const at::Tensor &level_start_index,
27 | const at::Tensor &sampling_loc,
28 | const at::Tensor &attn_weight,
29 | const int im2col_step)
30 | {
31 | AT_ERROR("Not implement on cpu");
32 | }
33 |
34 | std::vector
35 | ms_deform_attn_cpu_backward(
36 | const at::Tensor &value,
37 | const at::Tensor &spatial_shapes,
38 | const at::Tensor &level_start_index,
39 | const at::Tensor &sampling_loc,
40 | const at::Tensor &attn_weight,
41 | const at::Tensor &grad_output,
42 | const int im2col_step)
43 | {
44 | AT_ERROR("Not implement on cpu");
45 | }
46 |
47 |
--------------------------------------------------------------------------------
/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 | #include
18 |
19 | at::Tensor
20 | ms_deform_attn_cpu_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step);
27 |
28 | std::vector
29 | ms_deform_attn_cpu_backward(
30 | const at::Tensor &value,
31 | const at::Tensor &spatial_shapes,
32 | const at::Tensor &level_start_index,
33 | const at::Tensor &sampling_loc,
34 | const at::Tensor &attn_weight,
35 | const at::Tensor &grad_output,
36 | const int im2col_step);
37 |
38 |
39 |
--------------------------------------------------------------------------------
/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #include
17 | #include "cuda/ms_deform_im2col_cuda.cuh"
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 |
25 | at::Tensor ms_deform_attn_cuda_forward(
26 | const at::Tensor &value,
27 | const at::Tensor &spatial_shapes,
28 | const at::Tensor &level_start_index,
29 | const at::Tensor &sampling_loc,
30 | const at::Tensor &attn_weight,
31 | const int im2col_step)
32 | {
33 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
34 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
35 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
36 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
37 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
38 |
39 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
40 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
41 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
42 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
43 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
44 |
45 | const int batch = value.size(0);
46 | const int spatial_size = value.size(1);
47 | const int num_heads = value.size(2);
48 | const int channels = value.size(3);
49 |
50 | const int num_levels = spatial_shapes.size(0);
51 |
52 | const int num_query = sampling_loc.size(1);
53 | const int num_point = sampling_loc.size(4);
54 |
55 | const int im2col_step_ = std::min(batch, im2col_step);
56 |
57 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
58 |
59 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
60 |
61 | const int batch_n = im2col_step_;
62 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
63 | auto per_value_size = spatial_size * num_heads * channels;
64 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
65 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
66 | for (int n = 0; n < batch/im2col_step_; ++n)
67 | {
68 | auto columns = output_n.select(0, n);
69 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
70 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
71 | value.data() + n * im2col_step_ * per_value_size,
72 | spatial_shapes.data(),
73 | level_start_index.data(),
74 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
75 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
76 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
77 | columns.data());
78 |
79 | }));
80 | }
81 |
82 | output = output.view({batch, num_query, num_heads*channels});
83 |
84 | return output;
85 | }
86 |
87 |
88 | std::vector ms_deform_attn_cuda_backward(
89 | const at::Tensor &value,
90 | const at::Tensor &spatial_shapes,
91 | const at::Tensor &level_start_index,
92 | const at::Tensor &sampling_loc,
93 | const at::Tensor &attn_weight,
94 | const at::Tensor &grad_output,
95 | const int im2col_step)
96 | {
97 |
98 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104 |
105 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
109 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
110 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
111 |
112 | const int batch = value.size(0);
113 | const int spatial_size = value.size(1);
114 | const int num_heads = value.size(2);
115 | const int channels = value.size(3);
116 |
117 | const int num_levels = spatial_shapes.size(0);
118 |
119 | const int num_query = sampling_loc.size(1);
120 | const int num_point = sampling_loc.size(4);
121 |
122 | const int im2col_step_ = std::min(batch, im2col_step);
123 |
124 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125 |
126 | auto grad_value = at::zeros_like(value);
127 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
128 | auto grad_attn_weight = at::zeros_like(attn_weight);
129 |
130 | const int batch_n = im2col_step_;
131 | auto per_value_size = spatial_size * num_heads * channels;
132 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135 |
136 | for (int n = 0; n < batch/im2col_step_; ++n)
137 | {
138 | auto grad_output_g = grad_output_n.select(0, n);
139 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
140 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141 | grad_output_g.data(),
142 | value.data() + n * im2col_step_ * per_value_size,
143 | spatial_shapes.data(),
144 | level_start_index.data(),
145 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
146 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
147 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148 | grad_value.data() + n * im2col_step_ * per_value_size,
149 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
150 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
151 |
152 | }));
153 | }
154 |
155 | return {
156 | grad_value, grad_sampling_loc, grad_attn_weight
157 | };
158 | }
--------------------------------------------------------------------------------
/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 | #include
18 |
19 | at::Tensor ms_deform_attn_cuda_forward(
20 | const at::Tensor &value,
21 | const at::Tensor &spatial_shapes,
22 | const at::Tensor &level_start_index,
23 | const at::Tensor &sampling_loc,
24 | const at::Tensor &attn_weight,
25 | const int im2col_step);
26 |
27 | std::vector ms_deform_attn_cuda_backward(
28 | const at::Tensor &value,
29 | const at::Tensor &spatial_shapes,
30 | const at::Tensor &level_start_index,
31 | const at::Tensor &sampling_loc,
32 | const at::Tensor &attn_weight,
33 | const at::Tensor &grad_output,
34 | const int im2col_step);
35 |
36 |
--------------------------------------------------------------------------------
/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #pragma once
17 |
18 | #include "cpu/ms_deform_attn_cpu.h"
19 |
20 | #ifdef WITH_CUDA
21 | #include "cuda/ms_deform_attn_cuda.h"
22 | #endif
23 |
24 |
25 | at::Tensor
26 | ms_deform_attn_forward(
27 | const at::Tensor &value,
28 | const at::Tensor &spatial_shapes,
29 | const at::Tensor &level_start_index,
30 | const at::Tensor &sampling_loc,
31 | const at::Tensor &attn_weight,
32 | const int im2col_step)
33 | {
34 | if (value.type().is_cuda())
35 | {
36 | #ifdef WITH_CUDA
37 | return ms_deform_attn_cuda_forward(
38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39 | #else
40 | AT_ERROR("Not compiled with GPU support");
41 | #endif
42 | }
43 | AT_ERROR("Not implemented on the CPU");
44 | }
45 |
46 | std::vector
47 | ms_deform_attn_backward(
48 | const at::Tensor &value,
49 | const at::Tensor &spatial_shapes,
50 | const at::Tensor &level_start_index,
51 | const at::Tensor &sampling_loc,
52 | const at::Tensor &attn_weight,
53 | const at::Tensor &grad_output,
54 | const int im2col_step)
55 | {
56 | if (value.type().is_cuda())
57 | {
58 | #ifdef WITH_CUDA
59 | return ms_deform_attn_cuda_backward(
60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61 | #else
62 | AT_ERROR("Not compiled with GPU support");
63 | #endif
64 | }
65 | AT_ERROR("Not implemented on the CPU");
66 | }
67 |
68 |
--------------------------------------------------------------------------------
/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | /*!
12 | * Copyright (c) Facebook, Inc. and its affiliates.
13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14 | */
15 |
16 | #include "ms_deform_attn.h"
17 |
18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21 | }
22 |
--------------------------------------------------------------------------------
/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | # Copyright (c) Facebook, Inc. and its affiliates.
10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11 |
12 | from __future__ import absolute_import
13 | from __future__ import print_function
14 | from __future__ import division
15 |
16 | import time
17 | import torch
18 | import torch.nn as nn
19 | from torch.autograd import gradcheck
20 |
21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22 |
23 |
24 | N, M, D = 1, 2, 2
25 | Lq, L, P = 2, 2, 2
26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28 | S = sum([(H*W).item() for H, W in shapes])
29 |
30 |
31 | torch.manual_seed(3)
32 |
33 |
34 | @torch.no_grad()
35 | def check_forward_equal_with_pytorch_double():
36 | value = torch.rand(N, S, M, D).cuda() * 0.01
37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40 | im2col_step = 2
41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43 | fwdok = torch.allclose(output_cuda, output_pytorch)
44 | max_abs_err = (output_cuda - output_pytorch).abs().max()
45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46 |
47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48 |
49 |
50 | @torch.no_grad()
51 | def check_forward_equal_with_pytorch_float():
52 | value = torch.rand(N, S, M, D).cuda() * 0.01
53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56 | im2col_step = 2
57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60 | max_abs_err = (output_cuda - output_pytorch).abs().max()
61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62 |
63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64 |
65 |
66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67 |
68 | value = torch.rand(N, S, M, channels).cuda() * 0.01
69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72 | im2col_step = 2
73 | func = MSDeformAttnFunction.apply
74 |
75 | value.requires_grad = grad_value
76 | sampling_locations.requires_grad = grad_sampling_loc
77 | attention_weights.requires_grad = grad_attn_weight
78 |
79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80 |
81 | print(f'* {gradok} check_gradient_numerical(D={channels})')
82 |
83 |
84 | if __name__ == '__main__':
85 | check_forward_equal_with_pytorch_double()
86 | check_forward_equal_with_pytorch_float()
87 |
88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89 | check_gradient_numerical(channels, True, True, True)
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/scripts/avss/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def F10_IoU_BCELoss(pred_mask, ten_gt_masks, gt_temporal_mask_flag):
7 | """
8 | binary cross entropy loss (iou loss) of the total ten frames for multiple sound source segmentation
9 |
10 | Args:
11 | pred_mask: predicted masks for a batch of data, shape:[bs*10, N_CLASSES, 224, 224]
12 | ten_gt_masks: ground truth mask of the total ten frames, shape: [bs*10, 224, 224]
13 | """
14 | assert len(pred_mask.shape) == 4
15 | if ten_gt_masks.shape[1] == 1:
16 | ten_gt_masks = ten_gt_masks.squeeze(1)
17 |
18 | loss = nn.CrossEntropyLoss(reduction='none')(
19 | pred_mask, ten_gt_masks) # [bs*10, 224, 224]
20 | loss = loss.mean(-1).mean(-1) # [bs*10]
21 | loss = loss * gt_temporal_mask_flag # [bs*10]
22 | loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag)
23 |
24 | return loss
25 |
26 |
27 | def Mix_Dice_loss(pred_mask, norm_gt_mask, gt_temporal_mask_flag):
28 | """dice loss for aux loss
29 |
30 | Args:
31 | pred_mask (Tensor): (bs, 1, h, w)
32 | five_gt_masks (Tensor): (bs, 1, h, w)
33 | """
34 | assert len(pred_mask.shape) == 4
35 | pred_mask = torch.sigmoid(pred_mask)
36 |
37 | pred_mask = pred_mask.flatten(1)
38 | gt_mask = norm_gt_mask.flatten(1)
39 | a = (pred_mask * gt_mask).sum(-1)
40 | b = (pred_mask * pred_mask).sum(-1) + 0.001
41 | c = (gt_mask * gt_mask).sum(-1) + 0.001
42 | d = (2 * a) / (b + c)
43 | loss = 1 - d
44 | loss = loss * gt_temporal_mask_flag
45 | loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag)
46 | return loss
47 |
48 |
49 | def IouSemanticAwareLoss(pred_masks, mask_feature, gt_mask, gt_temporal_mask_flag, weight_dict, **kwargs):
50 | total_loss = 0
51 | loss_dict = {}
52 |
53 | iou_loss = weight_dict['iou_loss'] * \
54 | F10_IoU_BCELoss(pred_masks, gt_mask, gt_temporal_mask_flag)
55 | total_loss += iou_loss
56 | loss_dict['iou_loss'] = iou_loss.item()
57 |
58 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True)
59 | mask_feature = F.interpolate(
60 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False)
61 | one_mask = torch.ones_like(gt_mask)
62 | norm_gt_mask = torch.where(gt_mask > 0, one_mask, gt_mask)
63 | mix_loss = weight_dict['mix_loss'] * \
64 | Mix_Dice_loss(mask_feature, norm_gt_mask, gt_temporal_mask_flag)
65 | total_loss += mix_loss
66 | loss_dict['mix_loss'] = mix_loss.item()
67 |
68 | return total_loss, loss_dict
69 |
--------------------------------------------------------------------------------
/scripts/avss/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import os
4 | from mmcv import Config
5 | import argparse
6 | from utils.vis_mask import save_color_mask
7 | from utils.compute_color_metrics import calc_color_miou_fscore
8 | from utils.logger import getLogger
9 | from model import build_model
10 | from dataloader import build_dataset, get_v2_pallete
11 |
12 |
13 | def main():
14 | # logger
15 | logger = getLogger(None, __name__)
16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
17 | logger.info(f'Load config from {args.cfg}')
18 |
19 | # config
20 | cfg = Config.fromfile(args.cfg)
21 | logger.info(cfg.pretty_text)
22 |
23 | # model
24 | model = build_model(**cfg.model)
25 | model.load_state_dict(torch.load(args.weights))
26 | model = torch.nn.DataParallel(model).cuda()
27 | model.eval()
28 | logger.info('Load trained model %s' % args.weights)
29 |
30 | # Test data
31 | test_dataset = build_dataset(**cfg.dataset.test)
32 | test_dataloader = torch.utils.data.DataLoader(test_dataset,
33 | batch_size=cfg.dataset.test.batch_size,
34 | shuffle=False,
35 | num_workers=cfg.process.num_works,
36 | pin_memory=True)
37 | N_CLASSES = test_dataset.num_classes
38 |
39 | # for save predicted rgb masks
40 | v2_pallete = get_v2_pallete(cfg.dataset.test.label_idx_path)
41 | resize_pred_mask = cfg.dataset.test.resize_pred_mask
42 | if resize_pred_mask:
43 | pred_mask_img_size = cfg.dataset.test.save_pred_mask_img_size
44 | else:
45 | pred_mask_img_size = cfg.dataset.test.img_size
46 |
47 | # metrics
48 | miou_pc = torch.zeros((N_CLASSES)) # miou value per class (total sum)
49 | Fs_pc = torch.zeros((N_CLASSES)) # f-score per class (total sum)
50 | cls_pc = torch.zeros((N_CLASSES)) # count per class
51 | with torch.no_grad():
52 | for n_iter, batch_data in enumerate(test_dataloader):
53 | imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, video_name_list = batch_data
54 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda()
55 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda()
56 |
57 | imgs = imgs.cuda()
58 | # audio = audio.cuda()
59 | mask = mask.cuda()
60 | B, frame, C, H, W = imgs.shape
61 | imgs = imgs.view(B * frame, C, H, W)
62 | mask = mask.view(B * frame, H, W)
63 |
64 | vid_temporal_mask_flag = vid_temporal_mask_flag.view(
65 | B * frame) # [B*T]
66 | gt_temporal_mask_flag = gt_temporal_mask_flag.view(
67 | B * frame) # [B*T]
68 |
69 | output, _ = model(audio, imgs, vid_temporal_mask_flag)
70 | if args.save_pred_mask:
71 | mask_save_path = os.path.join(
72 | args.save_dir, dir_name, 'pred_masks')
73 | save_color_mask(output, mask_save_path, video_name_list,
74 | v2_pallete, resize_pred_mask, pred_mask_img_size, T=10)
75 |
76 | _miou_pc, _fscore_pc, _cls_pc, _ = calc_color_miou_fscore(
77 | output, mask)
78 | # compute miou, J-measure
79 | miou_pc += _miou_pc
80 | cls_pc += _cls_pc
81 | # compute f-score, F-measure
82 | Fs_pc += _fscore_pc
83 |
84 | batch_iou = miou_pc / cls_pc
85 | batch_iou[torch.isnan(batch_iou)] = 0
86 | batch_iou = torch.sum(batch_iou) / torch.sum(cls_pc != 0)
87 | batch_fscore = Fs_pc / cls_pc
88 | batch_fscore[torch.isnan(batch_fscore)] = 0
89 | batch_fscore = torch.sum(batch_fscore) / torch.sum(cls_pc != 0)
90 | logger.info('n_iter: {}, iou: {}, F_score: {}, cls_num: {}'.format(
91 | n_iter, batch_iou, batch_fscore, torch.sum(cls_pc != 0).item()))
92 |
93 | miou_pc = miou_pc / cls_pc
94 | logger.info(
95 | f"[test miou] {torch.sum(torch.isnan(miou_pc)).item()} classes are not predicted in this batch")
96 | miou_pc[torch.isnan(miou_pc)] = 0
97 | miou = torch.mean(miou_pc).item()
98 | miou_noBg = torch.mean(miou_pc[:-1]).item()
99 | f_score_pc = Fs_pc / cls_pc
100 | logger.info(
101 | f"[test fscore] {torch.sum(torch.isnan(f_score_pc)).item()} classes are not predicted in this batch")
102 | f_score_pc[torch.isnan(f_score_pc)] = 0
103 | f_score = torch.mean(f_score_pc).item()
104 | f_score_noBg = torch.mean(f_score_pc[:-1]).item()
105 |
106 | logger.info('test | cls {}, miou: {:.4f}, miou_noBg: {:.4f}, F_score: {:.4f}, F_score_noBg: {:.4f}'.format(
107 | torch.sum(cls_pc != 0).item(), miou, miou_noBg, f_score, f_score_noBg))
108 |
109 |
110 | if __name__ == '__main__':
111 | parser = argparse.ArgumentParser()
112 | parser.add_argument('cfg', type=str, help='config file path')
113 | parser.add_argument('weights', type=str, help='model weights path')
114 | parser.add_argument("--save_pred_mask", action='store_true',
115 | default=False, help="save predited masks or not")
116 | parser.add_argument('--save_dir', type=str,
117 | default='work_dir', help='save path')
118 |
119 | args = parser.parse_args()
120 | main()
121 |
--------------------------------------------------------------------------------
/scripts/avss/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import torch.nn
4 | import os
5 | import random
6 | import numpy as np
7 | from mmcv import Config
8 | import argparse
9 | from utils import pyutils
10 | from utils.loss_util import LossUtil
11 | from utils.logger import getLogger
12 | from model import build_model
13 | from dataloader import build_dataset
14 | from loss import IouSemanticAwareLoss
15 | from utils.compute_color_metrics import calc_color_miou_fscore
16 |
17 |
18 | def main():
19 | # Fix seed
20 | FixSeed = 123
21 | random.seed(FixSeed)
22 | np.random.seed(FixSeed)
23 | torch.manual_seed(FixSeed)
24 | torch.cuda.manual_seed(FixSeed)
25 |
26 | # logger
27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime())
28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
29 | if not os.path.exists(args.log_dir):
30 | os.mkdir(args.log_dir)
31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)):
32 | os.mkdir(os.path.join(args.log_dir, dir_name))
33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log')
34 | logger = getLogger(log_file, __name__)
35 | logger.info(f'Load config from {args.cfg}')
36 |
37 | # config
38 | cfg = Config.fromfile(args.cfg)
39 | logger.info(cfg.pretty_text)
40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name)
41 |
42 | # model
43 | model = build_model(**cfg.model)
44 | model = torch.nn.DataParallel(model).cuda()
45 | model.train()
46 | logger.info("Total params: %.2fM" % (sum(p.numel()
47 | for p in model.parameters()) / 1e6))
48 |
49 | # dataset
50 | train_dataset = build_dataset(**cfg.dataset.train)
51 | train_dataloader = torch.utils.data.DataLoader(train_dataset,
52 | batch_size=cfg.dataset.train.batch_size,
53 | shuffle=True,
54 | num_workers=cfg.process.num_works,
55 | pin_memory=True)
56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \
57 | cfg.process.train_epochs
58 | val_dataset = build_dataset(**cfg.dataset.val)
59 | val_dataloader = torch.utils.data.DataLoader(val_dataset,
60 | batch_size=cfg.dataset.val.batch_size,
61 | shuffle=False,
62 | num_workers=cfg.process.num_works,
63 | pin_memory=True)
64 | N_CLASSES = train_dataset.num_classes
65 |
66 | # optimizer
67 | optimizer = pyutils.get_optimizer(model, cfg.optimizer)
68 | loss_util = LossUtil(**cfg.loss)
69 |
70 | # Train
71 | best_epoch = 0
72 | global_step = 0
73 | miou_list = []
74 | max_miou = 0
75 | miou_noBg_list = []
76 | fscore_list, fscore_noBg_list = [], []
77 | max_fs, max_fs_noBg = 0, 0
78 | for epoch in range(cfg.process.train_epochs):
79 | if epoch == cfg.process.freeze_epochs:
80 | model.module.freeze_backbone(False)
81 |
82 | for n_iter, batch_data in enumerate(train_dataloader):
83 | imgs, audio, label, vid_temporal_mask_flag, gt_temporal_mask_flag, _ = batch_data
84 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda()
85 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda()
86 |
87 | imgs = imgs.cuda()
88 | # audio = audio.cuda()
89 | label = label.cuda()
90 | B, frame, C, H, W = imgs.shape
91 | imgs = imgs.view(B * frame, C, H, W)
92 | mask_num = 10
93 | label = label.view(B * mask_num, H, W)
94 | vid_temporal_mask_flag = vid_temporal_mask_flag.view(
95 | B * frame) # [B*T]
96 | gt_temporal_mask_flag = gt_temporal_mask_flag.view(
97 | B * frame) # [B*T]
98 |
99 | # [bs*5, 24, 224, 224]
100 | output, mask_feature = model(audio, imgs, vid_temporal_mask_flag)
101 | loss, loss_dict = IouSemanticAwareLoss(
102 | output, mask_feature, label, gt_temporal_mask_flag, **cfg.loss)
103 | loss_util.add_loss(loss, loss_dict)
104 | optimizer.zero_grad()
105 | loss.backward()
106 | optimizer.step()
107 |
108 | global_step += 1
109 | if (global_step - 1) % 20 == 0:
110 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % (
111 | global_step - 1,
112 | max_step,
113 | loss_util.pretty_out(),
114 | optimizer.param_groups[0]['lr'])
115 | logger.info(train_log)
116 |
117 | # Validation:
118 | if epoch >= cfg.process.start_eval_epoch and epoch % cfg.process.eval_interval == 0:
119 | model.eval()
120 |
121 | miou_pc = torch.zeros((N_CLASSES))
122 | Fs_pc = torch.zeros((N_CLASSES)) # f-score per class (total sum)
123 | cls_pc = torch.zeros((N_CLASSES)) # count per class
124 | with torch.no_grad():
125 | for n_iter, batch_data in enumerate(val_dataloader):
126 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224]
127 | imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, _ = batch_data
128 |
129 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda()
130 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda()
131 |
132 | imgs = imgs.cuda()
133 | # audio = audio.cuda()
134 | mask = mask.cuda()
135 | B, frame, C, H, W = imgs.shape
136 | imgs = imgs.view(B * frame, C, H, W)
137 | mask = mask.view(B * frame, H, W)
138 | #! notice
139 | vid_temporal_mask_flag = vid_temporal_mask_flag.view(
140 | B * frame) # [B*T]
141 | gt_temporal_mask_flag = gt_temporal_mask_flag.view(
142 | B * frame) # [B*T]
143 |
144 | output, _ = model(audio, imgs, vid_temporal_mask_flag)
145 |
146 | _miou_pc, _fscore_pc, _cls_pc, _ = calc_color_miou_fscore(
147 | output, mask)
148 | # compute miou, J-measure
149 | miou_pc += _miou_pc
150 | cls_pc += _cls_pc
151 | # compute f-score, F-measure
152 | Fs_pc += _fscore_pc
153 |
154 | miou_pc = miou_pc / cls_pc
155 | logger.info(
156 | f"[miou] {torch.sum(torch.isnan(miou_pc)).item()} classes are not predicted in this batch")
157 | miou_pc[torch.isnan(miou_pc)] = 0
158 | miou = torch.mean(miou_pc).item()
159 | miou_noBg = torch.mean(miou_pc[:-1]).item()
160 | f_score_pc = Fs_pc / cls_pc
161 | logger.info(
162 | f"[fscore] {torch.sum(torch.isnan(f_score_pc)).item()} classes are not predicted in this batch")
163 | f_score_pc[torch.isnan(f_score_pc)] = 0
164 | f_score = torch.mean(f_score_pc).item()
165 | f_score_noBg = torch.mean(f_score_pc[:-1]).item()
166 |
167 | if miou > max_miou:
168 | model_save_path = os.path.join(
169 | checkpoint_dir, '%s_miou_best.pth' % (args.session_name))
170 | torch.save(model.module.state_dict(), model_save_path)
171 | best_epoch = epoch
172 | logger.info('save miou best model to %s' % model_save_path)
173 | if (miou + f_score) > (max_miou + max_fs):
174 | model_save_path = os.path.join(
175 | checkpoint_dir, '%s_miou_and_fscore_best.pth' % (args.session_name))
176 | torch.save(model.module.state_dict(), model_save_path)
177 | best_epoch = epoch
178 | logger.info('save miou and fscore best model to %s' %
179 | model_save_path)
180 |
181 | miou_list.append(miou)
182 | miou_noBg_list.append(miou_noBg)
183 | max_miou = max(miou_list)
184 | max_miou_noBg = max(miou_noBg_list)
185 | fscore_list.append(f_score)
186 | fscore_noBg_list.append(f_score_noBg)
187 | max_fs = max(fscore_list)
188 | max_fs_noBg = max(fscore_noBg_list)
189 |
190 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}, Miou(no bg): {}, maxMiou (no bg): {} '.format(
191 | epoch, miou, max_miou, miou_noBg, max_miou_noBg)
192 | val_log += ' Fscore: {}, maxFs: {}, Fscore(no bg): {}, max Fscore (no bg): {}'.format(
193 | f_score, max_fs, f_score_noBg, max_fs_noBg)
194 | logger.info(val_log)
195 |
196 | model.train()
197 |
198 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch))
199 |
200 |
201 | if __name__ == '__main__':
202 | parser = argparse.ArgumentParser()
203 | parser.add_argument('cfg', type=str, help='config file path')
204 | parser.add_argument('--log_dir', type=str,
205 | default='work_dir', help='log dir')
206 | parser.add_argument('--checkpoint_dir', type=str,
207 | default='work_dir', help='dir to save checkpoints')
208 | parser.add_argument("--session_name", default="AVSS",
209 | type=str, help="the AVSS setting")
210 |
211 | args = parser.parse_args()
212 | main()
213 |
--------------------------------------------------------------------------------
/scripts/ms3/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def F5_IoU_BCELoss(pred_mask, five_gt_masks):
7 | """
8 | binary cross entropy loss (iou loss) of the total five frames for multiple sound source segmentation
9 |
10 | Args:
11 | pred_mask: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224]
12 | five_gt_masks: ground truth mask of the total five frames, shape: [bs*5, 1, 224, 224]
13 | """
14 | assert len(pred_mask.shape) == 4
15 | pred_mask = torch.sigmoid(pred_mask) # [bs*5, 1, 224, 224]
16 | # five_gt_masks = five_gt_masks.view(-1, 1, five_gt_masks.shape[-2], five_gt_masks.shape[-1]) # [bs*5, 1, 224, 224]
17 | loss = nn.BCELoss()(pred_mask, five_gt_masks)
18 |
19 | return loss
20 |
21 |
22 | def F5_Dice_loss(pred_mask, five_gt_masks):
23 | """dice loss for aux loss
24 |
25 | Args:
26 | pred_mask (Tensor): (bs, 1, h, w)
27 | five_gt_masks (Tensor): (bs, 1, h, w)
28 | """
29 | assert len(pred_mask.shape) == 4
30 | pred_mask = torch.sigmoid(pred_mask)
31 |
32 | pred_mask = pred_mask.flatten(1)
33 | gt_mask = five_gt_masks.flatten(1)
34 | a = (pred_mask * gt_mask).sum(-1)
35 | b = (pred_mask * pred_mask).sum(-1) + 0.001
36 | c = (gt_mask * gt_mask).sum(-1) + 0.001
37 | d = (2 * a) / (b + c)
38 | loss = 1 - d
39 | return loss.mean()
40 |
41 |
42 | def IouSemanticAwareLoss(pred_mask, mask_feature, gt_mask, weight_dict, loss_type='bce', **kwargs):
43 | total_loss = 0
44 | loss_dict = {}
45 |
46 | if loss_type == 'bce':
47 | loss_func = F5_IoU_BCELoss
48 | elif loss_type == 'dice':
49 | loss_func = F5_Dice_loss
50 | else:
51 | raise ValueError
52 |
53 | iou_loss = weight_dict['iou_loss'] * loss_func(pred_mask, gt_mask)
54 | total_loss += iou_loss
55 | loss_dict['iou_loss'] = iou_loss.item()
56 |
57 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True)
58 | mask_feature = F.interpolate(
59 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False)
60 | mix_loss = weight_dict['mix_loss']*loss_func(mask_feature, gt_mask)
61 | total_loss += mix_loss
62 | loss_dict['mix_loss'] = mix_loss.item()
63 |
64 | return total_loss, loss_dict
65 |
--------------------------------------------------------------------------------
/scripts/ms3/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import os
4 | from mmcv import Config
5 | import argparse
6 | from utils import pyutils
7 | from utility import mask_iou, Eval_Fmeasure, save_mask
8 | from utils.logger import getLogger
9 | from model import build_model
10 | from dataloader import build_dataset
11 |
12 |
13 | def main():
14 | # logger
15 | logger = getLogger(None, __name__)
16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
17 | logger.info(f'Load config from {args.cfg}')
18 |
19 | # config
20 | cfg = Config.fromfile(args.cfg)
21 | logger.info(cfg.pretty_text)
22 |
23 | # model
24 | model = build_model(**cfg.model)
25 | model.load_state_dict(torch.load(args.weights))
26 | model = torch.nn.DataParallel(model).cuda()
27 | model.eval()
28 | logger.info('Load trained model %s' % args.weights)
29 |
30 | # Test data
31 | test_dataset = build_dataset(**cfg.dataset.test)
32 | test_dataloader = torch.utils.data.DataLoader(test_dataset,
33 | batch_size=cfg.dataset.test.batch_size,
34 | shuffle=False,
35 | num_workers=cfg.process.num_works,
36 | pin_memory=True)
37 | avg_meter_miou = pyutils.AverageMeter('miou')
38 | avg_meter_F = pyutils.AverageMeter('F_score')
39 |
40 | # Test
41 | with torch.no_grad():
42 | for n_iter, batch_data in enumerate(test_dataloader):
43 | imgs, audio, mask, video_name_list = batch_data
44 |
45 | imgs = imgs.cuda()
46 | audio = audio.cuda()
47 | mask = mask.cuda()
48 | B, frame, C, H, W = imgs.shape
49 | imgs = imgs.view(B * frame, C, H, W)
50 | mask = mask.view(B * frame, H, W)
51 | audio = audio.view(-1, audio.shape[2],
52 | audio.shape[3], audio.shape[4])
53 |
54 | output, _ = model(audio, imgs)
55 | if args.save_pred_mask:
56 | mask_save_path = os.path.join(
57 | args.save_dir, dir_name, 'pred_masks')
58 | save_mask(output.squeeze(1), mask_save_path, video_name_list)
59 |
60 | miou = mask_iou(output.squeeze(1), mask)
61 | avg_meter_miou.add({'miou': miou})
62 | F_score = Eval_Fmeasure(output.squeeze(1), mask)
63 | avg_meter_F.add({'F_score': F_score})
64 | logger.info('n_iter: {}, iou: {}, F_score: {}'.format(
65 | n_iter, miou, F_score))
66 | miou = (avg_meter_miou.pop('miou'))
67 | F_score = (avg_meter_F.pop('F_score'))
68 | logger.info(f'test miou: {miou.item()}')
69 | logger.info(f'test F_score: {F_score}')
70 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score))
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('cfg', type=str, help='config file path')
76 | parser.add_argument('weights', type=str, help='model weights path')
77 | parser.add_argument("--save_pred_mask", action='store_true',
78 | default=False, help="save predited masks or not")
79 | parser.add_argument('--save_dir', type=str,
80 | default='work_dir', help='save path')
81 |
82 | args = parser.parse_args()
83 | main()
84 |
--------------------------------------------------------------------------------
/scripts/ms3/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import torch.nn
4 | import os
5 | import random
6 | import numpy as np
7 | from mmcv import Config
8 | import argparse
9 | from utils import pyutils
10 | from utils.loss_util import LossUtil
11 | from utility import mask_iou
12 | from utils.logger import getLogger
13 | from model import build_model
14 | from dataloader import build_dataset
15 | from loss import IouSemanticAwareLoss
16 |
17 |
18 | def main():
19 | # Fix seed
20 | FixSeed = 123
21 | random.seed(FixSeed)
22 | np.random.seed(FixSeed)
23 | torch.manual_seed(FixSeed)
24 | torch.cuda.manual_seed(FixSeed)
25 |
26 | # logger
27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime())
28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
29 | if not os.path.exists(args.log_dir):
30 | os.mkdir(args.log_dir)
31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)):
32 | os.mkdir(os.path.join(args.log_dir, dir_name))
33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log')
34 | logger = getLogger(log_file, __name__)
35 | logger.info(f'Load config from {args.cfg}')
36 |
37 | # config
38 | cfg = Config.fromfile(args.cfg)
39 | logger.info(cfg.pretty_text)
40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name)
41 |
42 | # model
43 | model = build_model(**cfg.model)
44 | model = torch.nn.DataParallel(model).cuda()
45 | model.train()
46 | logger.info("Total params: %.2fM" % (sum(p.numel()
47 | for p in model.parameters()) / 1e6))
48 |
49 | # dataset
50 | train_dataset = build_dataset(**cfg.dataset.train)
51 | train_dataloader = torch.utils.data.DataLoader(train_dataset,
52 | batch_size=cfg.dataset.train.batch_size,
53 | shuffle=True,
54 | num_workers=cfg.process.num_works,
55 | pin_memory=True)
56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \
57 | cfg.process.train_epochs
58 | val_dataset = build_dataset(**cfg.dataset.val)
59 | val_dataloader = torch.utils.data.DataLoader(val_dataset,
60 | batch_size=cfg.dataset.val.batch_size,
61 | shuffle=False,
62 | num_workers=cfg.process.num_works,
63 | pin_memory=True)
64 |
65 | # optimizer
66 | optimizer = pyutils.get_optimizer(model, cfg.optimizer)
67 | loss_util = LossUtil(**cfg.loss)
68 | avg_meter_miou = pyutils.AverageMeter('miou')
69 |
70 | # Train
71 | best_epoch = 0
72 | global_step = 0
73 | miou_list = []
74 | max_miou = 0
75 | for epoch in range(cfg.process.train_epochs):
76 | if epoch == cfg.process.freeze_epochs:
77 | model.module.freeze_backbone(False)
78 |
79 | for n_iter, batch_data in enumerate(train_dataloader):
80 | imgs, audio, mask, _ = batch_data
81 |
82 | imgs = imgs.cuda()
83 | audio = audio.cuda()
84 | mask = mask.cuda()
85 | B, frame, C, H, W = imgs.shape
86 | imgs = imgs.view(B * frame, C, H, W)
87 | mask_num = 5
88 | mask = mask.view(B * mask_num, 1, H, W)
89 | audio = audio.view(-1, audio.shape[2],
90 | audio.shape[3], audio.shape[4])
91 |
92 | output, mask_feature = model(audio, imgs)
93 | loss, loss_dict = IouSemanticAwareLoss(
94 | output, mask_feature, mask, **cfg.loss)
95 | loss_util.add_loss(loss, loss_dict)
96 | optimizer.zero_grad()
97 | loss.backward()
98 | optimizer.step()
99 |
100 | global_step += 1
101 | if (global_step - 1) % 20 == 0:
102 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % (
103 | global_step - 1, max_step, loss_util.pretty_out(), optimizer.param_groups[0]['lr'])
104 | logger.info(train_log)
105 |
106 | # Validation:
107 | model.eval()
108 | with torch.no_grad():
109 | for n_iter, batch_data in enumerate(val_dataloader):
110 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224]
111 | imgs, audio, mask, _ = batch_data
112 |
113 | imgs = imgs.cuda()
114 | audio = audio.cuda()
115 | mask = mask.cuda()
116 | B, frame, C, H, W = imgs.shape
117 | imgs = imgs.view(B * frame, C, H, W)
118 | mask = mask.view(B * frame, H, W)
119 | audio = audio.view(-1, audio.shape[2],
120 | audio.shape[3], audio.shape[4])
121 |
122 | # [bs*5, 1, 224, 224]
123 | output, _ = model(audio, imgs)
124 |
125 | miou = mask_iou(output.squeeze(1), mask)
126 | avg_meter_miou.add({'miou': miou})
127 |
128 | miou = (avg_meter_miou.pop('miou'))
129 | if miou > max_miou:
130 | model_save_path = os.path.join(
131 | checkpoint_dir, '%s_best.pth' % (args.session_name))
132 | torch.save(model.module.state_dict(), model_save_path)
133 | best_epoch = epoch
134 | logger.info('save best model to %s' % model_save_path)
135 |
136 | miou_list.append(miou)
137 | max_miou = max(miou_list)
138 |
139 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}'.format(
140 | epoch, miou, max_miou)
141 | logger.info(val_log)
142 |
143 | model.train()
144 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch))
145 |
146 |
147 | if __name__ == '__main__':
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('cfg', type=str, help='config file path')
150 | parser.add_argument('--log_dir', type=str,
151 | default='work_dir', help='log dir')
152 | parser.add_argument('--checkpoint_dir', type=str,
153 | default='work_dir', help='dir to save checkpoints')
154 | parser.add_argument("--session_name", default="MS3",
155 | type=str, help="the MS3 setting")
156 |
157 | args = parser.parse_args()
158 | main()
159 |
--------------------------------------------------------------------------------
/scripts/ms3/utility.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import os
5 | import shutil
6 | import logging
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | import sys
12 | import time
13 | import pandas as pd
14 | import pdb
15 | from torchvision import transforms
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def save_checkpoint(state, epoch, is_best, checkpoint_dir='./models', filename='checkpoint', thres=100):
21 | """
22 | - state
23 | - epoch
24 | - is_best
25 | - checkpoint_dir: default, ./models
26 | - filename: default, checkpoint
27 | - freq: default, 10
28 | - thres: default, 100
29 | """
30 | if not os.path.isdir(checkpoint_dir):
31 | os.makedirs(checkpoint_dir)
32 |
33 | if epoch >= thres:
34 | file_path = os.path.join(
35 | checkpoint_dir, filename + '_{}'.format(str(epoch)) + '.pth.tar')
36 | else:
37 | file_path = os.path.join(checkpoint_dir, filename + '.pth.tar')
38 | torch.save(state, file_path)
39 | logger.info('==> save model at {}'.format(file_path))
40 |
41 | if is_best:
42 | cpy_file = os.path.join(
43 | checkpoint_dir, filename + '_model_best.pth.tar')
44 | shutil.copyfile(file_path, cpy_file)
45 | logger.info('==> save best model at {}'.format(cpy_file))
46 |
47 |
48 | def mask_iou(pred, target, eps=1e-7, size_average=True):
49 | r"""
50 | param:
51 | pred: size [N x H x W]
52 | target: size [N x H x W]
53 | output:
54 | iou: size [1] (size_average=True) or [N] (size_average=False)
55 | """
56 | assert len(pred.shape) == 3 and pred.shape == target.shape
57 |
58 | N = pred.size(0)
59 | num_pixels = pred.size(-1) * pred.size(-2)
60 | no_obj_flag = (target.sum(2).sum(1) == 0)
61 |
62 | temp_pred = torch.sigmoid(pred)
63 | pred = (temp_pred > 0.5).int()
64 | inter = (pred * target).sum(2).sum(1)
65 | union = torch.max(pred, target).sum(2).sum(1)
66 |
67 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
68 | inter[no_obj_flag] = inter_no_obj[no_obj_flag]
69 | union[no_obj_flag] = num_pixels
70 |
71 | iou = torch.sum(inter / (union + eps)) / N
72 |
73 | return iou
74 |
75 |
76 | def _eval_pr(y_pred, y, num, cuda_flag=True):
77 | if cuda_flag:
78 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
79 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
80 | else:
81 | prec, recall = torch.zeros(num), torch.zeros(num)
82 | thlist = torch.linspace(0, 1 - 1e-10, num)
83 | for i in range(num):
84 | y_temp = (y_pred >= thlist[i]).float()
85 | tp = (y_temp * y).sum()
86 | prec[i], recall[i] = tp / \
87 | (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
88 |
89 | return prec, recall
90 |
91 |
92 | def Eval_Fmeasure(pred, gt, pr_num=255):
93 | r"""
94 | param:
95 | pred: size [N x H x W]
96 | gt: size [N x H x W]
97 | output:
98 | iou: size [1] (size_average=True) or [N] (size_average=False)
99 | """
100 | print('=> eval [FMeasure]..')
101 | # =======================================[important]
102 | pred = torch.sigmoid(pred)
103 | N = pred.size(0)
104 | beta2 = 0.3
105 | avg_f, img_num = 0.0, 0
106 | score = torch.zeros(pr_num)
107 | print("{} videos in this batch".format(N))
108 |
109 | for img_id in range(N):
110 | # examples with totally black GTs are out of consideration
111 | if torch.mean(gt[img_id]) == 0.0:
112 | continue
113 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num)
114 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
115 | f_score[f_score != f_score] = 0 # for Nan
116 | avg_f += f_score
117 | img_num += 1
118 | score = avg_f / img_num
119 |
120 | return score.max().item()
121 |
122 |
123 | def save_mask(pred_masks, save_base_path, video_name_list):
124 | # pred_mask: [bs*5, 1, 224, 224]
125 | # print(f"=> {len(video_name_list)} videos in this batch")
126 |
127 | if not os.path.exists(save_base_path):
128 | os.makedirs(save_base_path, exist_ok=True)
129 |
130 | pred_masks = pred_masks.squeeze(2)
131 | pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()
132 |
133 | pred_masks = pred_masks.view(-1, 5,
134 | pred_masks.shape[-2], pred_masks.shape[-1])
135 | pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8)
136 | pred_masks *= 255
137 | bs = pred_masks.shape[0]
138 |
139 | for idx in range(bs):
140 | video_name = video_name_list[idx]
141 | mask_save_path = os.path.join(save_base_path, video_name)
142 | if not os.path.exists(mask_save_path):
143 | os.makedirs(mask_save_path, exist_ok=True)
144 | one_video_masks = pred_masks[idx] # [5, 1, 224, 224]
145 | for video_id in range(len(one_video_masks)):
146 | one_mask = one_video_masks[video_id]
147 | output_name = "%s_%d.png" % (video_name, video_id)
148 | im = Image.fromarray(one_mask).convert('P')
149 | im.save(os.path.join(mask_save_path, output_name), format='PNG')
150 |
151 |
152 | def save_raw_img_mask(anno_file_path, raw_img_base_path, mask_base_path, split='test', r=0.5):
153 | df = pd.read_csv(anno_file_path, sep=',')
154 | df_test = df[df['split'] == split]
155 | count = 0
156 | for video_id in range(len(df_test)):
157 | video_name = df_test.iloc[video_id][0]
158 | raw_img_path = os.path.join(raw_img_base_path, video_name)
159 | for img_id in range(5):
160 | img_name = "%s.mp4_%d.png" % (video_name, img_id + 1)
161 | raw_img = cv2.imread(os.path.join(raw_img_path, img_name))
162 | mask = cv2.imread(os.path.join(
163 | mask_base_path, 'pred_masks', video_name, "%s_%d.png" % (video_name, img_id)))
164 | # pdb.set_trace()
165 | raw_img_mask = cv2.addWeighted(raw_img, 1, mask, r, 0)
166 | save_img_path = os.path.join(
167 | mask_base_path, 'img_add_masks', video_name)
168 | if not os.path.exists(save_img_path):
169 | os.makedirs(save_img_path, exist_ok=True)
170 | cv2.imwrite(os.path.join(save_img_path, img_name), raw_img_mask)
171 | count += 1
172 | print(f'count: {count} videos')
173 |
--------------------------------------------------------------------------------
/scripts/s4/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def F1_IoU_BCELoss(pred_masks, first_gt_mask):
7 | """
8 | binary cross entropy loss (iou loss) of the first frame for single sound source segmentation
9 |
10 | Args:
11 | pred_masks: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224]
12 | first_gt_mask: ground truth mask of the first frame, shape: [bs, 1, 1, 224, 224]
13 | """
14 | assert len(pred_masks.shape) == 4
15 | pred_masks = torch.sigmoid(pred_masks) # [bs*5, 1, 224, 224]
16 |
17 | indices = torch.tensor(list(range(0, len(pred_masks), 5)))
18 | indices = indices.cuda()
19 | first_pred = torch.index_select(
20 | pred_masks, dim=0, index=indices) # [bs, 1, 224, 224]
21 | assert first_pred.requires_grad == True, "Error when indexing predited masks"
22 | if len(first_gt_mask.shape) == 5:
23 | first_gt_mask = first_gt_mask.squeeze(1) # [bs, 1, 224, 224]
24 |
25 | first_bce_loss = nn.BCELoss()(first_pred, first_gt_mask)
26 |
27 | return first_bce_loss
28 |
29 |
30 | def F1_Dice_loss(pred_masks, first_gt_mask):
31 | """dice loss for aux loss
32 |
33 | Args:
34 | pred_mask (Tensor): (bs*5, 1, h, w)
35 | five_gt_masks (Tensor): (bs, 1, 1, h, w)
36 | """
37 | assert len(pred_masks.shape) == 4
38 | pred_masks = torch.sigmoid(pred_masks)
39 |
40 | indices = torch.tensor(list(range(0, len(pred_masks), 5)))
41 | indices = indices.cuda()
42 | first_pred = torch.index_select(
43 | pred_masks, dim=0, index=indices) # [bs, 1, 224, 224]
44 | assert first_pred.requires_grad == True, "Error when indexing predited masks"
45 | if len(first_gt_mask.shape) == 5:
46 | first_gt_mask = first_gt_mask.squeeze(1) # [bs, 1, 224, 224]
47 |
48 | pred_mask = first_pred.flatten(1)
49 | gt_mask = first_gt_mask.flatten(1)
50 | a = (pred_mask * gt_mask).sum(-1)
51 | b = (pred_mask * pred_mask).sum(-1) + 0.001
52 | c = (gt_mask * gt_mask).sum(-1) + 0.001
53 | d = (2 * a) / (b + c)
54 | loss = 1 - d
55 | return loss.mean()
56 |
57 |
58 | def IouSemanticAwareLoss(pred_masks, mask_feature, gt_mask, weight_dict, loss_type='bce', **kwargs):
59 | total_loss = 0
60 | loss_dict = {}
61 |
62 | if loss_type == 'bce':
63 | loss_func = F1_IoU_BCELoss
64 | elif loss_type == 'dice':
65 | loss_func = F1_Dice_loss
66 | else:
67 | raise ValueError
68 |
69 | iou_loss = loss_func(pred_masks, gt_mask)
70 | total_loss += weight_dict['iou_loss'] * iou_loss
71 | loss_dict['iou_loss'] = weight_dict['iou_loss'] * iou_loss.item()
72 |
73 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True)
74 | mask_feature = F.interpolate(
75 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False)
76 | mix_loss = weight_dict['mix_loss']*loss_func(mask_feature, gt_mask)
77 | total_loss += mix_loss
78 | loss_dict['mix_loss'] = mix_loss.item()
79 |
80 | return total_loss, loss_dict
81 |
--------------------------------------------------------------------------------
/scripts/s4/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import os
4 | from mmcv import Config
5 | import argparse
6 | from utils import pyutils
7 | from utility import mask_iou, Eval_Fmeasure, save_mask
8 | from utils.logger import getLogger
9 | from model import build_model
10 | from dataloader import build_dataset
11 |
12 |
13 | def main():
14 | # logger
15 | logger = getLogger(None, __name__)
16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
17 | logger.info(f'Load config from {args.cfg}')
18 |
19 | # config
20 | cfg = Config.fromfile(args.cfg)
21 | logger.info(cfg.pretty_text)
22 |
23 | # model
24 | model = build_model(**cfg.model)
25 | model.load_state_dict(torch.load(args.weights))
26 | model = torch.nn.DataParallel(model).cuda()
27 | model.eval()
28 | logger.info('Load trained model %s' % args.weights)
29 |
30 | # Test data
31 | test_dataset = build_dataset(**cfg.dataset.test)
32 | test_dataloader = torch.utils.data.DataLoader(test_dataset,
33 | batch_size=cfg.dataset.test.batch_size,
34 | shuffle=False,
35 | num_workers=cfg.process.num_works,
36 | pin_memory=True)
37 | avg_meter_miou = pyutils.AverageMeter('miou')
38 | avg_meter_F = pyutils.AverageMeter('F_score')
39 |
40 | # Test
41 | with torch.no_grad():
42 | for n_iter, batch_data in enumerate(test_dataloader):
43 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
44 | imgs, audio, mask, category_list, video_name_list = batch_data
45 |
46 | imgs = imgs.cuda()
47 | audio = audio.cuda()
48 | mask = mask.cuda()
49 | B, frame, C, H, W = imgs.shape
50 | imgs = imgs.view(B * frame, C, H, W)
51 | mask = mask.view(B * frame, H, W)
52 | audio = audio.view(-1, audio.shape[2],
53 | audio.shape[3], audio.shape[4])
54 |
55 | output, _ = model(audio, imgs)
56 | if args.save_pred_mask:
57 | mask_save_path = os.path.join(
58 | args.save_dir, dir_name, 'pred_masks')
59 | save_mask(output.squeeze(1), mask_save_path,
60 | category_list, video_name_list)
61 |
62 | miou = mask_iou(output.squeeze(1), mask)
63 | avg_meter_miou.add({'miou': miou})
64 | F_score = Eval_Fmeasure(output.squeeze(1), mask)
65 | avg_meter_F.add({'F_score': F_score})
66 | logger.info('n_iter: {}, iou: {}, F_score: {}'.format(
67 | n_iter, miou, F_score))
68 |
69 | miou = (avg_meter_miou.pop('miou'))
70 | F_score = (avg_meter_F.pop('F_score'))
71 | logger.info(f'test miou: {miou.item}')
72 | logger.info(f'test F_score: {F_score}')
73 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score))
74 |
75 |
76 | if __name__ == '__main__':
77 | parser = argparse.ArgumentParser()
78 | parser.add_argument('cfg', type=str, help='config file path')
79 | parser.add_argument('weights', type=str, help='model weights path')
80 | parser.add_argument("--save_pred_mask", action='store_true',
81 | default=False, help="save predited masks or not")
82 | parser.add_argument('--save_dir', type=str,
83 | default='work_dir', help='save path')
84 |
85 | args = parser.parse_args()
86 | main()
87 |
--------------------------------------------------------------------------------
/scripts/s4/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import torch.nn
4 | import os
5 | import random
6 | import numpy as np
7 | from mmcv import Config
8 | import argparse
9 | from utils import pyutils
10 | from utils.loss_util import LossUtil
11 | from utility import mask_iou
12 | from utils.logger import getLogger
13 | from model import build_model
14 | from dataloader import build_dataset
15 | from loss import IouSemanticAwareLoss
16 |
17 |
18 | def main():
19 | # Fix seed
20 | FixSeed = 123
21 | random.seed(FixSeed)
22 | np.random.seed(FixSeed)
23 | torch.manual_seed(FixSeed)
24 | torch.cuda.manual_seed(FixSeed)
25 |
26 | # logger
27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime())
28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0]
29 | if not os.path.exists(args.log_dir):
30 | os.mkdir(args.log_dir)
31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)):
32 | os.mkdir(os.path.join(args.log_dir, dir_name))
33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log')
34 | logger = getLogger(log_file, __name__)
35 | logger.info(f'Load config from {args.cfg}')
36 |
37 | # config
38 | cfg = Config.fromfile(args.cfg)
39 | logger.info(cfg.pretty_text)
40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name)
41 |
42 | # model
43 | model = build_model(**cfg.model)
44 | model = torch.nn.DataParallel(model).cuda()
45 | model.train()
46 | logger.info("Total params: %.2fM" % (sum(p.numel()
47 | for p in model.parameters()) / 1e6))
48 |
49 | # dataset
50 | train_dataset = build_dataset(**cfg.dataset.train)
51 | train_dataloader = torch.utils.data.DataLoader(train_dataset,
52 | batch_size=cfg.dataset.train.batch_size,
53 | shuffle=True,
54 | num_workers=cfg.process.num_works,
55 | pin_memory=True)
56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \
57 | cfg.process.train_epochs
58 | val_dataset = build_dataset(**cfg.dataset.val)
59 | val_dataloader = torch.utils.data.DataLoader(val_dataset,
60 | batch_size=cfg.dataset.val.batch_size,
61 | shuffle=False,
62 | num_workers=cfg.process.num_works,
63 | pin_memory=True)
64 |
65 | # optimizer
66 | optimizer = pyutils.get_optimizer(model, cfg.optimizer)
67 | loss_util = LossUtil(**cfg.loss)
68 | avg_meter_miou = pyutils.AverageMeter('miou')
69 |
70 | # Train
71 | best_epoch = 0
72 | global_step = 0
73 | miou_list = []
74 | max_miou = 0
75 | for epoch in range(cfg.process.train_epochs):
76 | if epoch == cfg.process.freeze_epochs:
77 | model.module.freeze_backbone(False)
78 |
79 | for n_iter, batch_data in enumerate(train_dataloader):
80 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
81 | imgs, audio, mask = batch_data
82 |
83 | imgs = imgs.cuda()
84 | audio = audio.cuda()
85 | mask = mask.cuda()
86 | B, frame, C, H, W = imgs.shape
87 | imgs = imgs.view(B * frame, C, H, W)
88 | mask = mask.view(B, H, W)
89 | audio = audio.view(-1, audio.shape[2],
90 | audio.shape[3], audio.shape[4])
91 |
92 | output, mask_feature = model(audio, imgs) # [bs*5, 1, 224, 224]
93 | loss, loss_dict = IouSemanticAwareLoss(
94 | output, mask_feature, mask.unsqueeze(1).unsqueeze(1), **cfg.loss)
95 | loss_util.add_loss(loss, loss_dict)
96 | optimizer.zero_grad()
97 | loss.backward()
98 | optimizer.step()
99 |
100 | global_step += 1
101 |
102 | if (global_step - 1) % 50 == 0:
103 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % (
104 | global_step - 1, max_step, loss_util.pretty_out(), optimizer.param_groups[0]['lr'])
105 | logger.info(train_log)
106 |
107 | # Validation:
108 | model.eval()
109 | with torch.no_grad():
110 | for n_iter, batch_data in enumerate(val_dataloader):
111 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224]
112 | imgs, audio, mask, _, _ = batch_data
113 |
114 | imgs = imgs.cuda()
115 | audio = audio.cuda()
116 | mask = mask.cuda()
117 | B, frame, C, H, W = imgs.shape
118 | imgs = imgs.view(B * frame, C, H, W)
119 | mask = mask.view(B * frame, H, W)
120 | audio = audio.view(-1, audio.shape[2],
121 | audio.shape[3], audio.shape[4])
122 |
123 | output, _ = model(audio, imgs)
124 |
125 | miou = mask_iou(output.squeeze(1), mask)
126 | avg_meter_miou.add({'miou': miou})
127 |
128 | miou = (avg_meter_miou.pop('miou'))
129 | if miou > max_miou:
130 | model_save_path = os.path.join(
131 | checkpoint_dir, '%s_best.pth' % (args.session_name))
132 | torch.save(model.module.state_dict(), model_save_path)
133 | best_epoch = epoch
134 | logger.info('save best model to %s' % model_save_path)
135 |
136 | miou_list.append(miou)
137 | max_miou = max(miou_list)
138 |
139 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}'.format(
140 | epoch, miou, max_miou)
141 | # print(val_log)
142 | logger.info(val_log)
143 |
144 | model.train()
145 |
146 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch))
147 |
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('cfg', type=str, help='config file path')
152 | parser.add_argument('--log_dir', type=str,
153 | default='work_dir', help='log dir')
154 | parser.add_argument('--checkpoint_dir', type=str,
155 | default='work_dir', help='dir to save checkpoints')
156 | parser.add_argument("--session_name", default="S4",
157 | type=str, help="the S4 setting")
158 |
159 | args = parser.parse_args()
160 | main()
161 |
--------------------------------------------------------------------------------
/scripts/s4/utility.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import os
5 | import shutil
6 | import logging
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | import sys
12 | import time
13 | import pandas as pd
14 | import pdb
15 | from torchvision import transforms
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def save_checkpoint(state, epoch, is_best, checkpoint_dir='./models', filename='checkpoint', thres=100):
21 | """
22 | - state
23 | - epoch
24 | - is_best
25 | - checkpoint_dir: default, ./models
26 | - filename: default, checkpoint
27 | - freq: default, 10
28 | - thres: default, 100
29 | """
30 | if not os.path.isdir(checkpoint_dir):
31 | os.makedirs(checkpoint_dir)
32 |
33 | if epoch >= thres:
34 | file_path = os.path.join(
35 | checkpoint_dir, filename + '_{}'.format(str(epoch)) + '.pth.tar')
36 | else:
37 | file_path = os.path.join(checkpoint_dir, filename + '.pth.tar')
38 | torch.save(state, file_path)
39 | logger.info('==> save model at {}'.format(file_path))
40 |
41 | if is_best:
42 | cpy_file = os.path.join(
43 | checkpoint_dir, filename + '_model_best.pth.tar')
44 | shutil.copyfile(file_path, cpy_file)
45 | logger.info('==> save best model at {}'.format(cpy_file))
46 |
47 |
48 | def mask_iou(pred, target, eps=1e-7, size_average=True):
49 | r"""
50 | param:
51 | pred: size [N x H x W]
52 | target: size [N x H x W]
53 | output:
54 | iou: size [1] (size_average=True) or [N] (size_average=False)
55 | """
56 | assert len(pred.shape) == 3 and pred.shape == target.shape
57 |
58 | N = pred.size(0)
59 | num_pixels = pred.size(-1) * pred.size(-2)
60 | no_obj_flag = (target.sum(2).sum(1) == 0)
61 |
62 | temp_pred = torch.sigmoid(pred)
63 | pred = (temp_pred > 0.5).int()
64 | inter = (pred * target).sum(2).sum(1)
65 | union = torch.max(pred, target).sum(2).sum(1)
66 |
67 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
68 | inter[no_obj_flag] = inter_no_obj[no_obj_flag]
69 | union[no_obj_flag] = num_pixels
70 |
71 | iou = torch.sum(inter / (union + eps)) / N
72 |
73 | return iou
74 |
75 |
76 | def _eval_pr(y_pred, y, num, cuda_flag=True):
77 | if cuda_flag:
78 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
79 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
80 | else:
81 | prec, recall = torch.zeros(num), torch.zeros(num)
82 | thlist = torch.linspace(0, 1 - 1e-10, num)
83 | for i in range(num):
84 | y_temp = (y_pred >= thlist[i]).float()
85 | tp = (y_temp * y).sum()
86 | prec[i], recall[i] = tp / \
87 | (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
88 |
89 | return prec, recall
90 |
91 |
92 | def Eval_Fmeasure(pred, gt, pr_num=255):
93 | r"""
94 | param:
95 | pred: size [N x H x W]
96 | gt: size [N x H x W]
97 | output:
98 | iou: size [1] (size_average=True) or [N] (size_average=False)
99 | """
100 | print('=> eval [FMeasure]..')
101 | # =======================================[important]
102 | pred = torch.sigmoid(pred)
103 | N = pred.size(0)
104 | beta2 = 0.3
105 | avg_f, img_num = 0.0, 0
106 | score = torch.zeros(pr_num)
107 | print("{} videos in this batch".format(N))
108 |
109 | for img_id in range(N):
110 | # examples with totally black GTs are out of consideration
111 | if torch.mean(gt[img_id]) == 0.0:
112 | continue
113 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num)
114 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
115 | f_score[f_score != f_score] = 0 # for Nan
116 | avg_f += f_score
117 | img_num += 1
118 | score = avg_f / img_num
119 |
120 | return score.max().item()
121 |
122 |
123 | def save_mask(pred_masks, save_base_path, category_list, video_name_list):
124 | # pred_mask: [bs*5, 1, 224, 224]
125 | # print(f"=> {len(video_name_list)} videos in this batch")
126 |
127 | if not os.path.exists(save_base_path):
128 | os.makedirs(save_base_path, exist_ok=True)
129 |
130 | pred_masks = pred_masks.squeeze(2)
131 | pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()
132 |
133 | pred_masks = pred_masks.view(-1, 5,
134 | pred_masks.shape[-2], pred_masks.shape[-1])
135 | pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8)
136 | pred_masks *= 255
137 | bs = pred_masks.shape[0]
138 |
139 | for idx in range(bs):
140 | category, video_name = category_list[idx], video_name_list[idx]
141 | mask_save_path = os.path.join(save_base_path, category, video_name)
142 | if not os.path.exists(mask_save_path):
143 | os.makedirs(mask_save_path, exist_ok=True)
144 | one_video_masks = pred_masks[idx] # [5, 1, 224, 224]
145 | for video_id in range(len(one_video_masks)):
146 | one_mask = one_video_masks[video_id]
147 | output_name = "%s_%d.png" % (video_name, video_id)
148 | im = Image.fromarray(one_mask).convert('P')
149 | im.save(os.path.join(mask_save_path, output_name), format='PNG')
150 |
151 |
152 | def save_raw_img_mask(anno_file_path, raw_img_base_path, mask_base_path, split='test', r=0.5):
153 | df = pd.read_csv(anno_file_path, sep=',')
154 | df_test = df[df['split'] == split]
155 | count = 0
156 | for video_id in range(len(df_test)):
157 | video_name, category = df_test.iloc[video_id][0], df_test.iloc[video_id][2]
158 | raw_img_path = os.path.join(
159 | raw_img_base_path, split, category, video_name)
160 | for img_id in range(5):
161 | img_name = "%s_%d.png" % (video_name, img_id + 1)
162 | raw_img = cv2.imread(os.path.join(raw_img_path, img_name))
163 | mask = cv2.imread(os.path.join(mask_base_path, 'pred_masks',
164 | category, video_name, "%s_%d.png" % (video_name, img_id)))
165 | # pdb.set_trace()
166 | raw_img_mask = cv2.addWeighted(raw_img, 1, mask, r, 0)
167 | save_img_path = os.path.join(
168 | mask_base_path, 'img_add_masks', category, video_name)
169 | if not os.path.exists(save_img_path):
170 | os.makedirs(save_img_path, exist_ok=True)
171 | cv2.imwrite(os.path.join(save_img_path, img_name), raw_img_mask)
172 | count += 1
173 | print(f'count: {count} videos')
174 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | SESSION=$1
2 | CONFIG=$2
3 | WEIGHTS=$3
4 |
5 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
6 | python scripts/$SESSION/test.py \
7 | $CONFIG \
8 | $WEIGHTS \
9 | # --save_pred_mask
10 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | SESSION=$1
2 | CONFIG=$2
3 |
4 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
5 | python scripts/$SESSION/train.py $CONFIG
6 |
--------------------------------------------------------------------------------
/utils/compute_color_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import os
5 | import shutil
6 | import logging
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | import sys
12 | import time
13 | import pandas as pd
14 | from torchvision import transforms
15 |
16 | import numpy as np
17 | from multiprocessing import Pool
18 | from tqdm import tqdm
19 |
20 | import pdb
21 |
22 |
23 | def _batch_miou_fscore(output, target, nclass, T, beta2=0.3):
24 | """batch mIoU and Fscore"""
25 | # output: [BF, C, H, W],
26 | # target: [BF, H, W]
27 | mini = 1
28 | maxi = nclass
29 | nbins = nclass
30 | predict = torch.argmax(output, 1) + 1
31 | target = target.float() + 1
32 | # pdb.set_trace()
33 | predict = predict.float() * (target > 0).float() # [BF, H, W]
34 | intersection = predict * (predict == target).float() # [BF, H, W]
35 | # areas of intersection and union
36 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
37 | batch_size = target.shape[0] // T
38 | cls_count = torch.zeros(nclass).float()
39 | ious = torch.zeros(nclass).float()
40 | fscores = torch.zeros(nclass).float()
41 |
42 | # vid_miou_list = torch.zeros(target.shape[0]).float()
43 | vid_miou_list = []
44 | for i in range(target.shape[0]):
45 | area_inter = torch.histc(
46 | intersection[i].cpu(), bins=nbins, min=mini, max=maxi) # TP
47 | area_pred = torch.histc(
48 | predict[i].cpu(), bins=nbins, min=mini, max=maxi) # TP + FP
49 | area_lab = torch.histc(
50 | target[i].cpu(), bins=nbins, min=mini, max=maxi) # TP + FN
51 | area_union = area_pred + area_lab - area_inter
52 | assert torch.sum(area_inter > area_union).item(
53 | ) == 0, "Intersection area should be smaller than Union area"
54 | iou = 1.0 * area_inter.float() / (2.220446049250313e-16 + area_union.float())
55 | # iou[torch.isnan(iou)] = 1.
56 | ious += iou
57 | cls_count[torch.nonzero(area_union).squeeze(-1)] += 1
58 |
59 | precision = area_inter / area_pred
60 | recall = area_inter / area_lab
61 | fscore = (1 + beta2) * precision * recall / \
62 | (beta2 * precision + recall)
63 | fscore[torch.isnan(fscore)] = 0.
64 | fscores += fscore
65 |
66 | vid_miou_list.append(torch.sum(iou) / (torch.sum(iou != 0).float()))
67 |
68 | return ious, fscores, cls_count, vid_miou_list
69 |
70 |
71 | def calc_color_miou_fscore(pred, target, T=10):
72 | r"""
73 | J measure
74 | param:
75 | pred: size [BF x C x H x W], C is category number including background
76 | target: size [BF x H x W]
77 | """
78 | nclass = pred.shape[1]
79 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W]
80 | # miou, fscore, cls_count = _batch_miou_fscore(pred, target, nclass, T)
81 | miou, fscore, cls_count, vid_miou_list = _batch_miou_fscore(
82 | pred, target, nclass, T)
83 | return miou, fscore, cls_count, vid_miou_list
84 |
85 |
86 | def _batch_intersection_union(output, target, nclass, T):
87 | """mIoU"""
88 | # output: [BF, C, H, W],
89 | # target: [BF, H, W]
90 | mini = 1
91 | maxi = nclass
92 | nbins = nclass
93 | predict = torch.argmax(output, 1) + 1
94 | target = target.float() + 1
95 |
96 | # pdb.set_trace()
97 |
98 | predict = predict.float() * (target > 0).float() # [BF, H, W]
99 | intersection = predict * (predict == target).float() # [BF, H, W]
100 | # areas of intersection and union
101 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
102 | batch_size = target.shape[0] // T
103 | cls_count = torch.zeros(nclass).float()
104 | ious = torch.zeros(nclass).float()
105 | for i in range(target.shape[0]):
106 | area_inter = torch.histc(
107 | intersection[i].cpu(), bins=nbins, min=mini, max=maxi)
108 | area_pred = torch.histc(
109 | predict[i].cpu(), bins=nbins, min=mini, max=maxi)
110 | area_lab = torch.histc(target[i].cpu(), bins=nbins, min=mini, max=maxi)
111 | area_union = area_pred + area_lab - area_inter
112 | assert torch.sum(area_inter > area_union).item(
113 | ) == 0, "Intersection area should be smaller than Union area"
114 | iou = 1.0 * area_inter.float() / (2.220446049250313e-16 + area_union.float())
115 | ious += iou
116 | cls_count[torch.nonzero(area_union).squeeze(-1)] += 1
117 | # pdb.set_trace()
118 | # ious = ious / cls_count
119 | # ious[torch.isnan(ious)] = 0
120 | # pdb.set_trace()
121 | # return area_inter.float(), area_union.float()
122 | # return ious
123 | return ious, cls_count
124 |
125 |
126 | def calc_color_miou(pred, target, T=10):
127 | r"""
128 | J measure
129 | param:
130 | pred: size [BF x C x H x W], C is category number including background
131 | target: size [BF x H x W]
132 | """
133 | nclass = pred.shape[1]
134 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W]
135 | # correct, labeled = _batch_pix_accuracy(pred, target)
136 | # inter, union = _batch_intersection_union(pred, target, nclass, T)
137 | ious, cls_count = _batch_intersection_union(pred, target, nclass, T)
138 |
139 | # pixAcc = 1.0 * correct / (2.220446049250313e-16 + labeled)
140 | # IoU = 1.0 * inter / (2.220446049250313e-16 + union)
141 | # mIoU = IoU.mean().item()
142 | # pdb.set_trace()
143 | # return mIoU
144 | return ious, cls_count
145 |
146 |
147 | def calc_binary_miou(pred, target, eps=1e-7, size_average=True):
148 | r"""
149 | param:
150 | pred: size [N x C x H x W]
151 | target: size [N x H x W]
152 | output:
153 | iou: size [1] (size_average=True) or [N] (size_average=False)
154 | """
155 | # assert len(pred.shape) == 3 and pred.shape == target.shape
156 | nclass = pred.shape[1]
157 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W]
158 | pred = torch.argmax(pred, dim=1) # [BF, H, W]
159 | binary_pred = (pred != (nclass - 1)).float() # [BF, H, W]
160 | # pdb.set_trace()
161 | pred = binary_pred
162 | target = (target != (nclass - 1)).float()
163 |
164 | N = pred.size(0)
165 | num_pixels = pred.size(-1) * pred.size(-2)
166 | no_obj_flag = (target.sum(2).sum(1) == 0)
167 |
168 | temp_pred = torch.sigmoid(pred)
169 | pred = (temp_pred > 0.5).int()
170 | inter = (pred * target).sum(2).sum(1)
171 | union = torch.max(pred, target).sum(2).sum(1)
172 |
173 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
174 | inter[no_obj_flag] = inter_no_obj[no_obj_flag]
175 | union[no_obj_flag] = num_pixels
176 |
177 | iou = torch.sum(inter / (union + eps)) / N
178 | # pdb.set_trace()
179 | return iou
180 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def getLogger(log_file, name, fmt='%(asctime)s %(levelname)s ==> %(message)s'):
5 | logger = logging.getLogger(name)
6 | logger.setLevel(logging.DEBUG)
7 | formatter = logging.Formatter(fmt)
8 |
9 | console_handler = logging.StreamHandler()
10 | console_handler.setLevel(logging.DEBUG)
11 | console_handler.setFormatter(formatter)
12 | logger.addHandler(console_handler)
13 | if log_file is not None:
14 | file_handler = logging.FileHandler(log_file)
15 | file_handler.setLevel(logging.INFO)
16 | file_handler.setFormatter(formatter)
17 | logger.addHandler(file_handler)
18 |
19 | return logger
20 |
--------------------------------------------------------------------------------
/utils/loss_util.py:
--------------------------------------------------------------------------------
1 | from utils.pyutils import AverageMeter
2 |
3 |
4 | class LossUtil:
5 | def __init__(self, weight_dict, **kwargs) -> None:
6 | self.loss_weight_dict = weight_dict
7 | self.avg_loss = dict()
8 | self.avg_loss['total_loss'] = AverageMeter('total_loss')
9 | # for k in weight_dict.keys():
10 | # self.avg_loss[k] = AverageMeter(k)
11 |
12 | def add_loss(self, loss, loss_dict):
13 | self.avg_loss['total_loss'].add({'total_loss': loss.item()})
14 | for k, v in loss_dict.items():
15 | meter = self.avg_loss.get(k, None)
16 | if meter is None:
17 | meter = AverageMeter(k)
18 | self.avg_loss[k] = meter
19 |
20 | self.avg_loss[k].add({k: v})
21 |
22 | def pretty_out(self):
23 | f = 'Total_Loss:%.4f, ' % (
24 | self.avg_loss['total_loss'].pop('total_loss'))
25 | for k in self.avg_loss.keys():
26 | if k == 'total_loss':
27 | continue
28 | t = '%s:%.4f, ' % (k, self.avg_loss[k].pop(k))
29 | f += t
30 | return f
31 |
--------------------------------------------------------------------------------
/utils/pyutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import time
4 | import sys
5 | from multiprocessing.pool import ThreadPool
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, outfile):
10 | self.terminal = sys.stdout
11 | self.log = open(outfile, "w")
12 | sys.stdout = self
13 |
14 | def write(self, message):
15 | self.terminal.write(message)
16 | self.log.write(message)
17 |
18 | def flush(self):
19 | self.terminal.flush()
20 |
21 |
22 | class AverageMeter:
23 | def __init__(self, *keys):
24 | self.__data = dict()
25 | for k in keys:
26 | self.__data[k] = [0.0, 0]
27 |
28 | def add(self, dict):
29 | for k, v in dict.items():
30 | self.__data[k][0] += v
31 | self.__data[k][1] += 1
32 |
33 | def get(self, *keys):
34 | if len(keys) == 1:
35 | return self.__data[keys[0]][0] / self.__data[keys[0]][1]
36 | else:
37 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
38 | return tuple(v_list)
39 |
40 | def pop(self, key=None):
41 | if key is None:
42 | for k in self.__data.keys():
43 | self.__data[k] = [0.0, 0]
44 | else:
45 | v = self.get(key)
46 | self.__data[key] = [0.0, 0]
47 | return v
48 |
49 |
50 | class Timer:
51 | def __init__(self, starting_msg=None):
52 | self.start = time.time()
53 | self.stage_start = self.start
54 |
55 | if starting_msg is not None:
56 | print(starting_msg, time.ctime(time.time()))
57 |
58 | def update_progress(self, progress):
59 | self.elapsed = time.time() - self.start
60 | self.est_total = self.elapsed / progress
61 | self.est_remaining = self.est_total - self.elapsed
62 | self.est_finish = int(self.start + self.est_total)
63 |
64 | def str_est_finish(self):
65 | return str(time.ctime(self.est_finish))
66 |
67 | def get_stage_elapsed(self):
68 | return time.time() - self.stage_start
69 |
70 | def reset_stage(self):
71 | self.stage_start = time.time()
72 |
73 |
74 | class BatchThreader:
75 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12):
76 | self.batch_size = batch_size
77 | self.prefetch_size = prefetch_size
78 |
79 | self.pool = ThreadPool(processes=processes)
80 | self.async_result = []
81 |
82 | self.func = func
83 | self.left_args_list = args_list
84 | self.n_tasks = len(args_list)
85 |
86 | # initial work
87 | self.__start_works(self.__get_n_pending_works())
88 |
89 | def __start_works(self, times):
90 | for _ in range(times):
91 | args = self.left_args_list.pop(0)
92 | self.async_result.append(
93 | self.pool.apply_async(self.func, args))
94 |
95 | def __get_n_pending_works(self):
96 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result), len(self.left_args_list))
97 |
98 | def pop_results(self):
99 |
100 | n_inwork = len(self.async_result)
101 |
102 | n_fetch = min(n_inwork, self.batch_size)
103 | rtn = [self.async_result.pop(0).get()
104 | for _ in range(n_fetch)]
105 |
106 | to_fill = self.__get_n_pending_works()
107 | if to_fill == 0:
108 | self.pool.close()
109 | else:
110 | self.__start_works(to_fill)
111 |
112 | return rtn
113 |
114 |
115 | def get_indices_of_pairs(radius, size):
116 |
117 | search_dist = []
118 |
119 | for x in range(1, radius):
120 | search_dist.append((0, x))
121 |
122 | for y in range(1, radius):
123 | for x in range(-radius + 1, radius):
124 | if x * x + y * y < radius * radius:
125 | search_dist.append((y, x))
126 |
127 | radius_floor = radius - 1
128 |
129 | full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64),
130 | (size[0], size[1]))
131 |
132 | cropped_height = size[0] - radius_floor
133 | cropped_width = size[1] - 2 * radius_floor
134 |
135 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor],
136 | [-1])
137 |
138 | indices_to_list = []
139 |
140 | for dy, dx in search_dist:
141 | indices_to = full_indices[dy:dy + cropped_height,
142 | radius_floor + dx:radius_floor + dx + cropped_width]
143 | indices_to = np.reshape(indices_to, [-1])
144 |
145 | indices_to_list.append(indices_to)
146 |
147 | concat_indices_to = np.concatenate(indices_to_list, axis=0)
148 |
149 | return indices_from, concat_indices_to
150 |
151 |
152 | def get_optimizer(model, cfg):
153 | # backbone_params = list(map(id, model.module.backbone.parameters()))
154 | # base_params = filter(lambda p: id(
155 | # p) not in backbone_params, model.parameters())
156 | # params = [
157 | # {'params': base_params},
158 | # {'params': model.module.backbone.parameters(), 'lr': cfg.backbone_lr}
159 | # ]
160 |
161 | if cfg.type == 'Adam':
162 | opt = torch.optim.Adam(model.parameters(), cfg.lr)
163 | elif cfg.type == 'AdamW':
164 | opt = torch.optim.AdamW(model.parameters(), cfg.lr)
165 | elif cfg.type=='SGD':
166 | opt=torch.optim.SGD(model.parameters(), cfg.lr)
167 | else:
168 | raise ValueError
169 |
170 | return opt
171 |
--------------------------------------------------------------------------------
/utils/vis_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import os
5 | import shutil
6 | import logging
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | import sys
12 | import time
13 | import pandas as pd
14 | import pdb
15 | from torchvision import transforms
16 |
17 |
18 | def save_color_mask(pred_masks, save_base_path, video_name_list, v_pallete, resize, resized_mask_size, T=10):
19 | # pred_mask: [bs*5, N_CLASSES, 224, 224]
20 | # print(f"=> {len(video_name_list)} videos in this batch")
21 |
22 | if not os.path.exists(save_base_path):
23 | os.makedirs(save_base_path, exist_ok=True)
24 |
25 | BT, N_CLASSES, H, W = pred_masks.shape
26 | bs = BT // T
27 |
28 | pred_masks = torch.softmax(pred_masks, dim=1)
29 | pred_masks = torch.argmax(pred_masks, dim=1) # [BT, 224, 224]
30 | pred_masks = pred_masks.cpu().numpy()
31 |
32 | pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8) # [BT, H, W, 3]
33 | for cls_idx in range(N_CLASSES):
34 | rgb = v_pallete[cls_idx]
35 | pred_rgb_masks[pred_masks == cls_idx] = rgb
36 | pred_rgb_masks = pred_rgb_masks.reshape(bs, T, H, W, 3)
37 |
38 | for idx in range(bs):
39 | video_name = video_name_list[idx]
40 | mask_save_path = os.path.join(save_base_path, video_name)
41 | if not os.path.exists(mask_save_path):
42 | os.makedirs(mask_save_path, exist_ok=True)
43 | one_video_masks = pred_rgb_masks[idx] # [5, 224, 224, 3]
44 | for video_id in range(len(one_video_masks)):
45 | one_mask = one_video_masks[video_id]
46 | output_name = "%s_%d.png"%(video_name, video_id)
47 | im = Image.fromarray(one_mask)#.convert('RGB')
48 | if resize:
49 | im = im.resize(resized_mask_size)
50 | im.save(os.path.join(mask_save_path, output_name), format='PNG')
51 |
--------------------------------------------------------------------------------