├── .gitignore ├── README.md ├── config ├── avss │ ├── AVSegFormer_pvt2_avss.py │ └── AVSegFormer_res50_avss.py ├── ms3 │ ├── AVSegFormer_pvt2_ms3.py │ └── AVSegFormer_res50_ms3.py └── s4 │ ├── AVSegFormer_pvt2_s4.py │ └── AVSegFormer_res50_s4.py ├── dataloader ├── __init__.py ├── ms3_dataset.py ├── s4_dataset.py └── v2_dataset.py ├── image └── arch.png ├── model ├── AVSegFormer.py ├── __init__.py ├── backbone │ ├── __init__.py │ ├── pvt.py │ └── resnet.py ├── head │ ├── AVSegHead.py │ └── __init__.py ├── utils │ ├── __init__.py │ ├── fusion_block.py │ ├── positional_encoding.py │ ├── query_generator.py │ └── transformer.py └── vggish │ ├── __init__.py │ ├── mel_features.py │ ├── vggish.py │ ├── vggish_input.py │ └── vggish_params.py ├── ops ├── functions │ ├── __init__.py │ └── ms_deform_attn_func.py ├── make.sh ├── modules │ ├── __init__.py │ └── ms_deform_attn.py ├── setup.py ├── src │ ├── cpu │ │ ├── ms_deform_attn_cpu.cpp │ │ └── ms_deform_attn_cpu.h │ ├── cuda │ │ ├── ms_deform_attn_cuda.cu │ │ ├── ms_deform_attn_cuda.h │ │ └── ms_deform_im2col_cuda.cuh │ ├── ms_deform_attn.h │ └── vision.cpp └── test.py ├── scripts ├── avss │ ├── loss.py │ ├── test.py │ └── train.py ├── ms3 │ ├── loss.py │ ├── test.py │ ├── train.py │ └── utility.py └── s4 │ ├── loss.py │ ├── test.py │ ├── train.py │ └── utility.py ├── test.sh ├── train.sh └── utils ├── compute_color_metrics.py ├── logger.py ├── loss_util.py ├── pyutils.py └── vis_mask.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | **/__pycache__/** 3 | *.pth 4 | temp 5 | data 6 | pretrained 7 | work_dir 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 💬 AVSegFormer [[paper](https://arxiv.org/abs/2307.01146)] 2 | The combination of vision and audio has long been a topic of interest among researchers in the multi-modal field. Recently, a new audio-visual segmentation task has been introduced, aiming to locate and segment the corresponding sound source objects in a given video. This task demands pixel-level fine-grained features for the first time, posing significant challenges. In this paper, we propose AVSegFormer, a new method for audio-visual segmentation tasks that leverages the Transformer architecture for its outstanding performance in multi-modal tasks. We combine audio features and learnable queries as decoder inputs to facilitate multi-modal information exchange. Furthermore, we design an audio-visual mixer to amplify the features of target objects. Additionally, we devise an intermediate mask loss to enhance training efficacy. Our method demonstrates robust performance and achieves state-of-the-art results in audio-visual segmentation tasks. 3 | 4 | 5 | ## 🚀 What's New 6 | - (2023.04.28) Upload pre-trained checkpoints and update README. 7 | - (2023.04.25) We completed the implemention of AVSegFormer and push the code. 8 | 9 | 10 | ## 🏠 Method 11 | image 12 | 13 | 14 | ## 🛠️ Get Started 15 | 16 | ### 1. Environments 17 | ```shell 18 | # recommended 19 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html 20 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html 21 | pip install pandas 22 | pip install timm 23 | pip install resampy 24 | pip install soundfile 25 | # build MSDeformAttention 26 | cd ops 27 | sh make.sh 28 | ``` 29 | 30 | 31 | ### 2. Data 32 | 33 | Please refer to the link [AVSBenchmark](https://github.com/OpenNLPLab/AVSBench) to download the datasets. You can put the data under `data` folder or rename your own folder. Remember to modify the path in config files. The `data` directory is as bellow: 34 | ``` 35 | |--data 36 | |--AVSS 37 | |--Multi-sources 38 | |--Single-source 39 | ``` 40 | 41 | 42 | ### 3. Download Pre-Trained Models 43 | 44 | - The pretrained backbone is available from benchmark [AVSBench pretrained backbones](https://drive.google.com/drive/folders/1386rcFHJ1QEQQMF6bV1rXJTzy8v26RTV). 45 | - We provides pre-trained models for all three subtasks. You can download them from [AVSegFormer pretrained models](https://drive.google.com/drive/folders/1ZYZOWAfoXcGPDsocswEN7ZYvcAn4H8kY). 46 | 47 | |Method|Backbone|Subset|Lr schd|Config|mIoU|F-score|Download| 48 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 49 | |AVSegFormer-R50|ResNet-50|S4|30ep|[config](config/s4/AVSegFormer_res50_s4.py)|76.38|86.7|[ckpt](https://drive.google.com/file/d/1nvIfR-1XZ_BgP8ZSUDuAsGhAwDJRxgC3/view?usp=drive_link)| 50 | |AVSegFormer-PVTv2|PVTv2-B5|S4|30ep|[config](config/s4/AVSegFormer_pvt2_s4.py)|83.06|90.5|[ckpt](https://drive.google.com/file/d/1ZJ55jxoHP1ur-hLBkGcha8sjptE_shfw/view?usp=drive_link)| 51 | |AVSegFormer-R50|ResNet-50|MS3|60ep|[config](config/ms3/AVSegFormer_res50_ms3.py)|53.81|65.6|[ckpt](https://drive.google.com/file/d/1MRk5gQnUtiWwYDpPfB20fO07SVLhfuIV/view?usp=drive_link)| 52 | |AVSegFormer-PVTv2|PVTv2-B5|MS3|60ep|[config](config/ms3/AVSegFormer_pvt2_ms3.py)|61.33|73.0|[ckpt](https://drive.google.com/file/d/1iKTxWtehAgCkNVty-4H1zVyAOaNxipHv/view?usp=drive_link)| 53 | |AVSegFormer-R50|ResNet-50|AVSS|30ep|[config](config/avss/AVSegFormer_res50_avss.py)|26.58|31.5|[ckpt](https://drive.google.com/file/d/1RvL6psDsINuUwd9V1ESgE2Kixh9MXIke/view?usp=drive_link)| 54 | |AVSegFormer-PVTv2|PVTv2-B5|AVSS|30ep|[config](config/avss/AVSegFormer_pvt2_avss.py)|37.31|42.8|[ckpt](https://drive.google.com/file/d/1P8a2dJSUoW0EqFyxyP8B1-Rnscxnh0YY/view?usp=drive_link)| 55 | 56 | 57 | ### 4. Train 58 | ```shell 59 | TASK = "s4" # or ms3, avss 60 | CONFIG = "config/s4/AVSegFormer_pvt2_s4.py" 61 | 62 | bash train.sh ${TASK} ${CONFIG} 63 | ``` 64 | 65 | 66 | ### 5. Test 67 | ```shell 68 | TASK = "s4" # or ms3, avss 69 | CONFIG = "config/s4/AVSegFormer_pvt2_s4.py" 70 | CHECKPOINT = "work_dir/AVSegFormer_pvt2_s4/S4_best.pth" 71 | 72 | bash test.sh ${TASK} ${CONFIG} ${CHECKPOINT} 73 | ``` 74 | 75 | 76 | ## 🤝 Citation 77 | 78 | If you use our model, please consider cite following papers: 79 | ``` 80 | @article{zhou2023avss, 81 | title={Audio-Visual Segmentation with Semantics}, 82 | author={Zhou, Jinxing and Shen, Xuyang and Wang, Jianyuan and Zhang, Jiayi and Sun, Weixuan and Zhang, Jing and Birchfield, Stan and Guo, Dan and Kong, Lingpeng and Wang, Meng and Zhong, Yiran}, 83 | journal={arXiv preprint arXiv:2301.13190}, 84 | year={2023}, 85 | } 86 | 87 | @misc{gao2023avsegformer, 88 | title={AVSegFormer: Audio-Visual Segmentation with Transformer}, 89 | author={Shengyi Gao and Zhe Chen and Guo Chen and Wenhai Wang and Tong Lu}, 90 | year={2023}, 91 | eprint={2307.01146}, 92 | archivePrefix={arXiv}, 93 | primaryClass={cs.CV} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /config/avss/AVSegFormer_pvt2_avss.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='pvt_v2_b5', 6 | init_weights_path='pretrained/pvt_v2_b5.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=True, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[64, 128, 320, 512], 16 | num_classes=71, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | query_generator=dict( 21 | type='AttentionGenerator', 22 | num_layers=6, 23 | query_num=300), 24 | positional_encoding=dict( 25 | type='SinePositionalEncoding', 26 | num_feats=128), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=10) 45 | dataset = dict( 46 | train=dict( 47 | type='V2Dataset', 48 | split='train', 49 | num_class=71, 50 | mask_num=10, 51 | crop_img_and_mask=True, 52 | crop_size=224, 53 | meta_csv_path='data/AVSS/metadata.csv', 54 | label_idx_path='data/AVSS/label2idx.json', 55 | dir_base='data/AVSS', 56 | img_size=(224, 224), 57 | resize_pred_mask=True, 58 | save_pred_mask_img_size=(360, 240), 59 | batch_size=4), 60 | val=dict( 61 | type='V2Dataset', 62 | split='val', 63 | num_class=71, 64 | mask_num=10, 65 | crop_img_and_mask=True, 66 | crop_size=224, 67 | meta_csv_path='data/AVSS/metadata.csv', 68 | label_idx_path='data/AVSS/label2idx.json', 69 | dir_base='data/AVSS', 70 | img_size=(224, 224), 71 | resize_pred_mask=True, 72 | save_pred_mask_img_size=(360, 240), 73 | batch_size=4), 74 | test=dict( 75 | type='V2Dataset', 76 | split='test', 77 | num_class=71, 78 | mask_num=10, 79 | crop_img_and_mask=True, 80 | crop_size=224, 81 | meta_csv_path='data/AVSS/metadata.csv', 82 | label_idx_path='data/AVSS/label2idx.json', 83 | dir_base='data/AVSS', 84 | img_size=(224, 224), 85 | resize_pred_mask=True, 86 | save_pred_mask_img_size=(360, 240), 87 | batch_size=4)) 88 | optimizer = dict( 89 | type='AdamW', 90 | lr=2e-5) 91 | loss = dict( 92 | weight_dict=dict( 93 | iou_loss=1.0, 94 | mix_loss=0.1)) 95 | process = dict( 96 | num_works=8, 97 | train_epochs=30, 98 | start_eval_epoch=10, 99 | eval_interval=2, 100 | freeze_epochs=0) 101 | -------------------------------------------------------------------------------- /config/avss/AVSegFormer_res50_avss.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='res50', 6 | init_weights_path='pretrained/resnet50-19c8e357.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=True, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[256, 512, 1024, 2048], 16 | num_classes=71, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | query_generator=dict( 21 | type='AttentionGenerator', 22 | num_layers=6, 23 | query_num=300), 24 | positional_encoding=dict( 25 | type='SinePositionalEncoding', 26 | num_feats=128), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=10) 45 | dataset = dict( 46 | train=dict( 47 | type='V2Dataset', 48 | split='train', 49 | num_class=71, 50 | mask_num=10, 51 | crop_img_and_mask=True, 52 | crop_size=224, 53 | meta_csv_path='data/AVSS/metadata.csv', 54 | label_idx_path='data/AVSS/label2idx.json', 55 | dir_base='data/AVSS', 56 | img_size=(224, 224), 57 | resize_pred_mask=True, 58 | save_pred_mask_img_size=(360, 240), 59 | batch_size=4), 60 | val=dict( 61 | type='V2Dataset', 62 | split='val', 63 | num_class=71, 64 | mask_num=10, 65 | crop_img_and_mask=True, 66 | crop_size=224, 67 | meta_csv_path='data/AVSS/metadata.csv', 68 | label_idx_path='data/AVSS/label2idx.json', 69 | dir_base='data/AVSS', 70 | img_size=(224, 224), 71 | resize_pred_mask=True, 72 | save_pred_mask_img_size=(360, 240), 73 | batch_size=4), 74 | test=dict( 75 | type='V2Dataset', 76 | split='test', 77 | num_class=71, 78 | mask_num=10, 79 | crop_img_and_mask=True, 80 | crop_size=224, 81 | meta_csv_path='data/AVSS/metadata.csv', 82 | label_idx_path='data/AVSS/label2idx.json', 83 | dir_base='data/AVSS', 84 | img_size=(224, 224), 85 | resize_pred_mask=True, 86 | save_pred_mask_img_size=(360, 240), 87 | batch_size=4)) 88 | optimizer = dict( 89 | type='AdamW', 90 | lr=2e-5) 91 | loss = dict( 92 | weight_dict=dict( 93 | iou_loss=1.0, 94 | mix_loss=0.1)) 95 | process = dict( 96 | num_works=8, 97 | train_epochs=30, 98 | start_eval_epoch=10, 99 | eval_interval=2, 100 | freeze_epochs=0) 101 | -------------------------------------------------------------------------------- /config/ms3/AVSegFormer_pvt2_ms3.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='pvt_v2_b5', 6 | init_weights_path='pretrained/pvt_v2_b5.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=False, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[64, 128, 320, 512], 16 | num_classes=1, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | query_generator=dict( 21 | type='AttentionGenerator', 22 | num_layers=6, 23 | query_num=300), 24 | positional_encoding=dict( 25 | type='SinePositionalEncoding', 26 | num_feats=128), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=5) 45 | dataset = dict( 46 | train=dict( 47 | type='MS3Dataset', 48 | split='train', 49 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 50 | dir_img='data/Multi-sources/ms3_data/visual_frames', 51 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 52 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 53 | img_size=(224, 224), 54 | batch_size=2), 55 | val=dict( 56 | type='MS3Dataset', 57 | split='val', 58 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 59 | dir_img='data/Multi-sources/ms3_data/visual_frames', 60 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 61 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 62 | img_size=(224, 224), 63 | batch_size=2), 64 | test=dict( 65 | type='MS3Dataset', 66 | split='test', 67 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 68 | dir_img='data/Multi-sources/ms3_data/visual_frames', 69 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 70 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 71 | img_size=(224, 224), 72 | batch_size=2)) 73 | optimizer = dict( 74 | type='AdamW', 75 | lr=2e-5) 76 | loss = dict( 77 | weight_dict=dict( 78 | iou_loss=1.0, 79 | mix_loss=0.1), 80 | loss_type='dice') 81 | process = dict( 82 | num_works=8, 83 | train_epochs=60, 84 | freeze_epochs=10) 85 | -------------------------------------------------------------------------------- /config/ms3/AVSegFormer_res50_ms3.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='res50', 6 | init_weights_path='pretrained/resnet50-19c8e357.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=False, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[256, 512, 1024, 2048], 16 | num_classes=1, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | query_generator=dict( 21 | type='AttentionGenerator', 22 | num_layers=6, 23 | query_num=300), 24 | positional_encoding=dict( 25 | type='SinePositionalEncoding', 26 | num_feats=128), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=5) 45 | dataset = dict( 46 | train=dict( 47 | type='MS3Dataset', 48 | split='train', 49 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 50 | dir_img='data/Multi-sources/ms3_data/visual_frames', 51 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 52 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 53 | img_size=(224, 224), 54 | batch_size=2), 55 | val=dict( 56 | type='MS3Dataset', 57 | split='val', 58 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 59 | dir_img='data/Multi-sources/ms3_data/visual_frames', 60 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 61 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 62 | img_size=(224, 224), 63 | batch_size=2), 64 | test=dict( 65 | type='MS3Dataset', 66 | split='test', 67 | anno_csv='data/Multi-sources/ms3_meta_data.csv', 68 | dir_img='data/Multi-sources/ms3_data/visual_frames', 69 | dir_audio_log_mel='data/Multi-sources/ms3_data/audio_log_mel', 70 | dir_mask='data/Multi-sources/ms3_data/gt_masks', 71 | img_size=(224, 224), 72 | batch_size=2)) 73 | optimizer = dict( 74 | type='AdamW', 75 | lr=2e-5) 76 | loss = dict( 77 | weight_dict=dict( 78 | iou_loss=1.0, 79 | mix_loss=0.1), 80 | loss_type='dice') 81 | process = dict( 82 | num_works=8, 83 | train_epochs=60, 84 | freeze_epochs=10) 85 | -------------------------------------------------------------------------------- /config/s4/AVSegFormer_pvt2_s4.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='pvt_v2_b5', 6 | init_weights_path='pretrained/pvt_v2_b5.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=False, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[64, 128, 320, 512], 16 | num_classes=1, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | positional_encoding=dict( 21 | type='SinePositionalEncoding', 22 | num_feats=128), 23 | query_generator=dict( 24 | type='AttentionGenerator', 25 | num_layers=6, 26 | query_num=300), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=5) 45 | dataset = dict( 46 | train=dict( 47 | type='S4Dataset', 48 | split='train', 49 | anno_csv='data/Single-source/s4_meta_data.csv', 50 | dir_img='data/Single-source/s4_data/visual_frames', 51 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 52 | dir_mask='data/Single-source/s4_data/gt_masks', 53 | img_size=(224, 224), 54 | batch_size=2), 55 | val=dict( 56 | type='S4Dataset', 57 | split='val', 58 | anno_csv='data/Single-source/s4_meta_data.csv', 59 | dir_img='data/Single-source/s4_data/visual_frames', 60 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 61 | dir_mask='data/Single-source/s4_data/gt_masks', 62 | img_size=(224, 224), 63 | batch_size=2), 64 | test=dict( 65 | type='S4Dataset', 66 | split='test', 67 | anno_csv='data/Single-source/s4_meta_data.csv', 68 | dir_img='data/Single-source/s4_data/visual_frames', 69 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 70 | dir_mask='data/Single-source/s4_data/gt_masks', 71 | img_size=(224, 224), 72 | batch_size=2)) 73 | optimizer = dict( 74 | type='AdamW', 75 | lr=2e-5) 76 | loss = dict( 77 | weight_dict=dict( 78 | iou_loss=1.0, 79 | mix_loss=0.1), 80 | loss_type='dice') 81 | process = dict( 82 | num_works=8, 83 | train_epochs=30, 84 | freeze_epochs=5) 85 | -------------------------------------------------------------------------------- /config/s4/AVSegFormer_res50_s4.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='AVSegFormer', 3 | neck=None, 4 | backbone=dict( 5 | type='res50', 6 | init_weights_path='pretrained/resnet50-19c8e357.pth'), 7 | vggish=dict( 8 | freeze_audio_extractor=True, 9 | pretrained_vggish_model_path='pretrained/vggish-10086976.pth', 10 | preprocess_audio_to_log_mel=False, 11 | postprocess_log_mel_with_pca=False, 12 | pretrained_pca_params_path=None), 13 | head=dict( 14 | type='AVSegHead', 15 | in_channels=[256, 512, 1024, 2048], 16 | num_classes=1, 17 | query_num=300, 18 | use_learnable_queries=True, 19 | fusion_block=dict(type='CrossModalMixer'), 20 | positional_encoding=dict( 21 | type='SinePositionalEncoding', 22 | num_feats=128), 23 | query_generator=dict( 24 | type='AttentionGenerator', 25 | num_layers=6, 26 | query_num=300), 27 | transformer=dict( 28 | type='AVSTransformer', 29 | encoder=dict( 30 | num_layers=6, 31 | layer=dict( 32 | dim=256, 33 | ffn_dim=2048, 34 | dropout=0.1)), 35 | decoder=dict( 36 | num_layers=6, 37 | layer=dict( 38 | dim=256, 39 | ffn_dim=2048, 40 | dropout=0.1)))), 41 | audio_dim=128, 42 | embed_dim=256, 43 | freeze_audio_backbone=True, 44 | T=5) 45 | dataset = dict( 46 | train=dict( 47 | type='S4Dataset', 48 | split='train', 49 | anno_csv='data/Single-source/s4_meta_data.csv', 50 | dir_img='data/Single-source/s4_data/visual_frames', 51 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 52 | dir_mask='data/Single-source/s4_data/gt_masks', 53 | img_size=(224, 224), 54 | batch_size=2), 55 | val=dict( 56 | type='S4Dataset', 57 | split='val', 58 | anno_csv='data/Single-source/s4_meta_data.csv', 59 | dir_img='data/Single-source/s4_data/visual_frames', 60 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 61 | dir_mask='data/Single-source/s4_data/gt_masks', 62 | img_size=(224, 224), 63 | batch_size=2), 64 | test=dict( 65 | type='S4Dataset', 66 | split='test', 67 | anno_csv='data/Single-source/s4_meta_data.csv', 68 | dir_img='data/Single-source/s4_data/visual_frames', 69 | dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel', 70 | dir_mask='data/Single-source/s4_data/gt_masks', 71 | img_size=(224, 224), 72 | batch_size=2)) 73 | optimizer = dict( 74 | type='AdamW', 75 | lr=2e-5) 76 | loss = dict( 77 | weight_dict=dict( 78 | iou_loss=1.0, 79 | mix_loss=0.1), 80 | loss_type='dice') 81 | process = dict( 82 | num_works=8, 83 | train_epochs=30, 84 | freeze_epochs=5) 85 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .v2_dataset import V2Dataset, get_v2_pallete 2 | from .s4_dataset import S4Dataset 3 | from .ms3_dataset import MS3Dataset 4 | from mmcv import Config 5 | 6 | 7 | def build_dataset(type, split, **kwargs): 8 | if type == 'V2Dataset': 9 | return V2Dataset(split=split, cfg=Config(kwargs)) 10 | elif type == 'S4Dataset': 11 | return S4Dataset(split=split, cfg=Config(kwargs)) 12 | elif type == 'MS3Dataset': 13 | return MS3Dataset(split=split, cfg=Config(kwargs)) 14 | else: 15 | raise ValueError 16 | 17 | 18 | __all__ = ['build_dataset', 'get_v2_pallete'] 19 | -------------------------------------------------------------------------------- /dataloader/ms3_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pickle 9 | 10 | import cv2 11 | from PIL import Image 12 | from torchvision import transforms 13 | 14 | 15 | def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None): 16 | img_PIL = Image.open(path).convert(mode) 17 | if transform: 18 | img_tensor = transform(img_PIL) 19 | return img_tensor 20 | return img_PIL 21 | 22 | 23 | def load_audio_lm(audio_lm_path): 24 | with open(audio_lm_path, 'rb') as fr: 25 | audio_log_mel = pickle.load(fr) 26 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64] 27 | return audio_log_mel 28 | 29 | 30 | class MS3Dataset(Dataset): 31 | """Dataset for multiple sound source segmentation""" 32 | 33 | def __init__(self, split='train', cfg=None): 34 | super(MS3Dataset, self).__init__() 35 | self.split = split 36 | self.mask_num = 5 37 | self.cfg = cfg 38 | df_all = pd.read_csv(cfg.anno_csv, sep=',') 39 | self.df_split = df_all[df_all['split'] == split] 40 | print("{}/{} videos are used for {}".format(len(self.df_split), 41 | len(df_all), self.split)) 42 | self.img_transform = transforms.Compose([ 43 | transforms.Resize([512, 512]), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 46 | ]) 47 | self.mask_transform = transforms.Compose([ 48 | transforms.Resize([512, 512]), 49 | transforms.ToTensor(), 50 | ]) 51 | 52 | def __getitem__(self, index): 53 | df_one_video = self.df_split.iloc[index] 54 | video_name = df_one_video[0] 55 | img_base_path = os.path.join(self.cfg.dir_img, video_name) 56 | audio_lm_path = os.path.join( 57 | self.cfg.dir_audio_log_mel, self.split, video_name + '.pkl') 58 | mask_base_path = os.path.join( 59 | self.cfg.dir_mask, self.split, video_name) 60 | audio_log_mel = load_audio_lm(audio_lm_path) 61 | # audio_lm_tensor = torch.from_numpy(audio_log_mel) 62 | imgs, masks = [], [] 63 | for img_id in range(1, 6): 64 | img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png" % ( 65 | video_name, img_id)), transform=self.img_transform) 66 | imgs.append(img) 67 | for mask_id in range(1, self.mask_num + 1): 68 | mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % ( 69 | video_name, mask_id)), transform=self.mask_transform, mode='P') 70 | masks.append(mask) 71 | imgs_tensor = torch.stack(imgs, dim=0) 72 | masks_tensor = torch.stack(masks, dim=0) 73 | 74 | return imgs_tensor, audio_log_mel, masks_tensor, video_name 75 | 76 | def __len__(self): 77 | return len(self.df_split) 78 | -------------------------------------------------------------------------------- /dataloader/s4_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from wave import _wave_params 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import Dataset 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pickle 10 | 11 | import cv2 12 | from PIL import Image 13 | from torchvision import transforms 14 | 15 | 16 | def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None): 17 | img_PIL = Image.open(path).convert(mode) 18 | if transform: 19 | img_tensor = transform(img_PIL) 20 | return img_tensor 21 | return img_PIL 22 | 23 | 24 | def load_audio_lm(audio_lm_path): 25 | with open(audio_lm_path, 'rb') as fr: 26 | audio_log_mel = pickle.load(fr) 27 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64] 28 | return audio_log_mel 29 | 30 | 31 | class S4Dataset(Dataset): 32 | """Dataset for single sound source segmentation""" 33 | 34 | def __init__(self, split='train', cfg=None): 35 | super(S4Dataset, self).__init__() 36 | self.split = split 37 | self.cfg = cfg 38 | self.mask_num = 1 if self.split == 'train' else 5 39 | df_all = pd.read_csv(cfg.anno_csv, sep=',') 40 | self.df_split = df_all[df_all['split'] == split] 41 | print("{}/{} videos are used for {}".format(len(self.df_split), 42 | len(df_all), self.split)) 43 | self.img_transform = transforms.Compose([ 44 | transforms.Resize([512, 512]), 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 47 | ]) 48 | self.mask_transform = transforms.Compose([ 49 | transforms.Resize([512, 512]), 50 | transforms.ToTensor(), 51 | ]) 52 | 53 | def __getitem__(self, index): 54 | df_one_video = self.df_split.iloc[index] 55 | video_name, category = df_one_video[0], df_one_video[2] 56 | img_base_path = os.path.join( 57 | self.cfg.dir_img, self.split, category, video_name) 58 | audio_lm_path = os.path.join( 59 | self.cfg.dir_audio_log_mel, self.split, category, video_name + '.pkl') 60 | mask_base_path = os.path.join( 61 | self.cfg.dir_mask, self.split, category, video_name) 62 | audio_log_mel = load_audio_lm(audio_lm_path) 63 | # audio_lm_tensor = torch.from_numpy(audio_log_mel) 64 | imgs, masks = [], [] 65 | for img_id in range(1, 6): 66 | img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s_%d.png" % ( 67 | video_name, img_id)), transform=self.img_transform) 68 | imgs.append(img) 69 | for mask_id in range(1, self.mask_num + 1): 70 | mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % ( 71 | video_name, mask_id)), transform=self.mask_transform, mode='1') 72 | masks.append(mask) 73 | imgs_tensor = torch.stack(imgs, dim=0) 74 | masks_tensor = torch.stack(masks, dim=0) 75 | 76 | if self.split == 'train': 77 | return imgs_tensor, audio_log_mel, masks_tensor 78 | else: 79 | return imgs_tensor, audio_log_mel, masks_tensor, category, video_name 80 | 81 | def __len__(self): 82 | return len(self.df_split) 83 | -------------------------------------------------------------------------------- /dataloader/v2_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | # from wave import _wave_params 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import Dataset 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pickle 10 | import json 11 | 12 | # import cv2 13 | from PIL import Image 14 | from torchvision import transforms 15 | 16 | # from .config import cfg_avs 17 | 18 | 19 | def get_v2_pallete(label_to_idx_path, num_cls=71): 20 | def _getpallete(num_cls=71): 21 | """build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2), 22 | 71 is the total category number of V2 dataset, you should not change that""" 23 | n = num_cls 24 | pallete = [0] * (n * 3) 25 | for j in range(0, n): 26 | lab = j 27 | pallete[j * 3 + 0] = 0 28 | pallete[j * 3 + 1] = 0 29 | pallete[j * 3 + 2] = 0 30 | i = 0 31 | while (lab > 0): 32 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 33 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 34 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 35 | i = i + 1 36 | lab >>= 3 37 | return pallete # list, lenth is n_classes*3 38 | 39 | with open(label_to_idx_path, 'r') as fr: 40 | label_to_pallete_idx = json.load(fr) 41 | v2_pallete = _getpallete(num_cls) # list 42 | v2_pallete = np.array(v2_pallete).reshape(-1, 3) 43 | assert len(v2_pallete) == len(label_to_pallete_idx) 44 | return v2_pallete 45 | 46 | 47 | def crop_resize_img(crop_size, img, img_is_mask=False): 48 | outsize = crop_size 49 | short_size = outsize 50 | w, h = img.size 51 | if w > h: 52 | oh = short_size 53 | ow = int(1.0 * w * oh / h) 54 | else: 55 | ow = short_size 56 | oh = int(1.0 * h * ow / w) 57 | if not img_is_mask: 58 | img = img.resize((ow, oh), Image.BILINEAR) 59 | else: 60 | img = img.resize((ow, oh), Image.NEAREST) 61 | # center crop 62 | w, h = img.size 63 | x1 = int(round((w - outsize) / 2.)) 64 | y1 = int(round((h - outsize) / 2.)) 65 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) 66 | # print("crop for train. set") 67 | return img 68 | 69 | 70 | def resize_img(crop_size, img, img_is_mask=False): 71 | outsize = crop_size 72 | # only resize for val./test. set 73 | if not img_is_mask: 74 | img = img.resize((outsize, outsize), Image.BILINEAR) 75 | else: 76 | img = img.resize((outsize, outsize), Image.NEAREST) 77 | return img 78 | 79 | 80 | def color_mask_to_label(mask, v_pallete): 81 | mask_array = np.array(mask).astype('int32') 82 | semantic_map = [] 83 | for colour in v_pallete: 84 | equality = np.equal(mask_array, colour) 85 | class_map = np.all(equality, axis=-1) 86 | semantic_map.append(class_map) 87 | semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32) 88 | # pdb.set_trace() # there is only one '1' value for each pixel, run np.sum(semantic_map, axis=-1) 89 | label = np.argmax(semantic_map, axis=-1) 90 | return label 91 | 92 | 93 | def load_image_in_PIL_to_Tensor(path, split='train', mode='RGB', transform=None, cfg=None): 94 | img_PIL = Image.open(path).convert(mode) 95 | if cfg.crop_img_and_mask: 96 | if split == 'train': 97 | img_PIL = crop_resize_img( 98 | cfg.crop_size, img_PIL, img_is_mask=False) 99 | else: 100 | img_PIL = resize_img(cfg.crop_size, 101 | img_PIL, img_is_mask=False) 102 | if transform: 103 | img_tensor = transform(img_PIL) 104 | return img_tensor 105 | return img_PIL 106 | 107 | 108 | def load_color_mask_in_PIL_to_Tensor(path, v_pallete, split='train', mode='RGB', cfg=None): 109 | color_mask_PIL = Image.open(path).convert(mode) 110 | if cfg.crop_img_and_mask: 111 | if split == 'train': 112 | color_mask_PIL = crop_resize_img( 113 | cfg.crop_size, color_mask_PIL, img_is_mask=True) 114 | else: 115 | color_mask_PIL = resize_img( 116 | cfg.crop_size, color_mask_PIL, img_is_mask=True) 117 | # obtain semantic label 118 | color_label = color_mask_to_label(color_mask_PIL, v_pallete) 119 | color_label = torch.from_numpy(color_label) # [H, W] 120 | color_label = color_label.unsqueeze(0) 121 | # binary_mask = (color_label != (cfg_avs.NUM_CLASSES-1)).float() 122 | # return color_label, binary_mask # both [1, H, W] 123 | return color_label # both [1, H, W] 124 | 125 | 126 | def load_audio_lm(audio_lm_path): 127 | with open(audio_lm_path, 'rb') as fr: 128 | audio_log_mel = pickle.load(fr) 129 | audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64] 130 | return audio_log_mel 131 | 132 | 133 | class V2Dataset(Dataset): 134 | """Dataset for audio visual semantic segmentation of AVSBench-semantic (V2)""" 135 | 136 | def __init__(self, split='train', cfg=None, debug_flag=False): 137 | super(V2Dataset, self).__init__() 138 | self.split = split 139 | self.cfg = cfg 140 | self.mask_num = cfg.mask_num 141 | df_all = pd.read_csv(cfg.meta_csv_path, sep=',') 142 | self.df_split = df_all[df_all['split'] == split] 143 | if debug_flag: 144 | self.df_split = self.df_split[:100] 145 | print("{}/{} videos are used for {}.".format(len(self.df_split), 146 | len(df_all), self.split)) 147 | self.img_transform = transforms.Compose([ 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 150 | ]) 151 | self.v2_pallete = get_v2_pallete( 152 | cfg.label_idx_path, num_cls=cfg.num_class) 153 | 154 | def __getitem__(self, index): 155 | df_one_video = self.df_split.iloc[index] 156 | video_name, set = df_one_video['uid'], df_one_video['label'] 157 | img_base_path = os.path.join( 158 | self.cfg.dir_base, set, video_name, 'frames') 159 | audio_path = os.path.join( 160 | self.cfg.dir_base, set, video_name, 'audio.wav') 161 | color_mask_base_path = os.path.join( 162 | self.cfg.dir_base, set, video_name, 'labels_rgb') 163 | 164 | # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame) 165 | if set == 'v1s': 166 | vid_temporal_mask_flag = torch.Tensor( 167 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool() 168 | gt_temporal_mask_flag = torch.Tensor( 169 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # .bool() 170 | # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated) 171 | elif set == 'v1m': 172 | vid_temporal_mask_flag = torch.Tensor( 173 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool() 174 | gt_temporal_mask_flag = torch.Tensor( 175 | [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) # .bool() 176 | # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated)) 177 | elif set == 'v2': 178 | vid_temporal_mask_flag = torch.ones(10) # .bool() 179 | gt_temporal_mask_flag = torch.ones(10) # .bool() 180 | 181 | img_path_list = sorted(os.listdir(img_base_path) 182 | ) # 5 for v1, 10 for new v2 183 | imgs_num = len(img_path_list) 184 | imgs_pad_zero_num = 10 - imgs_num 185 | imgs = [] 186 | for img_id in range(imgs_num): 187 | img_path = os.path.join(img_base_path, "%d.jpg" % (img_id)) 188 | img = load_image_in_PIL_to_Tensor( 189 | img_path, split=self.split, transform=self.img_transform, cfg=self.cfg) 190 | imgs.append(img) 191 | for pad_i in range(imgs_pad_zero_num): # ! pad black image? 192 | img = torch.zeros_like(img) 193 | imgs.append(img) 194 | 195 | labels = [] 196 | mask_path_list = sorted(os.listdir(color_mask_base_path)) 197 | for mask_path in mask_path_list: 198 | if not mask_path.endswith(".png"): 199 | mask_path_list.remove(mask_path) 200 | mask_num = len(mask_path_list) 201 | if self.split != 'train': 202 | if set == 'v2': 203 | assert mask_num == 10 204 | else: 205 | assert mask_num == 5 206 | 207 | mask_num = len(mask_path_list) 208 | label_pad_zero_num = 10 - mask_num 209 | for mask_id in range(mask_num): 210 | mask_path = os.path.join( 211 | color_mask_base_path, "%d.png" % (mask_id)) 212 | # mask_path = os.path.join(color_mask_base_path, mask_path_list[mask_id]) 213 | color_label = load_color_mask_in_PIL_to_Tensor( 214 | mask_path, v_pallete=self.v2_pallete, split=self.split, cfg=self.cfg) 215 | # print('color_label.shape: ', color_label.shape) 216 | labels.append(color_label) 217 | for pad_j in range(label_pad_zero_num): 218 | color_label = torch.zeros_like(color_label) 219 | labels.append(color_label) 220 | 221 | imgs_tensor = torch.stack(imgs, dim=0) 222 | labels_tensor = torch.stack(labels, dim=0) 223 | 224 | return imgs_tensor, audio_path, labels_tensor, \ 225 | vid_temporal_mask_flag, gt_temporal_mask_flag, video_name 226 | 227 | def __len__(self): 228 | return len(self.df_split) 229 | 230 | @property 231 | def num_classes(self): 232 | """Number of categories (including background).""" 233 | return self.cfg.num_class 234 | 235 | @property 236 | def classes(self): 237 | """Category names.""" 238 | with open(self.cfg.label_idx_path, 'r') as fr: 239 | classes = json.load(fr) 240 | return [label for label in classes.keys()] 241 | -------------------------------------------------------------------------------- /image/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vvvb-github/AVSegFormer/7083cb6586d1bc06d174a34be3c34bb5814cf0c3/image/arch.png -------------------------------------------------------------------------------- /model/AVSegFormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .backbone import build_backbone 4 | # from .neck import build_neck 5 | from .head import build_head 6 | from .vggish import VGGish 7 | 8 | 9 | class AVSegFormer(nn.Module): 10 | def __init__(self, 11 | backbone, 12 | vggish, 13 | head, 14 | neck=None, 15 | audio_dim=128, 16 | embed_dim=256, 17 | T=5, 18 | freeze_audio_backbone=True, 19 | *args, **kwargs): 20 | super().__init__() 21 | 22 | self.embed_dim = embed_dim 23 | self.T = T 24 | self.freeze_audio_backbone = freeze_audio_backbone 25 | self.backbone = build_backbone(**backbone) 26 | self.vggish = VGGish(**vggish) 27 | self.head = build_head(**head) 28 | self.audio_proj = nn.Linear(audio_dim, embed_dim) 29 | 30 | if self.freeze_audio_backbone: 31 | for p in self.vggish.parameters(): 32 | p.requires_grad = False 33 | self.freeze_backbone(True) 34 | 35 | self.neck = neck 36 | # if neck is not None: 37 | # self.neck = build_neck(**neck) 38 | # else: 39 | # self.neck = None 40 | 41 | def freeze_backbone(self, freeze=False): 42 | for p in self.backbone.parameters(): 43 | p.requires_grad = not freeze 44 | 45 | def mul_temporal_mask(self, feats, vid_temporal_mask_flag=None): 46 | if vid_temporal_mask_flag is None: 47 | return feats 48 | else: 49 | if isinstance(feats, list): 50 | out = [] 51 | for x in feats: 52 | out.append(x * vid_temporal_mask_flag) 53 | elif isinstance(feats, torch.Tensor): 54 | out = feats * vid_temporal_mask_flag 55 | 56 | return out 57 | 58 | def extract_feat(self, x): 59 | feats = self.backbone(x) 60 | if self.neck is not None: 61 | feats = self.neck(feats) 62 | return feats 63 | 64 | def forward(self, audio, frames, vid_temporal_mask_flag=None): 65 | if vid_temporal_mask_flag is not None: 66 | vid_temporal_mask_flag = vid_temporal_mask_flag.view(-1, 1, 1, 1) 67 | with torch.no_grad(): 68 | audio_feat = self.vggish(audio) # [B*T,128] 69 | 70 | audio_feat = audio_feat.unsqueeze(1) 71 | audio_feat = self.audio_proj(audio_feat) 72 | img_feat = self.extract_feat(frames) 73 | img_feat = self.mul_temporal_mask(img_feat, vid_temporal_mask_flag) 74 | 75 | pred, mask_feature = self.head(img_feat, audio_feat) 76 | pred = self.mul_temporal_mask(pred, vid_temporal_mask_flag) 77 | mask_feature = self.mul_temporal_mask( 78 | mask_feature, vid_temporal_mask_flag) 79 | 80 | return pred, mask_feature 81 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .AVSegFormer import AVSegFormer 2 | 3 | 4 | def build_model(type, **kwargs): 5 | if type == 'AVSegFormer': 6 | return AVSegFormer(**kwargs) 7 | else: 8 | raise ValueError 9 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import B2_ResNet 2 | from .pvt import pvt_v2_b5 3 | 4 | 5 | def build_backbone(type, **kwargs): 6 | if type == 'res50': 7 | return B2_ResNet(**kwargs) 8 | elif type=='pvt_v2_b5': 9 | return pvt_v2_b5(**kwargs) 10 | 11 | 12 | __all__=['build_backbone'] 13 | -------------------------------------------------------------------------------- /model/backbone/pvt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | # from mmseg.models.builder import BACKBONES 10 | # from mmseg.utils import get_root_logger 11 | # from mmcv.runner import load_checkpoint 12 | import math 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.dwconv = DWConv(hidden_features) 22 | self.act = act_layer() 23 | self.fc2 = nn.Linear(hidden_features, out_features) 24 | self.drop = nn.Dropout(drop) 25 | self.linear = linear 26 | 27 | if self.linear: 28 | self.relu = nn.ReLU(inplace=True) 29 | self.apply(self._init_weights) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.LayerNorm): 37 | nn.init.constant_(m.bias, 0) 38 | nn.init.constant_(m.weight, 1.0) 39 | elif isinstance(m, nn.Conv2d): 40 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | fan_out //= m.groups 42 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | 46 | def forward(self, x, H, W): 47 | x = self.fc1(x) 48 | if self.linear: 49 | x = self.relu(x) 50 | x = self.dwconv(x, H, W) 51 | x = self.act(x) 52 | x = self.drop(x) 53 | x = self.fc2(x) 54 | x = self.drop(x) 55 | return x 56 | 57 | 58 | class Attention(nn.Module): 59 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 60 | super().__init__() 61 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 62 | 63 | self.dim = dim 64 | self.num_heads = num_heads 65 | head_dim = dim // num_heads 66 | self.scale = qk_scale or head_dim ** -0.5 67 | 68 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 69 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 70 | self.attn_drop = nn.Dropout(attn_drop) 71 | self.proj = nn.Linear(dim, dim) 72 | self.proj_drop = nn.Dropout(proj_drop) 73 | 74 | self.linear = linear 75 | self.sr_ratio = sr_ratio 76 | if not linear: 77 | if sr_ratio > 1: 78 | self.sr = nn.Conv2d( 79 | dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 80 | self.norm = nn.LayerNorm(dim) 81 | else: 82 | self.pool = nn.AdaptiveAvgPool2d(7) 83 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 84 | self.norm = nn.LayerNorm(dim) 85 | self.act = nn.GELU() 86 | self.apply(self._init_weights) 87 | 88 | def _init_weights(self, m): 89 | if isinstance(m, nn.Linear): 90 | trunc_normal_(m.weight, std=.02) 91 | if isinstance(m, nn.Linear) and m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.LayerNorm): 94 | nn.init.constant_(m.bias, 0) 95 | nn.init.constant_(m.weight, 1.0) 96 | elif isinstance(m, nn.Conv2d): 97 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | fan_out //= m.groups 99 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 100 | if m.bias is not None: 101 | m.bias.data.zero_() 102 | 103 | def forward(self, x, H, W): 104 | B, N, C = x.shape 105 | q = self.q(x).reshape(B, N, self.num_heads, C // 106 | self.num_heads).permute(0, 2, 1, 3) 107 | 108 | if not self.linear: 109 | if self.sr_ratio > 1: 110 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 111 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 112 | x_ = self.norm(x_) 113 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, 114 | C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | else: 116 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, 117 | C // self.num_heads).permute(2, 0, 3, 1, 4) 118 | else: 119 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 120 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 121 | x_ = self.norm(x_) 122 | x_ = self.act(x_) 123 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, 124 | C // self.num_heads).permute(2, 0, 3, 1, 4) 125 | k, v = kv[0], kv[1] 126 | 127 | attn = (q @ k.transpose(-2, -1)) * self.scale 128 | attn = attn.softmax(dim=-1) 129 | attn = self.attn_drop(attn) 130 | 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | 135 | return x 136 | 137 | 138 | class Block(nn.Module): 139 | 140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 142 | super().__init__() 143 | self.norm1 = norm_layer(dim) 144 | self.attn = Attention( 145 | dim, 146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 149 | self.drop_path = DropPath( 150 | drop_path) if drop_path > 0. else nn.Identity() 151 | self.norm2 = norm_layer(dim) 152 | mlp_hidden_dim = int(dim * mlp_ratio) 153 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 154 | act_layer=act_layer, drop=drop, linear=linear) 155 | 156 | self.apply(self._init_weights) 157 | 158 | def _init_weights(self, m): 159 | if isinstance(m, nn.Linear): 160 | trunc_normal_(m.weight, std=.02) 161 | if isinstance(m, nn.Linear) and m.bias is not None: 162 | nn.init.constant_(m.bias, 0) 163 | elif isinstance(m, nn.LayerNorm): 164 | nn.init.constant_(m.bias, 0) 165 | nn.init.constant_(m.weight, 1.0) 166 | elif isinstance(m, nn.Conv2d): 167 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 168 | fan_out //= m.groups 169 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | 173 | def forward(self, x, H, W): 174 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 175 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 176 | 177 | return x 178 | 179 | 180 | class OverlapPatchEmbed(nn.Module): 181 | """ Image to Patch Embedding 182 | """ 183 | 184 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 185 | super().__init__() 186 | img_size = to_2tuple(img_size) 187 | patch_size = to_2tuple(patch_size) 188 | 189 | assert max(patch_size) > stride, "Set larger patch_size than stride" 190 | 191 | self.img_size = img_size 192 | self.patch_size = patch_size 193 | self.H, self.W = img_size[0] // stride, img_size[1] // stride 194 | self.num_patches = self.H * self.W 195 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 196 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 197 | self.norm = nn.LayerNorm(embed_dim) 198 | 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, m): 202 | if isinstance(m, nn.Linear): 203 | trunc_normal_(m.weight, std=.02) 204 | if isinstance(m, nn.Linear) and m.bias is not None: 205 | nn.init.constant_(m.bias, 0) 206 | elif isinstance(m, nn.LayerNorm): 207 | nn.init.constant_(m.bias, 0) 208 | nn.init.constant_(m.weight, 1.0) 209 | elif isinstance(m, nn.Conv2d): 210 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 211 | fan_out //= m.groups 212 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 213 | if m.bias is not None: 214 | m.bias.data.zero_() 215 | 216 | def forward(self, x): 217 | x = self.proj(x) 218 | _, _, H, W = x.shape 219 | x = x.flatten(2).transpose(1, 2) 220 | x = self.norm(x) 221 | 222 | return x, H, W 223 | 224 | 225 | class PyramidVisionTransformerV2(nn.Module): 226 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 227 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 228 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], 229 | sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False, init_weights_path=None): 230 | super().__init__() 231 | # self.num_classes = num_classes 232 | self.depths = depths 233 | self.num_stages = num_stages 234 | self.linear = linear 235 | 236 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 237 | sum(depths))] # stochastic depth decay rule 238 | cur = 0 239 | 240 | for i in range(num_stages): 241 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 242 | patch_size=7 if i == 0 else 3, 243 | stride=4 if i == 0 else 2, 244 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 245 | embed_dim=embed_dims[i]) 246 | 247 | block = nn.ModuleList([Block( 248 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 249 | qk_scale=qk_scale, 250 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + 251 | j], norm_layer=norm_layer, 252 | sr_ratio=sr_ratios[i], linear=linear) 253 | for j in range(depths[i])]) 254 | norm = norm_layer(embed_dims[i]) 255 | cur += depths[i] 256 | 257 | setattr(self, f"patch_embed{i + 1}", patch_embed) 258 | setattr(self, f"block{i + 1}", block) 259 | setattr(self, f"norm{i + 1}", norm) 260 | 261 | self.apply(self._init_weights) 262 | 263 | if init_weights_path is not None: 264 | self.initialize_weights(init_weights_path) 265 | 266 | def initialize_weights(self, path): 267 | pvt_model_dict = self.state_dict() 268 | pretrained_state_dicts = torch.load(path) 269 | # for k, v in pretrained_state_dicts['model'].items(): 270 | # if k in pvt_model_dict.keys(): 271 | # print(k, v.requires_grad) 272 | state_dict = {k: v for k, v in pretrained_state_dicts.items() 273 | if k in pvt_model_dict.keys()} 274 | pvt_model_dict.update(state_dict) 275 | self.load_state_dict(pvt_model_dict) 276 | print(f'==> Load pvt-v2-b5 parameters pretrained on ImageNet: {path}') 277 | 278 | def _init_weights(self, m): 279 | if isinstance(m, nn.Linear): 280 | trunc_normal_(m.weight, std=.02) 281 | if isinstance(m, nn.Linear) and m.bias is not None: 282 | nn.init.constant_(m.bias, 0) 283 | elif isinstance(m, nn.LayerNorm): 284 | nn.init.constant_(m.bias, 0) 285 | nn.init.constant_(m.weight, 1.0) 286 | elif isinstance(m, nn.Conv2d): 287 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 288 | fan_out //= m.groups 289 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 290 | if m.bias is not None: 291 | m.bias.data.zero_() 292 | 293 | def freeze_patch_emb(self): 294 | self.patch_embed1.requires_grad = False 295 | 296 | @torch.jit.ignore 297 | def no_weight_decay(self): 298 | # has pos_embed may be better 299 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} 300 | 301 | def get_classifier(self): 302 | return self.head 303 | 304 | def reset_classifier(self, num_classes, global_pool=''): 305 | self.num_classes = num_classes 306 | self.head = nn.Linear( 307 | self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 308 | 309 | def forward_features(self, x): 310 | B = x.shape[0] 311 | outs = [] 312 | 313 | for i in range(self.num_stages): 314 | patch_embed = getattr(self, f"patch_embed{i + 1}") 315 | block = getattr(self, f"block{i + 1}") 316 | norm = getattr(self, f"norm{i + 1}") 317 | x, H, W = patch_embed(x) 318 | for blk in block: 319 | x = blk(x, H, W) 320 | x = norm(x) 321 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 322 | outs.append(x) 323 | 324 | return outs 325 | 326 | def forward(self, x): 327 | x = self.forward_features(x) 328 | # x = self.head(x) 329 | 330 | return x 331 | 332 | 333 | class DWConv(nn.Module): 334 | def __init__(self, dim=768): 335 | super(DWConv, self).__init__() 336 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 337 | 338 | def forward(self, x, H, W): 339 | B, N, C = x.shape 340 | x = x.transpose(1, 2).view(B, C, H, W) 341 | x = self.dwconv(x) 342 | x = x.flatten(2).transpose(1, 2) 343 | 344 | return x 345 | 346 | 347 | @register_model 348 | def pvt_v2_b5(init_weights_path=False, **kwargs): 349 | model = PyramidVisionTransformerV2( 350 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 351 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 352 | drop_rate=0.0, drop_path_rate=0.1, init_weights_path=init_weights_path, 353 | **kwargs) 354 | model.default_cfg = _cfg() 355 | 356 | return model 357 | -------------------------------------------------------------------------------- /model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torchvision.models as models 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class B2_ResNet(nn.Module): 85 | # ResNet50 with two branches 86 | def __init__(self, init_weights_path=None): 87 | # self.inplanes = 128 88 | self.inplanes = 64 89 | super(B2_ResNet, self).__init__() 90 | 91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 92 | bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 97 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 98 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2) 99 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2) 100 | 101 | self.inplanes = 512 102 | self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2) 103 | self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | if init_weights_path is not None: 114 | self.initialize_weights(init_weights_path) 115 | 116 | def initialize_weights(self, path): 117 | res50 = models.resnet50(pretrained=False) 118 | resnet50_dict = torch.load(path) 119 | res50.load_state_dict(resnet50_dict) 120 | pretrained_dict = res50.state_dict() 121 | 122 | all_params = {} 123 | for k, v in self.state_dict().items(): 124 | if k in pretrained_dict.keys(): 125 | v = pretrained_dict[k] 126 | all_params[k] = v 127 | elif '_1' in k: 128 | name = k.split('_1')[0] + k.split('_1')[1] 129 | v = pretrained_dict[name] 130 | all_params[k] = v 131 | elif '_2' in k: 132 | name = k.split('_2')[0] + k.split('_2')[1] 133 | v = pretrained_dict[name] 134 | all_params[k] = v 135 | assert len(all_params.keys()) == len(self.state_dict().keys()) 136 | self.load_state_dict(all_params) 137 | print(f'==> Load pretrained ResNet50 parameters from {path}') 138 | 139 | def _make_layer(self, block, planes, blocks, stride=1): 140 | downsample = None 141 | if stride != 1 or self.inplanes != planes * block.expansion: 142 | downsample = nn.Sequential( 143 | nn.Conv2d(self.inplanes, planes * block.expansion, 144 | kernel_size=1, stride=stride, bias=False), 145 | nn.BatchNorm2d(planes * block.expansion), 146 | ) 147 | 148 | layers = [] 149 | layers.append(block(self.inplanes, planes, stride, downsample)) 150 | self.inplanes = planes * block.expansion 151 | for i in range(1, blocks): 152 | layers.append(block(self.inplanes, planes)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | def forward(self, x, branch=1): 157 | layer3 = getattr(self, f'layer3_{branch}') 158 | layer4 = getattr(self, f'layer4_{branch}') 159 | 160 | x = self.conv1(x) 161 | x = self.bn1(x) 162 | x = self.relu(x) 163 | x = self.maxpool(x) 164 | 165 | x1 = self.layer1(x) 166 | x2 = self.layer2(x1) 167 | x3 = layer3(x2) 168 | x4 = layer4(x3) 169 | 170 | return [x1, x2, x3, x4] 171 | -------------------------------------------------------------------------------- /model/head/AVSegHead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.utils import build_transformer, build_positional_encoding, build_fusion_block, build_generator 5 | from ops.modules import MSDeformAttn 6 | from torch.nn.init import normal_ 7 | from torch.nn.functional import interpolate 8 | 9 | 10 | class Interpolate(nn.Module): 11 | def __init__(self, scale_factor, mode, align_corners=False): 12 | super(Interpolate, self).__init__() 13 | 14 | self.interp = interpolate 15 | self.scale_factor = scale_factor 16 | self.mode = mode 17 | self.align_corners = align_corners 18 | 19 | def forward(self, x): 20 | x = self.interp( 21 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 22 | ) 23 | return x 24 | 25 | 26 | class MLP(nn.Module): 27 | """ Very simple multi-layer perceptron (also called FFN)""" 28 | 29 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 30 | super().__init__() 31 | self.num_layers = num_layers 32 | h = [hidden_dim] * (num_layers - 1) 33 | self.layers = nn.ModuleList(nn.Conv2d(n, k, kernel_size=1, stride=1, padding=0) 34 | for n, k in zip([input_dim] + h, h + [output_dim])) 35 | 36 | def forward(self, x): 37 | for i, layer in enumerate(self.layers): 38 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 39 | return x 40 | 41 | 42 | class SimpleFPN(nn.Module): 43 | def __init__(self, channel=256, layers=3): 44 | super().__init__() 45 | 46 | assert layers == 3 # only designed for 3 layers 47 | self.up1 = nn.Sequential( 48 | Interpolate(scale_factor=2, mode='bilinear'), 49 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 50 | ) 51 | self.up2 = nn.Sequential( 52 | Interpolate(scale_factor=2, mode='bilinear'), 53 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 54 | ) 55 | self.up3 = nn.Sequential( 56 | Interpolate(scale_factor=2, mode='bilinear'), 57 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 58 | ) 59 | 60 | def forward(self, x): 61 | x1 = self.up1(x[-1]) 62 | x1 = x1 + x[-2] 63 | 64 | x2 = self.up2(x1) 65 | x2 = x2 + x[-3] 66 | 67 | y = self.up3(x2) 68 | return y 69 | 70 | 71 | class AVSegHead(nn.Module): 72 | def __init__(self, 73 | in_channels, 74 | num_classes, 75 | query_num, 76 | transformer, 77 | query_generator, 78 | embed_dim=256, 79 | valid_indices=[1, 2, 3], 80 | scale_factor=4, 81 | positional_encoding=None, 82 | use_learnable_queries=True, 83 | fusion_block=None) -> None: 84 | super().__init__() 85 | 86 | self.in_channels = in_channels 87 | self.embed_dim = embed_dim 88 | self.num_classes = num_classes 89 | self.query_num = query_num 90 | self.valid_indices = valid_indices 91 | self.num_feats = len(valid_indices) 92 | self.scale_factor = scale_factor 93 | self.use_learnable_queries = use_learnable_queries 94 | self.level_embed = nn.Parameter( 95 | torch.Tensor(self.num_feats, embed_dim)) 96 | self.learnable_query = nn.Embedding(query_num, embed_dim) 97 | 98 | self.query_generator = build_generator(**query_generator) 99 | 100 | self.transformer = build_transformer(**transformer) 101 | if positional_encoding is not None: 102 | self.positional_encoding = build_positional_encoding( 103 | **positional_encoding) 104 | else: 105 | self.positional_encoding = None 106 | 107 | in_proj = [] 108 | for c in in_channels: 109 | in_proj.append( 110 | nn.Sequential( 111 | nn.Conv2d(c, embed_dim, kernel_size=1), 112 | nn.GroupNorm(32, embed_dim) 113 | ) 114 | ) 115 | self.in_proj = nn.ModuleList(in_proj) 116 | self.mlp = MLP(query_num, 2048, embed_dim, 3) 117 | 118 | if fusion_block is not None: 119 | self.fusion_block = build_fusion_block(**fusion_block) 120 | 121 | self.lateral_conv = nn.Sequential( 122 | nn.Conv2d(embed_dim, embed_dim, 123 | kernel_size=1, stride=1, padding=0), 124 | nn.GroupNorm(32, embed_dim) 125 | ) 126 | self.out_conv = nn.Sequential( 127 | nn.Conv2d(embed_dim, embed_dim, 128 | kernel_size=3, stride=1, padding=1), 129 | nn.GroupNorm(32, embed_dim), 130 | nn.ReLU(True) 131 | ) 132 | 133 | self.fpn = SimpleFPN() 134 | self.attn_fc = nn.Sequential( 135 | nn.Conv2d(embed_dim, 128, kernel_size=3, stride=1, padding=1), 136 | Interpolate(scale_factor=scale_factor, mode="bilinear"), 137 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 138 | nn.ReLU(True), 139 | nn.Conv2d(32, num_classes, kernel_size=1, 140 | stride=1, padding=0, bias=False) 141 | ) 142 | self.fc = nn.Sequential( 143 | nn.Conv2d(embed_dim, 128, kernel_size=3, stride=1, padding=1), 144 | Interpolate(scale_factor=scale_factor, mode="bilinear"), 145 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 146 | nn.ReLU(True), 147 | nn.Conv2d(32, num_classes, kernel_size=1, 148 | stride=1, padding=0, bias=False) 149 | ) 150 | 151 | self._reset_parameters() 152 | 153 | def _reset_parameters(self): 154 | for p in self.parameters(): 155 | if p.dim() > 1: 156 | nn.init.xavier_uniform_(p) 157 | for m in self.modules(): 158 | if isinstance(m, MSDeformAttn): 159 | m._reset_parameters() 160 | normal_(self.level_embed) 161 | 162 | def get_valid_ratio(self, mask): 163 | _, H, W = mask.shape 164 | valid_H = torch.sum(~mask[:, :, 0], 1) 165 | valid_W = torch.sum(~mask[:, 0, :], 1) 166 | valid_ratio_h = valid_H.float() / H 167 | valid_ratio_w = valid_W.float() / W 168 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 169 | return valid_ratio 170 | 171 | def reform_output_squences(self, memory, spatial_shapes, level_start_index, dim=1): 172 | split_size_or_sections = [None] * self.num_feats 173 | for i in range(self.num_feats): 174 | if i < self.num_feats - 1: 175 | split_size_or_sections[i] = level_start_index[i + 176 | 1] - level_start_index[i] 177 | else: 178 | split_size_or_sections[i] = memory.shape[dim] - \ 179 | level_start_index[i] 180 | y = torch.split(memory, split_size_or_sections, dim=dim) 181 | return y 182 | 183 | def forward(self, feats, audio_feat): 184 | feat14 = self.in_proj[0](feats[0]) 185 | srcs = [self.in_proj[i](feats[i]) for i in self.valid_indices] 186 | masks = [torch.zeros((x.size(0), x.size(2), x.size( 187 | 3)), device=x.device, dtype=torch.bool) for x in srcs] 188 | pos_embeds = [] 189 | for m in masks: 190 | pos_embeds.append(self.positional_encoding(m)) 191 | # prepare input for encoder 192 | src_flatten = [] 193 | mask_flatten = [] 194 | lvl_pos_embed_flatten = [] 195 | spatial_shapes = [] 196 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 197 | bs, c, h, w = src.shape 198 | spatial_shape = (h, w) 199 | spatial_shapes.append(spatial_shape) 200 | src = src.flatten(2).transpose(1, 2) 201 | mask = mask.flatten(1) 202 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 203 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 204 | lvl_pos_embed_flatten.append(lvl_pos_embed) 205 | src_flatten.append(src) 206 | mask_flatten.append(mask) 207 | src_flatten = torch.cat(src_flatten, 1) 208 | mask_flatten = torch.cat(mask_flatten, 1) 209 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 210 | spatial_shapes = torch.as_tensor( 211 | spatial_shapes, dtype=torch.long, device=src_flatten.device) 212 | level_start_index = torch.cat((spatial_shapes.new_zeros( 213 | (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) 214 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 215 | 216 | # prepare queries 217 | bs = audio_feat.shape[0] 218 | query = self.query_generator(audio_feat) 219 | if self.use_learnable_queries: 220 | query = query + \ 221 | self.learnable_query.weight[None, :, :].repeat(bs, 1, 1) 222 | 223 | memory, outputs = self.transformer(query, src_flatten, spatial_shapes, 224 | level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) 225 | 226 | # generate mask feature 227 | mask_feats = [] 228 | for i, z in enumerate(self.reform_output_squences(memory, spatial_shapes, level_start_index, 1)): 229 | mask_feats.append(z.transpose(1, 2).view( 230 | bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) 231 | cur_fpn = self.lateral_conv(feat14) 232 | mask_feature = mask_feats[0] 233 | mask_feature = cur_fpn + \ 234 | F.interpolate( 235 | mask_feature, size=cur_fpn.shape[-2:], mode='bilinear', align_corners=False) 236 | mask_feature = self.out_conv(mask_feature) 237 | if hasattr(self, 'fusion_block'): 238 | mask_feature = self.fusion_block(mask_feature, audio_feat) 239 | 240 | # predict output mask 241 | pred_feature = torch.einsum( 242 | 'bqc,bchw->bqhw', outputs[-1], mask_feature) 243 | pred_feature = self.mlp(pred_feature) 244 | pred_mask = mask_feature + pred_feature 245 | pred_mask = self.fc(pred_mask) 246 | 247 | return pred_mask, mask_feature 248 | 249 | # def forward_prediction_head(self, output, mask_embed, spatial_shapes, level_start_index): 250 | # masks = torch.einsum('bqc,bqn->bcn', output, mask_embed) 251 | # splitted_masks = self.reform_output_squences( 252 | # masks, spatial_shapes, level_start_index, 2) 253 | 254 | # bs = output.shape[0] 255 | # reforms = [] 256 | # for i, embed in enumerate(splitted_masks): 257 | # embed = embed.view( 258 | # bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]) 259 | # reforms.append(embed) 260 | 261 | # attn_mask = self.fpn(reforms) 262 | # attn_mask = self.attn_fc(attn_mask) 263 | # return attn_mask 264 | -------------------------------------------------------------------------------- /model/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .AVSegHead import AVSegHead 2 | 3 | 4 | def build_head(type, **kwargs): 5 | if type == 'AVSegHead': 6 | return AVSegHead(**kwargs) 7 | else: 8 | raise ValueError 9 | 10 | 11 | __all__ = ['build_head'] 12 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import build_transformer 2 | from .positional_encoding import build_positional_encoding 3 | from .fusion_block import build_fusion_block 4 | from .query_generator import build_generator 5 | 6 | 7 | __all__ = ['build_transformer', 'build_positional_encoding', 8 | 'build_fusion_block', 'build_generator'] 9 | -------------------------------------------------------------------------------- /model/utils/fusion_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CrossModalMixer(nn.Module): 6 | def __init__(self, dim=256, n_heads=8, qkv_bias=False, dropout=0.): 7 | super().__init__() 8 | 9 | self.dim = dim 10 | self.n_heads = n_heads 11 | self.dropout = dropout 12 | self.scale = (dim // n_heads)**-0.5 13 | 14 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 15 | self.kv_proj = nn.Linear(dim, dim * 2, bias=qkv_bias) 16 | self.attn_drop = nn.Dropout(dropout) 17 | self.proj = nn.Linear(dim, dim) 18 | self.proj_drop = nn.Dropout(dropout) 19 | 20 | def forward(self, feature_map, audio_feature): 21 | """channel attention for modality fusion 22 | 23 | Args: 24 | feature_map (Tensor): (bs, c, h, w) 25 | audio_feature (Tensor): (bs, 1, c) 26 | 27 | Returns: 28 | Tensor: (bs, c, h, w) 29 | """ 30 | flatten_map = feature_map.flatten(2).transpose(1, 2) 31 | B, N, C = flatten_map.shape 32 | 33 | q = self.q_proj(audio_feature).reshape( 34 | B, 1, self.n_heads, C // self.n_heads).permute(0, 2, 1, 3) 35 | kv = self.kv_proj(flatten_map).reshape( 36 | B, N, 2, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) 37 | k, v = kv.unbind(0) 38 | attn = (q @ k.transpose(-2, -1)) * self.scale 39 | attn = attn.softmax(dim=-1) 40 | x = (attn @ v).transpose(1, 2).reshape(B, 1, C) 41 | x = self.proj_drop(self.proj(x)) 42 | 43 | x = x.sigmoid() 44 | fusion_map = torch.einsum('bchw,bc->bchw', feature_map, x.squeeze()) 45 | return fusion_map 46 | 47 | 48 | def build_fusion_block(type, **kwargs): 49 | if type == 'CrossModalMixer': 50 | return CrossModalMixer(**kwargs) 51 | else: 52 | raise ValueError 53 | -------------------------------------------------------------------------------- /model/utils/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SinePositionalEncoding(nn.Module): 7 | """Position encoding with sine and cosine functions. 8 | 9 | See `End-to-End Object Detection with Transformers 10 | `_ for details. 11 | 12 | Args: 13 | num_feats (int): The feature dimension for each position 14 | along x-axis or y-axis. Note the final returned dimension 15 | for each position is 2 times of this value. 16 | temperature (int, optional): The temperature used for scaling 17 | the position embedding. Defaults to 10000. 18 | normalize (bool, optional): Whether to normalize the position 19 | embedding. Defaults to False. 20 | scale (float, optional): A scale factor that scales the position 21 | embedding. The scale will be used only when `normalize` is True. 22 | Defaults to 2*pi. 23 | eps (float, optional): A value added to the denominator for 24 | numerical stability. Defaults to 1e-6. 25 | offset (float): offset add to embed when do the normalization. 26 | Defaults to 0. 27 | init_cfg (dict or list[dict], optional): Initialization config dict. 28 | Default: None 29 | """ 30 | 31 | def __init__(self, 32 | num_feats, 33 | temperature=10000, 34 | normalize=False, 35 | scale=2 * math.pi, 36 | eps=1e-6, 37 | offset=0.): 38 | super().__init__() 39 | if normalize: 40 | assert isinstance(scale, (float, int)), 'when normalize is set,' \ 41 | 'scale should be provided and in float or int type, ' \ 42 | f'found {type(scale)}' 43 | self.num_feats = num_feats 44 | self.temperature = temperature 45 | self.normalize = normalize 46 | self.scale = scale 47 | self.eps = eps 48 | self.offset = offset 49 | 50 | def forward(self, mask): 51 | """Forward function for `SinePositionalEncoding`. 52 | 53 | Args: 54 | mask (Tensor): ByteTensor mask. Non-zero values representing 55 | ignored positions, while zero values means valid positions 56 | for this image. Shape [bs, h, w]. 57 | 58 | Returns: 59 | pos (Tensor): Returned position embedding with shape 60 | [bs, num_feats*2, h, w]. 61 | """ 62 | # For convenience of exporting to ONNX, it's required to convert 63 | # `masks` from bool to int. 64 | mask = mask.to(torch.int) 65 | not_mask = 1 - mask # logical_not 66 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 67 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 68 | if self.normalize: 69 | y_embed = (y_embed + self.offset) / \ 70 | (y_embed[:, -1:, :] + self.eps) * self.scale 71 | x_embed = (x_embed + self.offset) / \ 72 | (x_embed[:, :, -1:] + self.eps) * self.scale 73 | dim_t = torch.arange( 74 | self.num_feats, dtype=torch.float32, device=mask.device) 75 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) 76 | pos_x = x_embed[:, :, :, None] / dim_t 77 | pos_y = y_embed[:, :, :, None] / dim_t 78 | # use `view` instead of `flatten` for dynamically exporting to ONNX 79 | B, H, W = mask.size() 80 | pos_x = torch.stack( 81 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 82 | dim=4).view(B, H, W, -1) 83 | pos_y = torch.stack( 84 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 85 | dim=4).view(B, H, W, -1) 86 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 87 | return pos 88 | 89 | 90 | def build_positional_encoding(type, **kwargs): 91 | if type == 'SinePositionalEncoding': 92 | return SinePositionalEncoding(**kwargs) 93 | else: 94 | raise ValueError 95 | -------------------------------------------------------------------------------- /model/utils/query_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RepeatGenerator(nn.Module): 6 | def __init__(self, query_num) -> None: 7 | super().__init__() 8 | self.query_num = query_num 9 | 10 | def forward(self, audio_feat): 11 | return audio_feat.repeat(1, self.query_num, 1) 12 | 13 | 14 | class AttentionLayer(nn.Module): 15 | def __init__(self, embed_dim, num_heads, hidden_dim) -> None: 16 | super().__init__() 17 | self.self_attn = nn.MultiheadAttention( 18 | embed_dim, num_heads, bias=False, batch_first=True) 19 | self.cross_attn = nn.MultiheadAttention( 20 | embed_dim, num_heads, bias=False, batch_first=True) 21 | self.ffn = nn.Sequential( 22 | nn.Linear(embed_dim, hidden_dim), 23 | nn.GELU(), 24 | nn.Linear(hidden_dim, embed_dim) 25 | ) 26 | self.norm1 = nn.LayerNorm(embed_dim) 27 | self.norm2 = nn.LayerNorm(embed_dim) 28 | self.norm3 = nn.LayerNorm(embed_dim) 29 | 30 | def forward(self, query, audio_feat): 31 | out1 = self.self_attn(query, query, query)[0] 32 | query = self.norm1(query+out1) 33 | out2 = self.cross_attn(query, audio_feat, audio_feat)[0] 34 | query = self.norm2(query+out2) 35 | out3 = self.ffn(query) 36 | query = self.norm3(query+out3) 37 | return query 38 | 39 | 40 | class AttentionGenerator(nn.Module): 41 | def __init__(self, num_layers, query_num, embed_dim=256, num_heads=8, hidden_dim=1024): 42 | super().__init__() 43 | self.num_layers = num_layers 44 | self.query_num = query_num 45 | self.embed_dim = embed_dim 46 | self.query = nn.Embedding(query_num, embed_dim) 47 | self.layers = nn.ModuleList( 48 | [AttentionLayer(embed_dim, num_heads, hidden_dim) 49 | for i in range(num_layers)] 50 | ) 51 | 52 | self._reset_parameters() 53 | 54 | def _reset_parameters(self): 55 | for p in self.parameters(): 56 | if p.dim() > 1: 57 | nn.init.xavier_uniform_(p) 58 | 59 | def forward(self, audio_feat): 60 | bs = audio_feat.shape[0] 61 | query = self.query.weight[None, :, :].repeat(bs, 1, 1) 62 | for layer in self.layers: 63 | query = layer(query, audio_feat) 64 | return query 65 | 66 | 67 | def build_generator(type, **kwargs): 68 | if type == 'AttentionGenerator': 69 | return AttentionGenerator(**kwargs) 70 | elif type == 'RepeatGenerator': 71 | return RepeatGenerator(**kwargs) 72 | else: 73 | raise ValueError 74 | -------------------------------------------------------------------------------- /model/utils/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ops.modules import MSDeformAttn 4 | 5 | 6 | class AVSTransformerEncoderLayer(nn.Module): 7 | def __init__(self, dim=256, ffn_dim=2048, dropout=0.1, num_levels=3, num_heads=8, num_points=4): 8 | super().__init__() 9 | 10 | # self attention 11 | self.self_attn = MSDeformAttn(dim, num_levels, num_heads, num_points) 12 | self.dropout1 = nn.Dropout(dropout) 13 | self.norm1 = nn.LayerNorm(dim) 14 | 15 | # ffn 16 | self.linear1 = nn.Linear(dim, ffn_dim) 17 | self.activation = nn.GELU() 18 | self.dropout2 = nn.Dropout(dropout) 19 | self.linear2 = nn.Linear(ffn_dim, dim) 20 | self.dropout3 = nn.Dropout(dropout) 21 | self.norm2 = nn.LayerNorm(dim) 22 | 23 | @staticmethod 24 | def with_pos_embed(tensor, pos): 25 | return tensor if pos is None else tensor + pos 26 | 27 | def ffn(self, src): 28 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 29 | src = src + self.dropout3(src2) 30 | src = self.norm2(src) 31 | return src 32 | 33 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): 34 | # self attention 35 | src2 = self.self_attn(self.with_pos_embed( 36 | src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) 37 | src = src + self.dropout1(src2) 38 | src = self.norm1(src) 39 | # ffn 40 | src = self.ffn(src) 41 | return src 42 | 43 | 44 | class AVSTransformerEncoder(nn.Module): 45 | def __init__(self, num_layers, layer, *args, **kwargs) -> None: 46 | super().__init__() 47 | 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList( 50 | [AVSTransformerEncoderLayer(**layer) for i in range(num_layers)] 51 | ) 52 | 53 | @staticmethod 54 | def get_reference_points(spatial_shapes, valid_ratios, device): 55 | reference_points_list = [] 56 | for lvl, (H_, W_) in enumerate(spatial_shapes): 57 | 58 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 59 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 60 | ref_y = ref_y.reshape(-1)[None] / \ 61 | (valid_ratios[:, None, lvl, 1] * H_) 62 | ref_x = ref_x.reshape(-1)[None] / \ 63 | (valid_ratios[:, None, lvl, 0] * W_) 64 | ref = torch.stack((ref_x, ref_y), -1) 65 | reference_points_list.append(ref) 66 | reference_points = torch.cat(reference_points_list, 1) 67 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 68 | return reference_points 69 | 70 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 71 | out = src 72 | reference_points = self.get_reference_points( 73 | spatial_shapes, valid_ratios, device=src.device) 74 | for layer in self.layers: 75 | out = layer(out, pos, reference_points, spatial_shapes, 76 | level_start_index, padding_mask) 77 | return out, reference_points 78 | 79 | 80 | class AVSTransformerDecoderLayer(nn.Module): 81 | def __init__(self, dim=256, num_heads=8, ffn_dim=2048, dropout=0.1, num_levels=3, num_points=4, *args, **kwargs) -> None: 82 | super().__init__() 83 | 84 | # self attention 85 | self.self_attn = nn.MultiheadAttention( 86 | dim, num_heads, dropout=dropout, batch_first=True) 87 | self.dropout1 = nn.Dropout(dropout) 88 | self.norm1 = nn.LayerNorm(dim) 89 | 90 | # cross attention 91 | # self.cross_attn = MSDeformAttn(dim, num_levels, num_heads, num_points) 92 | self.cross_attn = nn.MultiheadAttention( 93 | dim, num_heads, dropout=dropout, batch_first=True) 94 | self.dropout2 = nn.Dropout(dropout) 95 | self.norm2 = nn.LayerNorm(dim) 96 | 97 | # ffn 98 | self.linear1 = nn.Linear(dim, ffn_dim) 99 | self.activation = nn.GELU() 100 | self.dropout3 = nn.Dropout(dropout) 101 | self.linear2 = nn.Linear(ffn_dim, dim) 102 | self.dropout4 = nn.Dropout(dropout) 103 | self.norm3 = nn.LayerNorm(dim) 104 | 105 | def ffn(self, src): 106 | src2 = self.linear2(self.dropout3(self.activation(self.linear1(src)))) 107 | src = src + self.dropout4(src2) 108 | src = self.norm3(src) 109 | return src 110 | 111 | def forward(self, query, src, reference_points, spatial_shapes, level_start_index, padding_mask=None): 112 | # self attention 113 | out1 = self.self_attn(query, query, query)[0] 114 | query = query + self.dropout1(out1) 115 | query = self.norm1(query) 116 | # cross attention 117 | out2 = self.cross_attn( 118 | query, src, src, key_padding_mask=padding_mask)[0] 119 | query = query + self.dropout2(out2) 120 | query = self.norm2(query) 121 | # ffn 122 | query = self.ffn(query) 123 | return query 124 | 125 | 126 | class AVSTransformerDecoder(nn.Module): 127 | def __init__(self, num_layers, layer, *args, **kwargs) -> None: 128 | super().__init__() 129 | 130 | self.num_layers = num_layers 131 | self.layers = nn.ModuleList( 132 | [AVSTransformerDecoderLayer(**layer) for i in range(num_layers)] 133 | ) 134 | 135 | def forward(self, query, src, reference_points, spatial_shapes, level_start_index, padding_mask=None): 136 | out = query 137 | outputs = [] 138 | for layer in self.layers: 139 | out = layer(out, src, reference_points, spatial_shapes, 140 | level_start_index, padding_mask) 141 | outputs.append(out) 142 | return outputs 143 | 144 | 145 | class AVSTransformer(nn.Module): 146 | def __init__(self, encoder, decoder, *args, **kwargs) -> None: 147 | super().__init__() 148 | 149 | self.encoder = AVSTransformerEncoder(**encoder) 150 | self.decoder = AVSTransformerDecoder(**decoder) 151 | 152 | def _reset_parameters(self): 153 | for p in self.parameters(): 154 | if p.dim() > 1: 155 | nn.init.xavier_uniform_(p) 156 | 157 | def forward(self, query, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 158 | memory, reference_points = self.encoder( 159 | src, spatial_shapes, level_start_index, valid_ratios, pos, padding_mask) 160 | outputs = self.decoder(query, memory, reference_points, 161 | spatial_shapes, level_start_index, padding_mask) 162 | return memory, outputs 163 | 164 | 165 | def build_transformer(type, **kwargs): 166 | if type == 'AVSTransformer': 167 | return AVSTransformer(**kwargs) 168 | else: 169 | raise ValueError 170 | -------------------------------------------------------------------------------- /model/vggish/__init__.py: -------------------------------------------------------------------------------- 1 | from .vggish import VGGish 2 | 3 | 4 | __all__=['VGGish'] -------------------------------------------------------------------------------- /model/vggish/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /model/vggish/vggish.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from . import vggish_input, vggish_params 5 | 6 | 7 | class VGG(nn.Module): 8 | def __init__(self, features): 9 | super(VGG, self).__init__() 10 | self.features = features 11 | self.embeddings = nn.Sequential( 12 | nn.Linear(512 * 4 * 6, 4096), 13 | nn.ReLU(True), 14 | nn.Linear(4096, 4096), 15 | nn.ReLU(True), 16 | nn.Linear(4096, 128), 17 | nn.ReLU(True)) 18 | 19 | def forward(self, x): 20 | x = self.features(x) 21 | 22 | # Transpose the output from features to 23 | # remain compatible with vggish embeddings 24 | x = torch.transpose(x, 1, 3) 25 | x = torch.transpose(x, 1, 2) 26 | x = x.contiguous() 27 | x = x.view(x.size(0), -1) 28 | 29 | return self.embeddings(x) 30 | 31 | 32 | class Postprocessor(nn.Module): 33 | """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a 34 | numpy array in order to preserve the gradient. 35 | 36 | "The initial release of AudioSet included 128-D VGGish embeddings for each 37 | segment of AudioSet. These released embeddings were produced by applying 38 | a PCA transformation (technically, a whitening transform is included as well) 39 | and 8-bit quantization to the raw embedding output from VGGish, in order to 40 | stay compatible with the YouTube-8M project which provides visual embeddings 41 | in the same format for a large set of YouTube videos. This class implements 42 | the same PCA (with whitening) and quantization transformations." 43 | """ 44 | 45 | def __init__(self): 46 | """Constructs a postprocessor.""" 47 | super(Postprocessor, self).__init__() 48 | # Create empty matrix, for user's state_dict to load 49 | self.pca_eigen_vectors = torch.empty( 50 | (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), 51 | dtype=torch.float, 52 | ) 53 | self.pca_means = torch.empty( 54 | (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float 55 | ) 56 | 57 | self.pca_eigen_vectors = nn.Parameter( 58 | self.pca_eigen_vectors, requires_grad=False) 59 | self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) 60 | 61 | def postprocess(self, embeddings_batch): 62 | """Applies tensor postprocessing to a batch of embeddings. 63 | 64 | Args: 65 | embeddings_batch: An tensor of shape [batch_size, embedding_size] 66 | containing output from the embedding layer of VGGish. 67 | 68 | Returns: 69 | A tensor of the same shape as the input, containing the PCA-transformed, 70 | quantized, and clipped version of the input. 71 | """ 72 | assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( 73 | embeddings_batch.shape, 74 | ) 75 | assert ( 76 | embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE 77 | ), "Bad batch shape: %r" % (embeddings_batch.shape,) 78 | 79 | # Apply PCA. 80 | # - Embeddings come in as [batch_size, embedding_size]. 81 | # - Transpose to [embedding_size, batch_size]. 82 | # - Subtract pca_means column vector from each column. 83 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 84 | # where both are are equal to embedding_size in our case. 85 | # - Transpose result back to [batch_size, embedding_size]. 86 | pca_applied = torch.mm(self.pca_eigen_vectors, 87 | (embeddings_batch.t() - self.pca_means)).t() 88 | 89 | # Quantize by: 90 | # - clipping to [min, max] range 91 | clipped_embeddings = torch.clamp( 92 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL 93 | ) 94 | # - convert to 8-bit in range [0.0, 255.0] 95 | quantized_embeddings = torch.round( 96 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) 97 | * ( 98 | 255.0 99 | / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) 100 | ) 101 | ) 102 | return torch.squeeze(quantized_embeddings) 103 | 104 | def forward(self, x): 105 | return self.postprocess(x) 106 | 107 | 108 | def make_layers(): 109 | layers = [] 110 | in_channels = 1 111 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | layers += [conv2d, nn.ReLU(inplace=True)] 117 | in_channels = v 118 | return nn.Sequential(*layers) 119 | 120 | 121 | def _vgg(): 122 | return VGG(make_layers()) 123 | 124 | 125 | class VGGish(VGG): 126 | def __init__(self, 127 | freeze_audio_extractor, 128 | pretrained_vggish_model_path, 129 | preprocess_audio_to_log_mel, 130 | postprocess_log_mel_with_pca, 131 | pretrained_pca_params_path, 132 | *args, **kwargs): 133 | super().__init__(make_layers()) 134 | if freeze_audio_extractor: 135 | state_dict = torch.load(pretrained_vggish_model_path) 136 | super().load_state_dict(state_dict) 137 | 138 | self.preprocess = preprocess_audio_to_log_mel 139 | self.postprocess = postprocess_log_mel_with_pca 140 | if self.postprocess: 141 | self.pproc = Postprocessor() 142 | if freeze_audio_extractor: 143 | state_dict = torch.load(pretrained_pca_params_path) 144 | # TODO: Convert the state_dict to torch 145 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( 146 | state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float 147 | ) 148 | state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( 149 | state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float 150 | ) 151 | self.pproc.load_state_dict(state_dict) 152 | 153 | def forward(self, x): 154 | if self.preprocess: 155 | x = self._preprocess(x) 156 | x = VGG.forward(self, x) 157 | if self.postprocess: 158 | x = self._postprocess(x) 159 | return x 160 | 161 | def _preprocess(self, x): 162 | if isinstance(x, str): 163 | x = vggish_input.waveform_to_examples(x) 164 | return x 165 | else: 166 | batch_num = len(x) 167 | audio_fea_list = [] 168 | for xx in x: 169 | if isinstance(xx, str): 170 | xx = vggish_input.wavfile_to_examples( 171 | xx) # [5 or 10, 1, 96, 64] 172 | #! notice: 173 | if xx.shape[0] != 10: 174 | new_xx = torch.zeros(10, 1, 96, 64) 175 | new_xx[:xx.shape[0]] = xx 176 | audio_fea_list.append(new_xx) 177 | else: 178 | audio_fea_list.append(xx) 179 | 180 | # [bs, 10, 1, 96, 64] 181 | audio_fea = torch.stack(audio_fea_list, dim=0) 182 | audio_fea = audio_fea.view( 183 | batch_num * 10, xx.shape[1], xx.shape[2], xx.shape[3]) # [bs*10, 1, 96, 64] 184 | audio_fea = audio_fea.cuda() 185 | return audio_fea 186 | 187 | def _postprocess(self, x): 188 | return self.pproc(x) 189 | -------------------------------------------------------------------------------- /model/vggish/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | # Modification: Return torch tensors rather than numpy arrays 19 | import torch 20 | 21 | import numpy as np 22 | import resampy 23 | 24 | from . import mel_features 25 | from . import vggish_params 26 | 27 | import soundfile as sf 28 | 29 | 30 | def waveform_to_examples(data, sample_rate, return_tensor=True): 31 | """Converts audio waveform into an array of examples for VGGish. 32 | 33 | Args: 34 | data: np.array of either one dimension (mono) or two dimensions 35 | (multi-channel, with the outer dimension representing channels). 36 | Each sample is generally expected to lie in the range [-1.0, +1.0], 37 | although this is not required. 38 | sample_rate: Sample rate of data. 39 | return_tensor: Return data as a Pytorch tensor ready for VGGish 40 | 41 | Returns: 42 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 43 | a sequence of examples, each of which contains a patch of log mel 44 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 45 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 46 | 47 | """ 48 | # Convert to mono. 49 | if len(data.shape) > 1: 50 | data = np.mean(data, axis=1) 51 | # Resample to the rate assumed by VGGish. 52 | if sample_rate != vggish_params.SAMPLE_RATE: 53 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 54 | 55 | # Compute log mel spectrogram features. 56 | log_mel = mel_features.log_mel_spectrogram( 57 | data, 58 | audio_sample_rate=vggish_params.SAMPLE_RATE, 59 | log_offset=vggish_params.LOG_OFFSET, 60 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 61 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 62 | num_mel_bins=vggish_params.NUM_MEL_BINS, 63 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 64 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 65 | 66 | # Frame features into examples. 67 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 68 | example_window_length = int(round( 69 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 70 | example_hop_length = int(round( 71 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 72 | log_mel_examples = mel_features.frame( 73 | log_mel, 74 | window_length=example_window_length, 75 | hop_length=example_hop_length) 76 | 77 | if return_tensor: 78 | log_mel_examples = torch.tensor( 79 | log_mel_examples, requires_grad=True)[:, None, :, :].float() 80 | 81 | return log_mel_examples 82 | 83 | 84 | def wavfile_to_examples(wav_file, return_tensor=True): 85 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 86 | 87 | Args: 88 | wav_file: String path to a file, or a file-like object. The file 89 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 90 | torch: Return data as a Pytorch tensor ready for VGGish 91 | 92 | Returns: 93 | See waveform_to_examples. 94 | """ 95 | wav_data, sr = sf.read(wav_file, dtype='int16') 96 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 97 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 98 | return waveform_to_examples(samples, sr, return_tensor) 99 | -------------------------------------------------------------------------------- /model/vggish/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 16000 28 | STFT_WINDOW_LENGTH_SECONDS = 0.025 29 | STFT_HOP_LENGTH_SECONDS = 0.010 30 | NUM_MEL_BINS = NUM_BANDS 31 | MEL_MIN_HZ = 125 32 | MEL_MAX_HZ = 7500 33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 36 | 37 | # Parameters used for embedding postprocessing. 38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 39 | PCA_MEANS_NAME = 'pca_means' 40 | QUANTIZE_MIN_VAL = -2.0 41 | QUANTIZE_MAX_VAL = +2.0 42 | 43 | # Hyperparameters used in training. 44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 47 | 48 | # Names of ops, tensors, and features. 49 | INPUT_OP_NAME = 'vggish/input_features' 50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 51 | OUTPUT_OP_NAME = 'vggish/embedding' 52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 54 | -------------------------------------------------------------------------------- /ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd maskdino/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import warnings 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.nn.init import xavier_uniform_, constant_ 23 | 24 | from ..functions import MSDeformAttnFunction 25 | from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch 26 | 27 | 28 | def _is_power_of_2(n): 29 | if (not isinstance(n, int)) or (n < 0): 30 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 31 | return (n & (n-1) == 0) and n != 0 32 | 33 | 34 | class MSDeformAttn(nn.Module): 35 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 36 | """ 37 | Multi-Scale Deformable Attention Module 38 | :param d_model hidden dimension 39 | :param n_levels number of feature levels 40 | :param n_heads number of attention heads 41 | :param n_points number of sampling points per attention head per feature level 42 | """ 43 | super().__init__() 44 | if d_model % n_heads != 0: 45 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 46 | _d_per_head = d_model // n_heads 47 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 48 | if not _is_power_of_2(_d_per_head): 49 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 50 | "which is more efficient in our CUDA implementation.") 51 | 52 | self.im2col_step = 128 53 | 54 | self.d_model = d_model 55 | self.n_levels = n_levels 56 | self.n_heads = n_heads 57 | self.n_points = n_points 58 | 59 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 60 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 61 | self.value_proj = nn.Linear(d_model, d_model) 62 | self.output_proj = nn.Linear(d_model, d_model) 63 | 64 | self._reset_parameters() 65 | 66 | def _reset_parameters(self): 67 | constant_(self.sampling_offsets.weight.data, 0.) 68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | with torch.no_grad(): 74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 75 | constant_(self.attention_weights.weight.data, 0.) 76 | constant_(self.attention_weights.bias.data, 0.) 77 | xavier_uniform_(self.value_proj.weight.data) 78 | constant_(self.value_proj.bias.data, 0.) 79 | xavier_uniform_(self.output_proj.weight.data) 80 | constant_(self.output_proj.bias.data, 0.) 81 | 82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 83 | """ 84 | :param query (N, Length_{query}, C) 85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 91 | 92 | :return output (N, Length_{query}, C) 93 | """ 94 | N, Len_q, _ = query.shape 95 | N, Len_in, _ = input_flatten.shape 96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 97 | 98 | value = self.value_proj(input_flatten) 99 | if input_padding_mask is not None: 100 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 105 | # N, Len_q, n_heads, n_levels, n_points, 2 106 | if reference_points.shape[-1] == 2: 107 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 108 | sampling_locations = reference_points[:, :, None, :, None, :] \ 109 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 110 | elif reference_points.shape[-1] == 4: 111 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 112 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 113 | else: 114 | raise ValueError( 115 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 116 | try: 117 | output = MSDeformAttnFunction.apply( 118 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 119 | except: 120 | # CPU 121 | output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) 122 | # # For FLOPs calculation only 123 | # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) 124 | output = self.output_proj(output) 125 | return output 126 | -------------------------------------------------------------------------------- /ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | #include "cuda/ms_deform_im2col_cuda.cuh" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | at::Tensor ms_deform_attn_cuda_forward( 26 | const at::Tensor &value, 27 | const at::Tensor &spatial_shapes, 28 | const at::Tensor &level_start_index, 29 | const at::Tensor &sampling_loc, 30 | const at::Tensor &attn_weight, 31 | const int im2col_step) 32 | { 33 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 34 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 35 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 36 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 37 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 38 | 39 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 40 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 41 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 42 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 43 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 44 | 45 | const int batch = value.size(0); 46 | const int spatial_size = value.size(1); 47 | const int num_heads = value.size(2); 48 | const int channels = value.size(3); 49 | 50 | const int num_levels = spatial_shapes.size(0); 51 | 52 | const int num_query = sampling_loc.size(1); 53 | const int num_point = sampling_loc.size(4); 54 | 55 | const int im2col_step_ = std::min(batch, im2col_step); 56 | 57 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 58 | 59 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 60 | 61 | const int batch_n = im2col_step_; 62 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 63 | auto per_value_size = spatial_size * num_heads * channels; 64 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 65 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 66 | for (int n = 0; n < batch/im2col_step_; ++n) 67 | { 68 | auto columns = output_n.select(0, n); 69 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 70 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 71 | value.data() + n * im2col_step_ * per_value_size, 72 | spatial_shapes.data(), 73 | level_start_index.data(), 74 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 75 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 76 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 77 | columns.data()); 78 | 79 | })); 80 | } 81 | 82 | output = output.view({batch, num_query, num_heads*channels}); 83 | 84 | return output; 85 | } 86 | 87 | 88 | std::vector ms_deform_attn_cuda_backward( 89 | const at::Tensor &value, 90 | const at::Tensor &spatial_shapes, 91 | const at::Tensor &level_start_index, 92 | const at::Tensor &sampling_loc, 93 | const at::Tensor &attn_weight, 94 | const at::Tensor &grad_output, 95 | const int im2col_step) 96 | { 97 | 98 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 99 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 100 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 101 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 102 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 103 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 104 | 105 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 106 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 107 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 108 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 109 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 110 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 111 | 112 | const int batch = value.size(0); 113 | const int spatial_size = value.size(1); 114 | const int num_heads = value.size(2); 115 | const int channels = value.size(3); 116 | 117 | const int num_levels = spatial_shapes.size(0); 118 | 119 | const int num_query = sampling_loc.size(1); 120 | const int num_point = sampling_loc.size(4); 121 | 122 | const int im2col_step_ = std::min(batch, im2col_step); 123 | 124 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 125 | 126 | auto grad_value = at::zeros_like(value); 127 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 128 | auto grad_attn_weight = at::zeros_like(attn_weight); 129 | 130 | const int batch_n = im2col_step_; 131 | auto per_value_size = spatial_size * num_heads * channels; 132 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 133 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 134 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 135 | 136 | for (int n = 0; n < batch/im2col_step_; ++n) 137 | { 138 | auto grad_output_g = grad_output_n.select(0, n); 139 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 140 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 141 | grad_output_g.data(), 142 | value.data() + n * im2col_step_ * per_value_size, 143 | spatial_shapes.data(), 144 | level_start_index.data(), 145 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 146 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 147 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 148 | grad_value.data() + n * im2col_step_ * per_value_size, 149 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 150 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 151 | 152 | })); 153 | } 154 | 155 | return { 156 | grad_value, grad_sampling_loc, grad_attn_weight 157 | }; 158 | } -------------------------------------------------------------------------------- /ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /scripts/avss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def F10_IoU_BCELoss(pred_mask, ten_gt_masks, gt_temporal_mask_flag): 7 | """ 8 | binary cross entropy loss (iou loss) of the total ten frames for multiple sound source segmentation 9 | 10 | Args: 11 | pred_mask: predicted masks for a batch of data, shape:[bs*10, N_CLASSES, 224, 224] 12 | ten_gt_masks: ground truth mask of the total ten frames, shape: [bs*10, 224, 224] 13 | """ 14 | assert len(pred_mask.shape) == 4 15 | if ten_gt_masks.shape[1] == 1: 16 | ten_gt_masks = ten_gt_masks.squeeze(1) 17 | 18 | loss = nn.CrossEntropyLoss(reduction='none')( 19 | pred_mask, ten_gt_masks) # [bs*10, 224, 224] 20 | loss = loss.mean(-1).mean(-1) # [bs*10] 21 | loss = loss * gt_temporal_mask_flag # [bs*10] 22 | loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag) 23 | 24 | return loss 25 | 26 | 27 | def Mix_Dice_loss(pred_mask, norm_gt_mask, gt_temporal_mask_flag): 28 | """dice loss for aux loss 29 | 30 | Args: 31 | pred_mask (Tensor): (bs, 1, h, w) 32 | five_gt_masks (Tensor): (bs, 1, h, w) 33 | """ 34 | assert len(pred_mask.shape) == 4 35 | pred_mask = torch.sigmoid(pred_mask) 36 | 37 | pred_mask = pred_mask.flatten(1) 38 | gt_mask = norm_gt_mask.flatten(1) 39 | a = (pred_mask * gt_mask).sum(-1) 40 | b = (pred_mask * pred_mask).sum(-1) + 0.001 41 | c = (gt_mask * gt_mask).sum(-1) + 0.001 42 | d = (2 * a) / (b + c) 43 | loss = 1 - d 44 | loss = loss * gt_temporal_mask_flag 45 | loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag) 46 | return loss 47 | 48 | 49 | def IouSemanticAwareLoss(pred_masks, mask_feature, gt_mask, gt_temporal_mask_flag, weight_dict, **kwargs): 50 | total_loss = 0 51 | loss_dict = {} 52 | 53 | iou_loss = weight_dict['iou_loss'] * \ 54 | F10_IoU_BCELoss(pred_masks, gt_mask, gt_temporal_mask_flag) 55 | total_loss += iou_loss 56 | loss_dict['iou_loss'] = iou_loss.item() 57 | 58 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True) 59 | mask_feature = F.interpolate( 60 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False) 61 | one_mask = torch.ones_like(gt_mask) 62 | norm_gt_mask = torch.where(gt_mask > 0, one_mask, gt_mask) 63 | mix_loss = weight_dict['mix_loss'] * \ 64 | Mix_Dice_loss(mask_feature, norm_gt_mask, gt_temporal_mask_flag) 65 | total_loss += mix_loss 66 | loss_dict['mix_loss'] = mix_loss.item() 67 | 68 | return total_loss, loss_dict 69 | -------------------------------------------------------------------------------- /scripts/avss/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import os 4 | from mmcv import Config 5 | import argparse 6 | from utils.vis_mask import save_color_mask 7 | from utils.compute_color_metrics import calc_color_miou_fscore 8 | from utils.logger import getLogger 9 | from model import build_model 10 | from dataloader import build_dataset, get_v2_pallete 11 | 12 | 13 | def main(): 14 | # logger 15 | logger = getLogger(None, __name__) 16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 17 | logger.info(f'Load config from {args.cfg}') 18 | 19 | # config 20 | cfg = Config.fromfile(args.cfg) 21 | logger.info(cfg.pretty_text) 22 | 23 | # model 24 | model = build_model(**cfg.model) 25 | model.load_state_dict(torch.load(args.weights)) 26 | model = torch.nn.DataParallel(model).cuda() 27 | model.eval() 28 | logger.info('Load trained model %s' % args.weights) 29 | 30 | # Test data 31 | test_dataset = build_dataset(**cfg.dataset.test) 32 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 33 | batch_size=cfg.dataset.test.batch_size, 34 | shuffle=False, 35 | num_workers=cfg.process.num_works, 36 | pin_memory=True) 37 | N_CLASSES = test_dataset.num_classes 38 | 39 | # for save predicted rgb masks 40 | v2_pallete = get_v2_pallete(cfg.dataset.test.label_idx_path) 41 | resize_pred_mask = cfg.dataset.test.resize_pred_mask 42 | if resize_pred_mask: 43 | pred_mask_img_size = cfg.dataset.test.save_pred_mask_img_size 44 | else: 45 | pred_mask_img_size = cfg.dataset.test.img_size 46 | 47 | # metrics 48 | miou_pc = torch.zeros((N_CLASSES)) # miou value per class (total sum) 49 | Fs_pc = torch.zeros((N_CLASSES)) # f-score per class (total sum) 50 | cls_pc = torch.zeros((N_CLASSES)) # count per class 51 | with torch.no_grad(): 52 | for n_iter, batch_data in enumerate(test_dataloader): 53 | imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, video_name_list = batch_data 54 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda() 55 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda() 56 | 57 | imgs = imgs.cuda() 58 | # audio = audio.cuda() 59 | mask = mask.cuda() 60 | B, frame, C, H, W = imgs.shape 61 | imgs = imgs.view(B * frame, C, H, W) 62 | mask = mask.view(B * frame, H, W) 63 | 64 | vid_temporal_mask_flag = vid_temporal_mask_flag.view( 65 | B * frame) # [B*T] 66 | gt_temporal_mask_flag = gt_temporal_mask_flag.view( 67 | B * frame) # [B*T] 68 | 69 | output, _ = model(audio, imgs, vid_temporal_mask_flag) 70 | if args.save_pred_mask: 71 | mask_save_path = os.path.join( 72 | args.save_dir, dir_name, 'pred_masks') 73 | save_color_mask(output, mask_save_path, video_name_list, 74 | v2_pallete, resize_pred_mask, pred_mask_img_size, T=10) 75 | 76 | _miou_pc, _fscore_pc, _cls_pc, _ = calc_color_miou_fscore( 77 | output, mask) 78 | # compute miou, J-measure 79 | miou_pc += _miou_pc 80 | cls_pc += _cls_pc 81 | # compute f-score, F-measure 82 | Fs_pc += _fscore_pc 83 | 84 | batch_iou = miou_pc / cls_pc 85 | batch_iou[torch.isnan(batch_iou)] = 0 86 | batch_iou = torch.sum(batch_iou) / torch.sum(cls_pc != 0) 87 | batch_fscore = Fs_pc / cls_pc 88 | batch_fscore[torch.isnan(batch_fscore)] = 0 89 | batch_fscore = torch.sum(batch_fscore) / torch.sum(cls_pc != 0) 90 | logger.info('n_iter: {}, iou: {}, F_score: {}, cls_num: {}'.format( 91 | n_iter, batch_iou, batch_fscore, torch.sum(cls_pc != 0).item())) 92 | 93 | miou_pc = miou_pc / cls_pc 94 | logger.info( 95 | f"[test miou] {torch.sum(torch.isnan(miou_pc)).item()} classes are not predicted in this batch") 96 | miou_pc[torch.isnan(miou_pc)] = 0 97 | miou = torch.mean(miou_pc).item() 98 | miou_noBg = torch.mean(miou_pc[:-1]).item() 99 | f_score_pc = Fs_pc / cls_pc 100 | logger.info( 101 | f"[test fscore] {torch.sum(torch.isnan(f_score_pc)).item()} classes are not predicted in this batch") 102 | f_score_pc[torch.isnan(f_score_pc)] = 0 103 | f_score = torch.mean(f_score_pc).item() 104 | f_score_noBg = torch.mean(f_score_pc[:-1]).item() 105 | 106 | logger.info('test | cls {}, miou: {:.4f}, miou_noBg: {:.4f}, F_score: {:.4f}, F_score_noBg: {:.4f}'.format( 107 | torch.sum(cls_pc != 0).item(), miou, miou_noBg, f_score, f_score_noBg)) 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('cfg', type=str, help='config file path') 113 | parser.add_argument('weights', type=str, help='model weights path') 114 | parser.add_argument("--save_pred_mask", action='store_true', 115 | default=False, help="save predited masks or not") 116 | parser.add_argument('--save_dir', type=str, 117 | default='work_dir', help='save path') 118 | 119 | args = parser.parse_args() 120 | main() 121 | -------------------------------------------------------------------------------- /scripts/avss/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn 4 | import os 5 | import random 6 | import numpy as np 7 | from mmcv import Config 8 | import argparse 9 | from utils import pyutils 10 | from utils.loss_util import LossUtil 11 | from utils.logger import getLogger 12 | from model import build_model 13 | from dataloader import build_dataset 14 | from loss import IouSemanticAwareLoss 15 | from utils.compute_color_metrics import calc_color_miou_fscore 16 | 17 | 18 | def main(): 19 | # Fix seed 20 | FixSeed = 123 21 | random.seed(FixSeed) 22 | np.random.seed(FixSeed) 23 | torch.manual_seed(FixSeed) 24 | torch.cuda.manual_seed(FixSeed) 25 | 26 | # logger 27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime()) 28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 29 | if not os.path.exists(args.log_dir): 30 | os.mkdir(args.log_dir) 31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)): 32 | os.mkdir(os.path.join(args.log_dir, dir_name)) 33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log') 34 | logger = getLogger(log_file, __name__) 35 | logger.info(f'Load config from {args.cfg}') 36 | 37 | # config 38 | cfg = Config.fromfile(args.cfg) 39 | logger.info(cfg.pretty_text) 40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name) 41 | 42 | # model 43 | model = build_model(**cfg.model) 44 | model = torch.nn.DataParallel(model).cuda() 45 | model.train() 46 | logger.info("Total params: %.2fM" % (sum(p.numel() 47 | for p in model.parameters()) / 1e6)) 48 | 49 | # dataset 50 | train_dataset = build_dataset(**cfg.dataset.train) 51 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 52 | batch_size=cfg.dataset.train.batch_size, 53 | shuffle=True, 54 | num_workers=cfg.process.num_works, 55 | pin_memory=True) 56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \ 57 | cfg.process.train_epochs 58 | val_dataset = build_dataset(**cfg.dataset.val) 59 | val_dataloader = torch.utils.data.DataLoader(val_dataset, 60 | batch_size=cfg.dataset.val.batch_size, 61 | shuffle=False, 62 | num_workers=cfg.process.num_works, 63 | pin_memory=True) 64 | N_CLASSES = train_dataset.num_classes 65 | 66 | # optimizer 67 | optimizer = pyutils.get_optimizer(model, cfg.optimizer) 68 | loss_util = LossUtil(**cfg.loss) 69 | 70 | # Train 71 | best_epoch = 0 72 | global_step = 0 73 | miou_list = [] 74 | max_miou = 0 75 | miou_noBg_list = [] 76 | fscore_list, fscore_noBg_list = [], [] 77 | max_fs, max_fs_noBg = 0, 0 78 | for epoch in range(cfg.process.train_epochs): 79 | if epoch == cfg.process.freeze_epochs: 80 | model.module.freeze_backbone(False) 81 | 82 | for n_iter, batch_data in enumerate(train_dataloader): 83 | imgs, audio, label, vid_temporal_mask_flag, gt_temporal_mask_flag, _ = batch_data 84 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda() 85 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda() 86 | 87 | imgs = imgs.cuda() 88 | # audio = audio.cuda() 89 | label = label.cuda() 90 | B, frame, C, H, W = imgs.shape 91 | imgs = imgs.view(B * frame, C, H, W) 92 | mask_num = 10 93 | label = label.view(B * mask_num, H, W) 94 | vid_temporal_mask_flag = vid_temporal_mask_flag.view( 95 | B * frame) # [B*T] 96 | gt_temporal_mask_flag = gt_temporal_mask_flag.view( 97 | B * frame) # [B*T] 98 | 99 | # [bs*5, 24, 224, 224] 100 | output, mask_feature = model(audio, imgs, vid_temporal_mask_flag) 101 | loss, loss_dict = IouSemanticAwareLoss( 102 | output, mask_feature, label, gt_temporal_mask_flag, **cfg.loss) 103 | loss_util.add_loss(loss, loss_dict) 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | global_step += 1 109 | if (global_step - 1) % 20 == 0: 110 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % ( 111 | global_step - 1, 112 | max_step, 113 | loss_util.pretty_out(), 114 | optimizer.param_groups[0]['lr']) 115 | logger.info(train_log) 116 | 117 | # Validation: 118 | if epoch >= cfg.process.start_eval_epoch and epoch % cfg.process.eval_interval == 0: 119 | model.eval() 120 | 121 | miou_pc = torch.zeros((N_CLASSES)) 122 | Fs_pc = torch.zeros((N_CLASSES)) # f-score per class (total sum) 123 | cls_pc = torch.zeros((N_CLASSES)) # count per class 124 | with torch.no_grad(): 125 | for n_iter, batch_data in enumerate(val_dataloader): 126 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224] 127 | imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, _ = batch_data 128 | 129 | vid_temporal_mask_flag = vid_temporal_mask_flag.cuda() 130 | gt_temporal_mask_flag = gt_temporal_mask_flag.cuda() 131 | 132 | imgs = imgs.cuda() 133 | # audio = audio.cuda() 134 | mask = mask.cuda() 135 | B, frame, C, H, W = imgs.shape 136 | imgs = imgs.view(B * frame, C, H, W) 137 | mask = mask.view(B * frame, H, W) 138 | #! notice 139 | vid_temporal_mask_flag = vid_temporal_mask_flag.view( 140 | B * frame) # [B*T] 141 | gt_temporal_mask_flag = gt_temporal_mask_flag.view( 142 | B * frame) # [B*T] 143 | 144 | output, _ = model(audio, imgs, vid_temporal_mask_flag) 145 | 146 | _miou_pc, _fscore_pc, _cls_pc, _ = calc_color_miou_fscore( 147 | output, mask) 148 | # compute miou, J-measure 149 | miou_pc += _miou_pc 150 | cls_pc += _cls_pc 151 | # compute f-score, F-measure 152 | Fs_pc += _fscore_pc 153 | 154 | miou_pc = miou_pc / cls_pc 155 | logger.info( 156 | f"[miou] {torch.sum(torch.isnan(miou_pc)).item()} classes are not predicted in this batch") 157 | miou_pc[torch.isnan(miou_pc)] = 0 158 | miou = torch.mean(miou_pc).item() 159 | miou_noBg = torch.mean(miou_pc[:-1]).item() 160 | f_score_pc = Fs_pc / cls_pc 161 | logger.info( 162 | f"[fscore] {torch.sum(torch.isnan(f_score_pc)).item()} classes are not predicted in this batch") 163 | f_score_pc[torch.isnan(f_score_pc)] = 0 164 | f_score = torch.mean(f_score_pc).item() 165 | f_score_noBg = torch.mean(f_score_pc[:-1]).item() 166 | 167 | if miou > max_miou: 168 | model_save_path = os.path.join( 169 | checkpoint_dir, '%s_miou_best.pth' % (args.session_name)) 170 | torch.save(model.module.state_dict(), model_save_path) 171 | best_epoch = epoch 172 | logger.info('save miou best model to %s' % model_save_path) 173 | if (miou + f_score) > (max_miou + max_fs): 174 | model_save_path = os.path.join( 175 | checkpoint_dir, '%s_miou_and_fscore_best.pth' % (args.session_name)) 176 | torch.save(model.module.state_dict(), model_save_path) 177 | best_epoch = epoch 178 | logger.info('save miou and fscore best model to %s' % 179 | model_save_path) 180 | 181 | miou_list.append(miou) 182 | miou_noBg_list.append(miou_noBg) 183 | max_miou = max(miou_list) 184 | max_miou_noBg = max(miou_noBg_list) 185 | fscore_list.append(f_score) 186 | fscore_noBg_list.append(f_score_noBg) 187 | max_fs = max(fscore_list) 188 | max_fs_noBg = max(fscore_noBg_list) 189 | 190 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}, Miou(no bg): {}, maxMiou (no bg): {} '.format( 191 | epoch, miou, max_miou, miou_noBg, max_miou_noBg) 192 | val_log += ' Fscore: {}, maxFs: {}, Fscore(no bg): {}, max Fscore (no bg): {}'.format( 193 | f_score, max_fs, f_score_noBg, max_fs_noBg) 194 | logger.info(val_log) 195 | 196 | model.train() 197 | 198 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch)) 199 | 200 | 201 | if __name__ == '__main__': 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('cfg', type=str, help='config file path') 204 | parser.add_argument('--log_dir', type=str, 205 | default='work_dir', help='log dir') 206 | parser.add_argument('--checkpoint_dir', type=str, 207 | default='work_dir', help='dir to save checkpoints') 208 | parser.add_argument("--session_name", default="AVSS", 209 | type=str, help="the AVSS setting") 210 | 211 | args = parser.parse_args() 212 | main() 213 | -------------------------------------------------------------------------------- /scripts/ms3/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def F5_IoU_BCELoss(pred_mask, five_gt_masks): 7 | """ 8 | binary cross entropy loss (iou loss) of the total five frames for multiple sound source segmentation 9 | 10 | Args: 11 | pred_mask: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224] 12 | five_gt_masks: ground truth mask of the total five frames, shape: [bs*5, 1, 224, 224] 13 | """ 14 | assert len(pred_mask.shape) == 4 15 | pred_mask = torch.sigmoid(pred_mask) # [bs*5, 1, 224, 224] 16 | # five_gt_masks = five_gt_masks.view(-1, 1, five_gt_masks.shape[-2], five_gt_masks.shape[-1]) # [bs*5, 1, 224, 224] 17 | loss = nn.BCELoss()(pred_mask, five_gt_masks) 18 | 19 | return loss 20 | 21 | 22 | def F5_Dice_loss(pred_mask, five_gt_masks): 23 | """dice loss for aux loss 24 | 25 | Args: 26 | pred_mask (Tensor): (bs, 1, h, w) 27 | five_gt_masks (Tensor): (bs, 1, h, w) 28 | """ 29 | assert len(pred_mask.shape) == 4 30 | pred_mask = torch.sigmoid(pred_mask) 31 | 32 | pred_mask = pred_mask.flatten(1) 33 | gt_mask = five_gt_masks.flatten(1) 34 | a = (pred_mask * gt_mask).sum(-1) 35 | b = (pred_mask * pred_mask).sum(-1) + 0.001 36 | c = (gt_mask * gt_mask).sum(-1) + 0.001 37 | d = (2 * a) / (b + c) 38 | loss = 1 - d 39 | return loss.mean() 40 | 41 | 42 | def IouSemanticAwareLoss(pred_mask, mask_feature, gt_mask, weight_dict, loss_type='bce', **kwargs): 43 | total_loss = 0 44 | loss_dict = {} 45 | 46 | if loss_type == 'bce': 47 | loss_func = F5_IoU_BCELoss 48 | elif loss_type == 'dice': 49 | loss_func = F5_Dice_loss 50 | else: 51 | raise ValueError 52 | 53 | iou_loss = weight_dict['iou_loss'] * loss_func(pred_mask, gt_mask) 54 | total_loss += iou_loss 55 | loss_dict['iou_loss'] = iou_loss.item() 56 | 57 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True) 58 | mask_feature = F.interpolate( 59 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False) 60 | mix_loss = weight_dict['mix_loss']*loss_func(mask_feature, gt_mask) 61 | total_loss += mix_loss 62 | loss_dict['mix_loss'] = mix_loss.item() 63 | 64 | return total_loss, loss_dict 65 | -------------------------------------------------------------------------------- /scripts/ms3/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import os 4 | from mmcv import Config 5 | import argparse 6 | from utils import pyutils 7 | from utility import mask_iou, Eval_Fmeasure, save_mask 8 | from utils.logger import getLogger 9 | from model import build_model 10 | from dataloader import build_dataset 11 | 12 | 13 | def main(): 14 | # logger 15 | logger = getLogger(None, __name__) 16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 17 | logger.info(f'Load config from {args.cfg}') 18 | 19 | # config 20 | cfg = Config.fromfile(args.cfg) 21 | logger.info(cfg.pretty_text) 22 | 23 | # model 24 | model = build_model(**cfg.model) 25 | model.load_state_dict(torch.load(args.weights)) 26 | model = torch.nn.DataParallel(model).cuda() 27 | model.eval() 28 | logger.info('Load trained model %s' % args.weights) 29 | 30 | # Test data 31 | test_dataset = build_dataset(**cfg.dataset.test) 32 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 33 | batch_size=cfg.dataset.test.batch_size, 34 | shuffle=False, 35 | num_workers=cfg.process.num_works, 36 | pin_memory=True) 37 | avg_meter_miou = pyutils.AverageMeter('miou') 38 | avg_meter_F = pyutils.AverageMeter('F_score') 39 | 40 | # Test 41 | with torch.no_grad(): 42 | for n_iter, batch_data in enumerate(test_dataloader): 43 | imgs, audio, mask, video_name_list = batch_data 44 | 45 | imgs = imgs.cuda() 46 | audio = audio.cuda() 47 | mask = mask.cuda() 48 | B, frame, C, H, W = imgs.shape 49 | imgs = imgs.view(B * frame, C, H, W) 50 | mask = mask.view(B * frame, H, W) 51 | audio = audio.view(-1, audio.shape[2], 52 | audio.shape[3], audio.shape[4]) 53 | 54 | output, _ = model(audio, imgs) 55 | if args.save_pred_mask: 56 | mask_save_path = os.path.join( 57 | args.save_dir, dir_name, 'pred_masks') 58 | save_mask(output.squeeze(1), mask_save_path, video_name_list) 59 | 60 | miou = mask_iou(output.squeeze(1), mask) 61 | avg_meter_miou.add({'miou': miou}) 62 | F_score = Eval_Fmeasure(output.squeeze(1), mask) 63 | avg_meter_F.add({'F_score': F_score}) 64 | logger.info('n_iter: {}, iou: {}, F_score: {}'.format( 65 | n_iter, miou, F_score)) 66 | miou = (avg_meter_miou.pop('miou')) 67 | F_score = (avg_meter_F.pop('F_score')) 68 | logger.info(f'test miou: {miou.item()}') 69 | logger.info(f'test F_score: {F_score}') 70 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score)) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('cfg', type=str, help='config file path') 76 | parser.add_argument('weights', type=str, help='model weights path') 77 | parser.add_argument("--save_pred_mask", action='store_true', 78 | default=False, help="save predited masks or not") 79 | parser.add_argument('--save_dir', type=str, 80 | default='work_dir', help='save path') 81 | 82 | args = parser.parse_args() 83 | main() 84 | -------------------------------------------------------------------------------- /scripts/ms3/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn 4 | import os 5 | import random 6 | import numpy as np 7 | from mmcv import Config 8 | import argparse 9 | from utils import pyutils 10 | from utils.loss_util import LossUtil 11 | from utility import mask_iou 12 | from utils.logger import getLogger 13 | from model import build_model 14 | from dataloader import build_dataset 15 | from loss import IouSemanticAwareLoss 16 | 17 | 18 | def main(): 19 | # Fix seed 20 | FixSeed = 123 21 | random.seed(FixSeed) 22 | np.random.seed(FixSeed) 23 | torch.manual_seed(FixSeed) 24 | torch.cuda.manual_seed(FixSeed) 25 | 26 | # logger 27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime()) 28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 29 | if not os.path.exists(args.log_dir): 30 | os.mkdir(args.log_dir) 31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)): 32 | os.mkdir(os.path.join(args.log_dir, dir_name)) 33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log') 34 | logger = getLogger(log_file, __name__) 35 | logger.info(f'Load config from {args.cfg}') 36 | 37 | # config 38 | cfg = Config.fromfile(args.cfg) 39 | logger.info(cfg.pretty_text) 40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name) 41 | 42 | # model 43 | model = build_model(**cfg.model) 44 | model = torch.nn.DataParallel(model).cuda() 45 | model.train() 46 | logger.info("Total params: %.2fM" % (sum(p.numel() 47 | for p in model.parameters()) / 1e6)) 48 | 49 | # dataset 50 | train_dataset = build_dataset(**cfg.dataset.train) 51 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 52 | batch_size=cfg.dataset.train.batch_size, 53 | shuffle=True, 54 | num_workers=cfg.process.num_works, 55 | pin_memory=True) 56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \ 57 | cfg.process.train_epochs 58 | val_dataset = build_dataset(**cfg.dataset.val) 59 | val_dataloader = torch.utils.data.DataLoader(val_dataset, 60 | batch_size=cfg.dataset.val.batch_size, 61 | shuffle=False, 62 | num_workers=cfg.process.num_works, 63 | pin_memory=True) 64 | 65 | # optimizer 66 | optimizer = pyutils.get_optimizer(model, cfg.optimizer) 67 | loss_util = LossUtil(**cfg.loss) 68 | avg_meter_miou = pyutils.AverageMeter('miou') 69 | 70 | # Train 71 | best_epoch = 0 72 | global_step = 0 73 | miou_list = [] 74 | max_miou = 0 75 | for epoch in range(cfg.process.train_epochs): 76 | if epoch == cfg.process.freeze_epochs: 77 | model.module.freeze_backbone(False) 78 | 79 | for n_iter, batch_data in enumerate(train_dataloader): 80 | imgs, audio, mask, _ = batch_data 81 | 82 | imgs = imgs.cuda() 83 | audio = audio.cuda() 84 | mask = mask.cuda() 85 | B, frame, C, H, W = imgs.shape 86 | imgs = imgs.view(B * frame, C, H, W) 87 | mask_num = 5 88 | mask = mask.view(B * mask_num, 1, H, W) 89 | audio = audio.view(-1, audio.shape[2], 90 | audio.shape[3], audio.shape[4]) 91 | 92 | output, mask_feature = model(audio, imgs) 93 | loss, loss_dict = IouSemanticAwareLoss( 94 | output, mask_feature, mask, **cfg.loss) 95 | loss_util.add_loss(loss, loss_dict) 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | global_step += 1 101 | if (global_step - 1) % 20 == 0: 102 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % ( 103 | global_step - 1, max_step, loss_util.pretty_out(), optimizer.param_groups[0]['lr']) 104 | logger.info(train_log) 105 | 106 | # Validation: 107 | model.eval() 108 | with torch.no_grad(): 109 | for n_iter, batch_data in enumerate(val_dataloader): 110 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224] 111 | imgs, audio, mask, _ = batch_data 112 | 113 | imgs = imgs.cuda() 114 | audio = audio.cuda() 115 | mask = mask.cuda() 116 | B, frame, C, H, W = imgs.shape 117 | imgs = imgs.view(B * frame, C, H, W) 118 | mask = mask.view(B * frame, H, W) 119 | audio = audio.view(-1, audio.shape[2], 120 | audio.shape[3], audio.shape[4]) 121 | 122 | # [bs*5, 1, 224, 224] 123 | output, _ = model(audio, imgs) 124 | 125 | miou = mask_iou(output.squeeze(1), mask) 126 | avg_meter_miou.add({'miou': miou}) 127 | 128 | miou = (avg_meter_miou.pop('miou')) 129 | if miou > max_miou: 130 | model_save_path = os.path.join( 131 | checkpoint_dir, '%s_best.pth' % (args.session_name)) 132 | torch.save(model.module.state_dict(), model_save_path) 133 | best_epoch = epoch 134 | logger.info('save best model to %s' % model_save_path) 135 | 136 | miou_list.append(miou) 137 | max_miou = max(miou_list) 138 | 139 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}'.format( 140 | epoch, miou, max_miou) 141 | logger.info(val_log) 142 | 143 | model.train() 144 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch)) 145 | 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('cfg', type=str, help='config file path') 150 | parser.add_argument('--log_dir', type=str, 151 | default='work_dir', help='log dir') 152 | parser.add_argument('--checkpoint_dir', type=str, 153 | default='work_dir', help='dir to save checkpoints') 154 | parser.add_argument("--session_name", default="MS3", 155 | type=str, help="the MS3 setting") 156 | 157 | args = parser.parse_args() 158 | main() 159 | -------------------------------------------------------------------------------- /scripts/ms3/utility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import os 5 | import shutil 6 | import logging 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import sys 12 | import time 13 | import pandas as pd 14 | import pdb 15 | from torchvision import transforms 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def save_checkpoint(state, epoch, is_best, checkpoint_dir='./models', filename='checkpoint', thres=100): 21 | """ 22 | - state 23 | - epoch 24 | - is_best 25 | - checkpoint_dir: default, ./models 26 | - filename: default, checkpoint 27 | - freq: default, 10 28 | - thres: default, 100 29 | """ 30 | if not os.path.isdir(checkpoint_dir): 31 | os.makedirs(checkpoint_dir) 32 | 33 | if epoch >= thres: 34 | file_path = os.path.join( 35 | checkpoint_dir, filename + '_{}'.format(str(epoch)) + '.pth.tar') 36 | else: 37 | file_path = os.path.join(checkpoint_dir, filename + '.pth.tar') 38 | torch.save(state, file_path) 39 | logger.info('==> save model at {}'.format(file_path)) 40 | 41 | if is_best: 42 | cpy_file = os.path.join( 43 | checkpoint_dir, filename + '_model_best.pth.tar') 44 | shutil.copyfile(file_path, cpy_file) 45 | logger.info('==> save best model at {}'.format(cpy_file)) 46 | 47 | 48 | def mask_iou(pred, target, eps=1e-7, size_average=True): 49 | r""" 50 | param: 51 | pred: size [N x H x W] 52 | target: size [N x H x W] 53 | output: 54 | iou: size [1] (size_average=True) or [N] (size_average=False) 55 | """ 56 | assert len(pred.shape) == 3 and pred.shape == target.shape 57 | 58 | N = pred.size(0) 59 | num_pixels = pred.size(-1) * pred.size(-2) 60 | no_obj_flag = (target.sum(2).sum(1) == 0) 61 | 62 | temp_pred = torch.sigmoid(pred) 63 | pred = (temp_pred > 0.5).int() 64 | inter = (pred * target).sum(2).sum(1) 65 | union = torch.max(pred, target).sum(2).sum(1) 66 | 67 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 68 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 69 | union[no_obj_flag] = num_pixels 70 | 71 | iou = torch.sum(inter / (union + eps)) / N 72 | 73 | return iou 74 | 75 | 76 | def _eval_pr(y_pred, y, num, cuda_flag=True): 77 | if cuda_flag: 78 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 79 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 80 | else: 81 | prec, recall = torch.zeros(num), torch.zeros(num) 82 | thlist = torch.linspace(0, 1 - 1e-10, num) 83 | for i in range(num): 84 | y_temp = (y_pred >= thlist[i]).float() 85 | tp = (y_temp * y).sum() 86 | prec[i], recall[i] = tp / \ 87 | (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 88 | 89 | return prec, recall 90 | 91 | 92 | def Eval_Fmeasure(pred, gt, pr_num=255): 93 | r""" 94 | param: 95 | pred: size [N x H x W] 96 | gt: size [N x H x W] 97 | output: 98 | iou: size [1] (size_average=True) or [N] (size_average=False) 99 | """ 100 | print('=> eval [FMeasure]..') 101 | # =======================================[important] 102 | pred = torch.sigmoid(pred) 103 | N = pred.size(0) 104 | beta2 = 0.3 105 | avg_f, img_num = 0.0, 0 106 | score = torch.zeros(pr_num) 107 | print("{} videos in this batch".format(N)) 108 | 109 | for img_id in range(N): 110 | # examples with totally black GTs are out of consideration 111 | if torch.mean(gt[img_id]) == 0.0: 112 | continue 113 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num) 114 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 115 | f_score[f_score != f_score] = 0 # for Nan 116 | avg_f += f_score 117 | img_num += 1 118 | score = avg_f / img_num 119 | 120 | return score.max().item() 121 | 122 | 123 | def save_mask(pred_masks, save_base_path, video_name_list): 124 | # pred_mask: [bs*5, 1, 224, 224] 125 | # print(f"=> {len(video_name_list)} videos in this batch") 126 | 127 | if not os.path.exists(save_base_path): 128 | os.makedirs(save_base_path, exist_ok=True) 129 | 130 | pred_masks = pred_masks.squeeze(2) 131 | pred_masks = (torch.sigmoid(pred_masks) > 0.5).int() 132 | 133 | pred_masks = pred_masks.view(-1, 5, 134 | pred_masks.shape[-2], pred_masks.shape[-1]) 135 | pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8) 136 | pred_masks *= 255 137 | bs = pred_masks.shape[0] 138 | 139 | for idx in range(bs): 140 | video_name = video_name_list[idx] 141 | mask_save_path = os.path.join(save_base_path, video_name) 142 | if not os.path.exists(mask_save_path): 143 | os.makedirs(mask_save_path, exist_ok=True) 144 | one_video_masks = pred_masks[idx] # [5, 1, 224, 224] 145 | for video_id in range(len(one_video_masks)): 146 | one_mask = one_video_masks[video_id] 147 | output_name = "%s_%d.png" % (video_name, video_id) 148 | im = Image.fromarray(one_mask).convert('P') 149 | im.save(os.path.join(mask_save_path, output_name), format='PNG') 150 | 151 | 152 | def save_raw_img_mask(anno_file_path, raw_img_base_path, mask_base_path, split='test', r=0.5): 153 | df = pd.read_csv(anno_file_path, sep=',') 154 | df_test = df[df['split'] == split] 155 | count = 0 156 | for video_id in range(len(df_test)): 157 | video_name = df_test.iloc[video_id][0] 158 | raw_img_path = os.path.join(raw_img_base_path, video_name) 159 | for img_id in range(5): 160 | img_name = "%s.mp4_%d.png" % (video_name, img_id + 1) 161 | raw_img = cv2.imread(os.path.join(raw_img_path, img_name)) 162 | mask = cv2.imread(os.path.join( 163 | mask_base_path, 'pred_masks', video_name, "%s_%d.png" % (video_name, img_id))) 164 | # pdb.set_trace() 165 | raw_img_mask = cv2.addWeighted(raw_img, 1, mask, r, 0) 166 | save_img_path = os.path.join( 167 | mask_base_path, 'img_add_masks', video_name) 168 | if not os.path.exists(save_img_path): 169 | os.makedirs(save_img_path, exist_ok=True) 170 | cv2.imwrite(os.path.join(save_img_path, img_name), raw_img_mask) 171 | count += 1 172 | print(f'count: {count} videos') 173 | -------------------------------------------------------------------------------- /scripts/s4/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def F1_IoU_BCELoss(pred_masks, first_gt_mask): 7 | """ 8 | binary cross entropy loss (iou loss) of the first frame for single sound source segmentation 9 | 10 | Args: 11 | pred_masks: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224] 12 | first_gt_mask: ground truth mask of the first frame, shape: [bs, 1, 1, 224, 224] 13 | """ 14 | assert len(pred_masks.shape) == 4 15 | pred_masks = torch.sigmoid(pred_masks) # [bs*5, 1, 224, 224] 16 | 17 | indices = torch.tensor(list(range(0, len(pred_masks), 5))) 18 | indices = indices.cuda() 19 | first_pred = torch.index_select( 20 | pred_masks, dim=0, index=indices) # [bs, 1, 224, 224] 21 | assert first_pred.requires_grad == True, "Error when indexing predited masks" 22 | if len(first_gt_mask.shape) == 5: 23 | first_gt_mask = first_gt_mask.squeeze(1) # [bs, 1, 224, 224] 24 | 25 | first_bce_loss = nn.BCELoss()(first_pred, first_gt_mask) 26 | 27 | return first_bce_loss 28 | 29 | 30 | def F1_Dice_loss(pred_masks, first_gt_mask): 31 | """dice loss for aux loss 32 | 33 | Args: 34 | pred_mask (Tensor): (bs*5, 1, h, w) 35 | five_gt_masks (Tensor): (bs, 1, 1, h, w) 36 | """ 37 | assert len(pred_masks.shape) == 4 38 | pred_masks = torch.sigmoid(pred_masks) 39 | 40 | indices = torch.tensor(list(range(0, len(pred_masks), 5))) 41 | indices = indices.cuda() 42 | first_pred = torch.index_select( 43 | pred_masks, dim=0, index=indices) # [bs, 1, 224, 224] 44 | assert first_pred.requires_grad == True, "Error when indexing predited masks" 45 | if len(first_gt_mask.shape) == 5: 46 | first_gt_mask = first_gt_mask.squeeze(1) # [bs, 1, 224, 224] 47 | 48 | pred_mask = first_pred.flatten(1) 49 | gt_mask = first_gt_mask.flatten(1) 50 | a = (pred_mask * gt_mask).sum(-1) 51 | b = (pred_mask * pred_mask).sum(-1) + 0.001 52 | c = (gt_mask * gt_mask).sum(-1) + 0.001 53 | d = (2 * a) / (b + c) 54 | loss = 1 - d 55 | return loss.mean() 56 | 57 | 58 | def IouSemanticAwareLoss(pred_masks, mask_feature, gt_mask, weight_dict, loss_type='bce', **kwargs): 59 | total_loss = 0 60 | loss_dict = {} 61 | 62 | if loss_type == 'bce': 63 | loss_func = F1_IoU_BCELoss 64 | elif loss_type == 'dice': 65 | loss_func = F1_Dice_loss 66 | else: 67 | raise ValueError 68 | 69 | iou_loss = loss_func(pred_masks, gt_mask) 70 | total_loss += weight_dict['iou_loss'] * iou_loss 71 | loss_dict['iou_loss'] = weight_dict['iou_loss'] * iou_loss.item() 72 | 73 | mask_feature = torch.mean(mask_feature, dim=1, keepdim=True) 74 | mask_feature = F.interpolate( 75 | mask_feature, gt_mask.shape[-2:], mode='bilinear', align_corners=False) 76 | mix_loss = weight_dict['mix_loss']*loss_func(mask_feature, gt_mask) 77 | total_loss += mix_loss 78 | loss_dict['mix_loss'] = mix_loss.item() 79 | 80 | return total_loss, loss_dict 81 | -------------------------------------------------------------------------------- /scripts/s4/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import os 4 | from mmcv import Config 5 | import argparse 6 | from utils import pyutils 7 | from utility import mask_iou, Eval_Fmeasure, save_mask 8 | from utils.logger import getLogger 9 | from model import build_model 10 | from dataloader import build_dataset 11 | 12 | 13 | def main(): 14 | # logger 15 | logger = getLogger(None, __name__) 16 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 17 | logger.info(f'Load config from {args.cfg}') 18 | 19 | # config 20 | cfg = Config.fromfile(args.cfg) 21 | logger.info(cfg.pretty_text) 22 | 23 | # model 24 | model = build_model(**cfg.model) 25 | model.load_state_dict(torch.load(args.weights)) 26 | model = torch.nn.DataParallel(model).cuda() 27 | model.eval() 28 | logger.info('Load trained model %s' % args.weights) 29 | 30 | # Test data 31 | test_dataset = build_dataset(**cfg.dataset.test) 32 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 33 | batch_size=cfg.dataset.test.batch_size, 34 | shuffle=False, 35 | num_workers=cfg.process.num_works, 36 | pin_memory=True) 37 | avg_meter_miou = pyutils.AverageMeter('miou') 38 | avg_meter_F = pyutils.AverageMeter('F_score') 39 | 40 | # Test 41 | with torch.no_grad(): 42 | for n_iter, batch_data in enumerate(test_dataloader): 43 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224] 44 | imgs, audio, mask, category_list, video_name_list = batch_data 45 | 46 | imgs = imgs.cuda() 47 | audio = audio.cuda() 48 | mask = mask.cuda() 49 | B, frame, C, H, W = imgs.shape 50 | imgs = imgs.view(B * frame, C, H, W) 51 | mask = mask.view(B * frame, H, W) 52 | audio = audio.view(-1, audio.shape[2], 53 | audio.shape[3], audio.shape[4]) 54 | 55 | output, _ = model(audio, imgs) 56 | if args.save_pred_mask: 57 | mask_save_path = os.path.join( 58 | args.save_dir, dir_name, 'pred_masks') 59 | save_mask(output.squeeze(1), mask_save_path, 60 | category_list, video_name_list) 61 | 62 | miou = mask_iou(output.squeeze(1), mask) 63 | avg_meter_miou.add({'miou': miou}) 64 | F_score = Eval_Fmeasure(output.squeeze(1), mask) 65 | avg_meter_F.add({'F_score': F_score}) 66 | logger.info('n_iter: {}, iou: {}, F_score: {}'.format( 67 | n_iter, miou, F_score)) 68 | 69 | miou = (avg_meter_miou.pop('miou')) 70 | F_score = (avg_meter_F.pop('F_score')) 71 | logger.info(f'test miou: {miou.item}') 72 | logger.info(f'test F_score: {F_score}') 73 | logger.info('test miou: {}, F_score: {}'.format(miou.item(), F_score)) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('cfg', type=str, help='config file path') 79 | parser.add_argument('weights', type=str, help='model weights path') 80 | parser.add_argument("--save_pred_mask", action='store_true', 81 | default=False, help="save predited masks or not") 82 | parser.add_argument('--save_dir', type=str, 83 | default='work_dir', help='save path') 84 | 85 | args = parser.parse_args() 86 | main() 87 | -------------------------------------------------------------------------------- /scripts/s4/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn 4 | import os 5 | import random 6 | import numpy as np 7 | from mmcv import Config 8 | import argparse 9 | from utils import pyutils 10 | from utils.loss_util import LossUtil 11 | from utility import mask_iou 12 | from utils.logger import getLogger 13 | from model import build_model 14 | from dataloader import build_dataset 15 | from loss import IouSemanticAwareLoss 16 | 17 | 18 | def main(): 19 | # Fix seed 20 | FixSeed = 123 21 | random.seed(FixSeed) 22 | np.random.seed(FixSeed) 23 | torch.manual_seed(FixSeed) 24 | torch.cuda.manual_seed(FixSeed) 25 | 26 | # logger 27 | log_name = time.strftime('%Y%m%d-%H%M%S', time.localtime()) 28 | dir_name = os.path.splitext(os.path.split(args.cfg)[-1])[0] 29 | if not os.path.exists(args.log_dir): 30 | os.mkdir(args.log_dir) 31 | if not os.path.exists(os.path.join(args.log_dir, dir_name)): 32 | os.mkdir(os.path.join(args.log_dir, dir_name)) 33 | log_file = os.path.join(args.log_dir, dir_name, f'{log_name}.log') 34 | logger = getLogger(log_file, __name__) 35 | logger.info(f'Load config from {args.cfg}') 36 | 37 | # config 38 | cfg = Config.fromfile(args.cfg) 39 | logger.info(cfg.pretty_text) 40 | checkpoint_dir = os.path.join(args.checkpoint_dir, dir_name) 41 | 42 | # model 43 | model = build_model(**cfg.model) 44 | model = torch.nn.DataParallel(model).cuda() 45 | model.train() 46 | logger.info("Total params: %.2fM" % (sum(p.numel() 47 | for p in model.parameters()) / 1e6)) 48 | 49 | # dataset 50 | train_dataset = build_dataset(**cfg.dataset.train) 51 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 52 | batch_size=cfg.dataset.train.batch_size, 53 | shuffle=True, 54 | num_workers=cfg.process.num_works, 55 | pin_memory=True) 56 | max_step = (len(train_dataset) // cfg.dataset.train.batch_size) * \ 57 | cfg.process.train_epochs 58 | val_dataset = build_dataset(**cfg.dataset.val) 59 | val_dataloader = torch.utils.data.DataLoader(val_dataset, 60 | batch_size=cfg.dataset.val.batch_size, 61 | shuffle=False, 62 | num_workers=cfg.process.num_works, 63 | pin_memory=True) 64 | 65 | # optimizer 66 | optimizer = pyutils.get_optimizer(model, cfg.optimizer) 67 | loss_util = LossUtil(**cfg.loss) 68 | avg_meter_miou = pyutils.AverageMeter('miou') 69 | 70 | # Train 71 | best_epoch = 0 72 | global_step = 0 73 | miou_list = [] 74 | max_miou = 0 75 | for epoch in range(cfg.process.train_epochs): 76 | if epoch == cfg.process.freeze_epochs: 77 | model.module.freeze_backbone(False) 78 | 79 | for n_iter, batch_data in enumerate(train_dataloader): 80 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224] 81 | imgs, audio, mask = batch_data 82 | 83 | imgs = imgs.cuda() 84 | audio = audio.cuda() 85 | mask = mask.cuda() 86 | B, frame, C, H, W = imgs.shape 87 | imgs = imgs.view(B * frame, C, H, W) 88 | mask = mask.view(B, H, W) 89 | audio = audio.view(-1, audio.shape[2], 90 | audio.shape[3], audio.shape[4]) 91 | 92 | output, mask_feature = model(audio, imgs) # [bs*5, 1, 224, 224] 93 | loss, loss_dict = IouSemanticAwareLoss( 94 | output, mask_feature, mask.unsqueeze(1).unsqueeze(1), **cfg.loss) 95 | loss_util.add_loss(loss, loss_dict) 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | global_step += 1 101 | 102 | if (global_step - 1) % 50 == 0: 103 | train_log = 'Iter:%5d/%5d, %slr: %.6f' % ( 104 | global_step - 1, max_step, loss_util.pretty_out(), optimizer.param_groups[0]['lr']) 105 | logger.info(train_log) 106 | 107 | # Validation: 108 | model.eval() 109 | with torch.no_grad(): 110 | for n_iter, batch_data in enumerate(val_dataloader): 111 | # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224] 112 | imgs, audio, mask, _, _ = batch_data 113 | 114 | imgs = imgs.cuda() 115 | audio = audio.cuda() 116 | mask = mask.cuda() 117 | B, frame, C, H, W = imgs.shape 118 | imgs = imgs.view(B * frame, C, H, W) 119 | mask = mask.view(B * frame, H, W) 120 | audio = audio.view(-1, audio.shape[2], 121 | audio.shape[3], audio.shape[4]) 122 | 123 | output, _ = model(audio, imgs) 124 | 125 | miou = mask_iou(output.squeeze(1), mask) 126 | avg_meter_miou.add({'miou': miou}) 127 | 128 | miou = (avg_meter_miou.pop('miou')) 129 | if miou > max_miou: 130 | model_save_path = os.path.join( 131 | checkpoint_dir, '%s_best.pth' % (args.session_name)) 132 | torch.save(model.module.state_dict(), model_save_path) 133 | best_epoch = epoch 134 | logger.info('save best model to %s' % model_save_path) 135 | 136 | miou_list.append(miou) 137 | max_miou = max(miou_list) 138 | 139 | val_log = 'Epoch: {}, Miou: {}, maxMiou: {}'.format( 140 | epoch, miou, max_miou) 141 | # print(val_log) 142 | logger.info(val_log) 143 | 144 | model.train() 145 | 146 | logger.info('best val Miou {} at peoch: {}'.format(max_miou, best_epoch)) 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('cfg', type=str, help='config file path') 152 | parser.add_argument('--log_dir', type=str, 153 | default='work_dir', help='log dir') 154 | parser.add_argument('--checkpoint_dir', type=str, 155 | default='work_dir', help='dir to save checkpoints') 156 | parser.add_argument("--session_name", default="S4", 157 | type=str, help="the S4 setting") 158 | 159 | args = parser.parse_args() 160 | main() 161 | -------------------------------------------------------------------------------- /scripts/s4/utility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import os 5 | import shutil 6 | import logging 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import sys 12 | import time 13 | import pandas as pd 14 | import pdb 15 | from torchvision import transforms 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def save_checkpoint(state, epoch, is_best, checkpoint_dir='./models', filename='checkpoint', thres=100): 21 | """ 22 | - state 23 | - epoch 24 | - is_best 25 | - checkpoint_dir: default, ./models 26 | - filename: default, checkpoint 27 | - freq: default, 10 28 | - thres: default, 100 29 | """ 30 | if not os.path.isdir(checkpoint_dir): 31 | os.makedirs(checkpoint_dir) 32 | 33 | if epoch >= thres: 34 | file_path = os.path.join( 35 | checkpoint_dir, filename + '_{}'.format(str(epoch)) + '.pth.tar') 36 | else: 37 | file_path = os.path.join(checkpoint_dir, filename + '.pth.tar') 38 | torch.save(state, file_path) 39 | logger.info('==> save model at {}'.format(file_path)) 40 | 41 | if is_best: 42 | cpy_file = os.path.join( 43 | checkpoint_dir, filename + '_model_best.pth.tar') 44 | shutil.copyfile(file_path, cpy_file) 45 | logger.info('==> save best model at {}'.format(cpy_file)) 46 | 47 | 48 | def mask_iou(pred, target, eps=1e-7, size_average=True): 49 | r""" 50 | param: 51 | pred: size [N x H x W] 52 | target: size [N x H x W] 53 | output: 54 | iou: size [1] (size_average=True) or [N] (size_average=False) 55 | """ 56 | assert len(pred.shape) == 3 and pred.shape == target.shape 57 | 58 | N = pred.size(0) 59 | num_pixels = pred.size(-1) * pred.size(-2) 60 | no_obj_flag = (target.sum(2).sum(1) == 0) 61 | 62 | temp_pred = torch.sigmoid(pred) 63 | pred = (temp_pred > 0.5).int() 64 | inter = (pred * target).sum(2).sum(1) 65 | union = torch.max(pred, target).sum(2).sum(1) 66 | 67 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 68 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 69 | union[no_obj_flag] = num_pixels 70 | 71 | iou = torch.sum(inter / (union + eps)) / N 72 | 73 | return iou 74 | 75 | 76 | def _eval_pr(y_pred, y, num, cuda_flag=True): 77 | if cuda_flag: 78 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 79 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 80 | else: 81 | prec, recall = torch.zeros(num), torch.zeros(num) 82 | thlist = torch.linspace(0, 1 - 1e-10, num) 83 | for i in range(num): 84 | y_temp = (y_pred >= thlist[i]).float() 85 | tp = (y_temp * y).sum() 86 | prec[i], recall[i] = tp / \ 87 | (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 88 | 89 | return prec, recall 90 | 91 | 92 | def Eval_Fmeasure(pred, gt, pr_num=255): 93 | r""" 94 | param: 95 | pred: size [N x H x W] 96 | gt: size [N x H x W] 97 | output: 98 | iou: size [1] (size_average=True) or [N] (size_average=False) 99 | """ 100 | print('=> eval [FMeasure]..') 101 | # =======================================[important] 102 | pred = torch.sigmoid(pred) 103 | N = pred.size(0) 104 | beta2 = 0.3 105 | avg_f, img_num = 0.0, 0 106 | score = torch.zeros(pr_num) 107 | print("{} videos in this batch".format(N)) 108 | 109 | for img_id in range(N): 110 | # examples with totally black GTs are out of consideration 111 | if torch.mean(gt[img_id]) == 0.0: 112 | continue 113 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num) 114 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 115 | f_score[f_score != f_score] = 0 # for Nan 116 | avg_f += f_score 117 | img_num += 1 118 | score = avg_f / img_num 119 | 120 | return score.max().item() 121 | 122 | 123 | def save_mask(pred_masks, save_base_path, category_list, video_name_list): 124 | # pred_mask: [bs*5, 1, 224, 224] 125 | # print(f"=> {len(video_name_list)} videos in this batch") 126 | 127 | if not os.path.exists(save_base_path): 128 | os.makedirs(save_base_path, exist_ok=True) 129 | 130 | pred_masks = pred_masks.squeeze(2) 131 | pred_masks = (torch.sigmoid(pred_masks) > 0.5).int() 132 | 133 | pred_masks = pred_masks.view(-1, 5, 134 | pred_masks.shape[-2], pred_masks.shape[-1]) 135 | pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8) 136 | pred_masks *= 255 137 | bs = pred_masks.shape[0] 138 | 139 | for idx in range(bs): 140 | category, video_name = category_list[idx], video_name_list[idx] 141 | mask_save_path = os.path.join(save_base_path, category, video_name) 142 | if not os.path.exists(mask_save_path): 143 | os.makedirs(mask_save_path, exist_ok=True) 144 | one_video_masks = pred_masks[idx] # [5, 1, 224, 224] 145 | for video_id in range(len(one_video_masks)): 146 | one_mask = one_video_masks[video_id] 147 | output_name = "%s_%d.png" % (video_name, video_id) 148 | im = Image.fromarray(one_mask).convert('P') 149 | im.save(os.path.join(mask_save_path, output_name), format='PNG') 150 | 151 | 152 | def save_raw_img_mask(anno_file_path, raw_img_base_path, mask_base_path, split='test', r=0.5): 153 | df = pd.read_csv(anno_file_path, sep=',') 154 | df_test = df[df['split'] == split] 155 | count = 0 156 | for video_id in range(len(df_test)): 157 | video_name, category = df_test.iloc[video_id][0], df_test.iloc[video_id][2] 158 | raw_img_path = os.path.join( 159 | raw_img_base_path, split, category, video_name) 160 | for img_id in range(5): 161 | img_name = "%s_%d.png" % (video_name, img_id + 1) 162 | raw_img = cv2.imread(os.path.join(raw_img_path, img_name)) 163 | mask = cv2.imread(os.path.join(mask_base_path, 'pred_masks', 164 | category, video_name, "%s_%d.png" % (video_name, img_id))) 165 | # pdb.set_trace() 166 | raw_img_mask = cv2.addWeighted(raw_img, 1, mask, r, 0) 167 | save_img_path = os.path.join( 168 | mask_base_path, 'img_add_masks', category, video_name) 169 | if not os.path.exists(save_img_path): 170 | os.makedirs(save_img_path, exist_ok=True) 171 | cv2.imwrite(os.path.join(save_img_path, img_name), raw_img_mask) 172 | count += 1 173 | print(f'count: {count} videos') 174 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | SESSION=$1 2 | CONFIG=$2 3 | WEIGHTS=$3 4 | 5 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 6 | python scripts/$SESSION/test.py \ 7 | $CONFIG \ 8 | $WEIGHTS \ 9 | # --save_pred_mask 10 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | SESSION=$1 2 | CONFIG=$2 3 | 4 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 5 | python scripts/$SESSION/train.py $CONFIG 6 | -------------------------------------------------------------------------------- /utils/compute_color_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import os 5 | import shutil 6 | import logging 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import sys 12 | import time 13 | import pandas as pd 14 | from torchvision import transforms 15 | 16 | import numpy as np 17 | from multiprocessing import Pool 18 | from tqdm import tqdm 19 | 20 | import pdb 21 | 22 | 23 | def _batch_miou_fscore(output, target, nclass, T, beta2=0.3): 24 | """batch mIoU and Fscore""" 25 | # output: [BF, C, H, W], 26 | # target: [BF, H, W] 27 | mini = 1 28 | maxi = nclass 29 | nbins = nclass 30 | predict = torch.argmax(output, 1) + 1 31 | target = target.float() + 1 32 | # pdb.set_trace() 33 | predict = predict.float() * (target > 0).float() # [BF, H, W] 34 | intersection = predict * (predict == target).float() # [BF, H, W] 35 | # areas of intersection and union 36 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 37 | batch_size = target.shape[0] // T 38 | cls_count = torch.zeros(nclass).float() 39 | ious = torch.zeros(nclass).float() 40 | fscores = torch.zeros(nclass).float() 41 | 42 | # vid_miou_list = torch.zeros(target.shape[0]).float() 43 | vid_miou_list = [] 44 | for i in range(target.shape[0]): 45 | area_inter = torch.histc( 46 | intersection[i].cpu(), bins=nbins, min=mini, max=maxi) # TP 47 | area_pred = torch.histc( 48 | predict[i].cpu(), bins=nbins, min=mini, max=maxi) # TP + FP 49 | area_lab = torch.histc( 50 | target[i].cpu(), bins=nbins, min=mini, max=maxi) # TP + FN 51 | area_union = area_pred + area_lab - area_inter 52 | assert torch.sum(area_inter > area_union).item( 53 | ) == 0, "Intersection area should be smaller than Union area" 54 | iou = 1.0 * area_inter.float() / (2.220446049250313e-16 + area_union.float()) 55 | # iou[torch.isnan(iou)] = 1. 56 | ious += iou 57 | cls_count[torch.nonzero(area_union).squeeze(-1)] += 1 58 | 59 | precision = area_inter / area_pred 60 | recall = area_inter / area_lab 61 | fscore = (1 + beta2) * precision * recall / \ 62 | (beta2 * precision + recall) 63 | fscore[torch.isnan(fscore)] = 0. 64 | fscores += fscore 65 | 66 | vid_miou_list.append(torch.sum(iou) / (torch.sum(iou != 0).float())) 67 | 68 | return ious, fscores, cls_count, vid_miou_list 69 | 70 | 71 | def calc_color_miou_fscore(pred, target, T=10): 72 | r""" 73 | J measure 74 | param: 75 | pred: size [BF x C x H x W], C is category number including background 76 | target: size [BF x H x W] 77 | """ 78 | nclass = pred.shape[1] 79 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W] 80 | # miou, fscore, cls_count = _batch_miou_fscore(pred, target, nclass, T) 81 | miou, fscore, cls_count, vid_miou_list = _batch_miou_fscore( 82 | pred, target, nclass, T) 83 | return miou, fscore, cls_count, vid_miou_list 84 | 85 | 86 | def _batch_intersection_union(output, target, nclass, T): 87 | """mIoU""" 88 | # output: [BF, C, H, W], 89 | # target: [BF, H, W] 90 | mini = 1 91 | maxi = nclass 92 | nbins = nclass 93 | predict = torch.argmax(output, 1) + 1 94 | target = target.float() + 1 95 | 96 | # pdb.set_trace() 97 | 98 | predict = predict.float() * (target > 0).float() # [BF, H, W] 99 | intersection = predict * (predict == target).float() # [BF, H, W] 100 | # areas of intersection and union 101 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 102 | batch_size = target.shape[0] // T 103 | cls_count = torch.zeros(nclass).float() 104 | ious = torch.zeros(nclass).float() 105 | for i in range(target.shape[0]): 106 | area_inter = torch.histc( 107 | intersection[i].cpu(), bins=nbins, min=mini, max=maxi) 108 | area_pred = torch.histc( 109 | predict[i].cpu(), bins=nbins, min=mini, max=maxi) 110 | area_lab = torch.histc(target[i].cpu(), bins=nbins, min=mini, max=maxi) 111 | area_union = area_pred + area_lab - area_inter 112 | assert torch.sum(area_inter > area_union).item( 113 | ) == 0, "Intersection area should be smaller than Union area" 114 | iou = 1.0 * area_inter.float() / (2.220446049250313e-16 + area_union.float()) 115 | ious += iou 116 | cls_count[torch.nonzero(area_union).squeeze(-1)] += 1 117 | # pdb.set_trace() 118 | # ious = ious / cls_count 119 | # ious[torch.isnan(ious)] = 0 120 | # pdb.set_trace() 121 | # return area_inter.float(), area_union.float() 122 | # return ious 123 | return ious, cls_count 124 | 125 | 126 | def calc_color_miou(pred, target, T=10): 127 | r""" 128 | J measure 129 | param: 130 | pred: size [BF x C x H x W], C is category number including background 131 | target: size [BF x H x W] 132 | """ 133 | nclass = pred.shape[1] 134 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W] 135 | # correct, labeled = _batch_pix_accuracy(pred, target) 136 | # inter, union = _batch_intersection_union(pred, target, nclass, T) 137 | ious, cls_count = _batch_intersection_union(pred, target, nclass, T) 138 | 139 | # pixAcc = 1.0 * correct / (2.220446049250313e-16 + labeled) 140 | # IoU = 1.0 * inter / (2.220446049250313e-16 + union) 141 | # mIoU = IoU.mean().item() 142 | # pdb.set_trace() 143 | # return mIoU 144 | return ious, cls_count 145 | 146 | 147 | def calc_binary_miou(pred, target, eps=1e-7, size_average=True): 148 | r""" 149 | param: 150 | pred: size [N x C x H x W] 151 | target: size [N x H x W] 152 | output: 153 | iou: size [1] (size_average=True) or [N] (size_average=False) 154 | """ 155 | # assert len(pred.shape) == 3 and pred.shape == target.shape 156 | nclass = pred.shape[1] 157 | pred = torch.softmax(pred, dim=1) # [BF, C, H, W] 158 | pred = torch.argmax(pred, dim=1) # [BF, H, W] 159 | binary_pred = (pred != (nclass - 1)).float() # [BF, H, W] 160 | # pdb.set_trace() 161 | pred = binary_pred 162 | target = (target != (nclass - 1)).float() 163 | 164 | N = pred.size(0) 165 | num_pixels = pred.size(-1) * pred.size(-2) 166 | no_obj_flag = (target.sum(2).sum(1) == 0) 167 | 168 | temp_pred = torch.sigmoid(pred) 169 | pred = (temp_pred > 0.5).int() 170 | inter = (pred * target).sum(2).sum(1) 171 | union = torch.max(pred, target).sum(2).sum(1) 172 | 173 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 174 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 175 | union[no_obj_flag] = num_pixels 176 | 177 | iou = torch.sum(inter / (union + eps)) / N 178 | # pdb.set_trace() 179 | return iou 180 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def getLogger(log_file, name, fmt='%(asctime)s %(levelname)s ==> %(message)s'): 5 | logger = logging.getLogger(name) 6 | logger.setLevel(logging.DEBUG) 7 | formatter = logging.Formatter(fmt) 8 | 9 | console_handler = logging.StreamHandler() 10 | console_handler.setLevel(logging.DEBUG) 11 | console_handler.setFormatter(formatter) 12 | logger.addHandler(console_handler) 13 | if log_file is not None: 14 | file_handler = logging.FileHandler(log_file) 15 | file_handler.setLevel(logging.INFO) 16 | file_handler.setFormatter(formatter) 17 | logger.addHandler(file_handler) 18 | 19 | return logger 20 | -------------------------------------------------------------------------------- /utils/loss_util.py: -------------------------------------------------------------------------------- 1 | from utils.pyutils import AverageMeter 2 | 3 | 4 | class LossUtil: 5 | def __init__(self, weight_dict, **kwargs) -> None: 6 | self.loss_weight_dict = weight_dict 7 | self.avg_loss = dict() 8 | self.avg_loss['total_loss'] = AverageMeter('total_loss') 9 | # for k in weight_dict.keys(): 10 | # self.avg_loss[k] = AverageMeter(k) 11 | 12 | def add_loss(self, loss, loss_dict): 13 | self.avg_loss['total_loss'].add({'total_loss': loss.item()}) 14 | for k, v in loss_dict.items(): 15 | meter = self.avg_loss.get(k, None) 16 | if meter is None: 17 | meter = AverageMeter(k) 18 | self.avg_loss[k] = meter 19 | 20 | self.avg_loss[k].add({k: v}) 21 | 22 | def pretty_out(self): 23 | f = 'Total_Loss:%.4f, ' % ( 24 | self.avg_loss['total_loss'].pop('total_loss')) 25 | for k in self.avg_loss.keys(): 26 | if k == 'total_loss': 27 | continue 28 | t = '%s:%.4f, ' % (k, self.avg_loss[k].pop(k)) 29 | f += t 30 | return f 31 | -------------------------------------------------------------------------------- /utils/pyutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import sys 5 | from multiprocessing.pool import ThreadPool 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, outfile): 10 | self.terminal = sys.stdout 11 | self.log = open(outfile, "w") 12 | sys.stdout = self 13 | 14 | def write(self, message): 15 | self.terminal.write(message) 16 | self.log.write(message) 17 | 18 | def flush(self): 19 | self.terminal.flush() 20 | 21 | 22 | class AverageMeter: 23 | def __init__(self, *keys): 24 | self.__data = dict() 25 | for k in keys: 26 | self.__data[k] = [0.0, 0] 27 | 28 | def add(self, dict): 29 | for k, v in dict.items(): 30 | self.__data[k][0] += v 31 | self.__data[k][1] += 1 32 | 33 | def get(self, *keys): 34 | if len(keys) == 1: 35 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 36 | else: 37 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 38 | return tuple(v_list) 39 | 40 | def pop(self, key=None): 41 | if key is None: 42 | for k in self.__data.keys(): 43 | self.__data[k] = [0.0, 0] 44 | else: 45 | v = self.get(key) 46 | self.__data[key] = [0.0, 0] 47 | return v 48 | 49 | 50 | class Timer: 51 | def __init__(self, starting_msg=None): 52 | self.start = time.time() 53 | self.stage_start = self.start 54 | 55 | if starting_msg is not None: 56 | print(starting_msg, time.ctime(time.time())) 57 | 58 | def update_progress(self, progress): 59 | self.elapsed = time.time() - self.start 60 | self.est_total = self.elapsed / progress 61 | self.est_remaining = self.est_total - self.elapsed 62 | self.est_finish = int(self.start + self.est_total) 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | class BatchThreader: 75 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 76 | self.batch_size = batch_size 77 | self.prefetch_size = prefetch_size 78 | 79 | self.pool = ThreadPool(processes=processes) 80 | self.async_result = [] 81 | 82 | self.func = func 83 | self.left_args_list = args_list 84 | self.n_tasks = len(args_list) 85 | 86 | # initial work 87 | self.__start_works(self.__get_n_pending_works()) 88 | 89 | def __start_works(self, times): 90 | for _ in range(times): 91 | args = self.left_args_list.pop(0) 92 | self.async_result.append( 93 | self.pool.apply_async(self.func, args)) 94 | 95 | def __get_n_pending_works(self): 96 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result), len(self.left_args_list)) 97 | 98 | def pop_results(self): 99 | 100 | n_inwork = len(self.async_result) 101 | 102 | n_fetch = min(n_inwork, self.batch_size) 103 | rtn = [self.async_result.pop(0).get() 104 | for _ in range(n_fetch)] 105 | 106 | to_fill = self.__get_n_pending_works() 107 | if to_fill == 0: 108 | self.pool.close() 109 | else: 110 | self.__start_works(to_fill) 111 | 112 | return rtn 113 | 114 | 115 | def get_indices_of_pairs(radius, size): 116 | 117 | search_dist = [] 118 | 119 | for x in range(1, radius): 120 | search_dist.append((0, x)) 121 | 122 | for y in range(1, radius): 123 | for x in range(-radius + 1, radius): 124 | if x * x + y * y < radius * radius: 125 | search_dist.append((y, x)) 126 | 127 | radius_floor = radius - 1 128 | 129 | full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), 130 | (size[0], size[1])) 131 | 132 | cropped_height = size[0] - radius_floor 133 | cropped_width = size[1] - 2 * radius_floor 134 | 135 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 136 | [-1]) 137 | 138 | indices_to_list = [] 139 | 140 | for dy, dx in search_dist: 141 | indices_to = full_indices[dy:dy + cropped_height, 142 | radius_floor + dx:radius_floor + dx + cropped_width] 143 | indices_to = np.reshape(indices_to, [-1]) 144 | 145 | indices_to_list.append(indices_to) 146 | 147 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 148 | 149 | return indices_from, concat_indices_to 150 | 151 | 152 | def get_optimizer(model, cfg): 153 | # backbone_params = list(map(id, model.module.backbone.parameters())) 154 | # base_params = filter(lambda p: id( 155 | # p) not in backbone_params, model.parameters()) 156 | # params = [ 157 | # {'params': base_params}, 158 | # {'params': model.module.backbone.parameters(), 'lr': cfg.backbone_lr} 159 | # ] 160 | 161 | if cfg.type == 'Adam': 162 | opt = torch.optim.Adam(model.parameters(), cfg.lr) 163 | elif cfg.type == 'AdamW': 164 | opt = torch.optim.AdamW(model.parameters(), cfg.lr) 165 | elif cfg.type=='SGD': 166 | opt=torch.optim.SGD(model.parameters(), cfg.lr) 167 | else: 168 | raise ValueError 169 | 170 | return opt 171 | -------------------------------------------------------------------------------- /utils/vis_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import os 5 | import shutil 6 | import logging 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import sys 12 | import time 13 | import pandas as pd 14 | import pdb 15 | from torchvision import transforms 16 | 17 | 18 | def save_color_mask(pred_masks, save_base_path, video_name_list, v_pallete, resize, resized_mask_size, T=10): 19 | # pred_mask: [bs*5, N_CLASSES, 224, 224] 20 | # print(f"=> {len(video_name_list)} videos in this batch") 21 | 22 | if not os.path.exists(save_base_path): 23 | os.makedirs(save_base_path, exist_ok=True) 24 | 25 | BT, N_CLASSES, H, W = pred_masks.shape 26 | bs = BT // T 27 | 28 | pred_masks = torch.softmax(pred_masks, dim=1) 29 | pred_masks = torch.argmax(pred_masks, dim=1) # [BT, 224, 224] 30 | pred_masks = pred_masks.cpu().numpy() 31 | 32 | pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8) # [BT, H, W, 3] 33 | for cls_idx in range(N_CLASSES): 34 | rgb = v_pallete[cls_idx] 35 | pred_rgb_masks[pred_masks == cls_idx] = rgb 36 | pred_rgb_masks = pred_rgb_masks.reshape(bs, T, H, W, 3) 37 | 38 | for idx in range(bs): 39 | video_name = video_name_list[idx] 40 | mask_save_path = os.path.join(save_base_path, video_name) 41 | if not os.path.exists(mask_save_path): 42 | os.makedirs(mask_save_path, exist_ok=True) 43 | one_video_masks = pred_rgb_masks[idx] # [5, 224, 224, 3] 44 | for video_id in range(len(one_video_masks)): 45 | one_mask = one_video_masks[video_id] 46 | output_name = "%s_%d.png"%(video_name, video_id) 47 | im = Image.fromarray(one_mask)#.convert('RGB') 48 | if resize: 49 | im = im.resize(resized_mask_size) 50 | im.save(os.path.join(mask_save_path, output_name), format='PNG') 51 | --------------------------------------------------------------------------------