├── LICENSE ├── README.md ├── __init__.py ├── assets ├── fig1.png ├── fig2.png └── fig3.png ├── configs ├── __init__.py └── config.py ├── data └── metadata.csv ├── datasets ├── __init__.py └── dataset_refavs.py ├── logs └── write_log.py ├── models ├── __init__.py ├── avs_model.py └── local │ └── mask2former │ ├── __init__.py │ ├── image_processor_m2f.py │ ├── model_m2f.py │ └── refavs_transformer.py ├── requirements.txt ├── run.sh ├── run_refavs.py ├── scripts ├── __init__.py └── train.py └── utils ├── __init__.py └── metric ├── pyutils.py └── utility.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 GeWu-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ref-AVS 2 | The official repo for "Ref-AVS: Refer and Segment Objects in Audio-Visual Scenes", ECCV 2024 3 | 4 | ### [Project Page](https://gewu-lab.github.io/Ref-AVS/) 5 | ### [Dataset Download](https://gewu-lab.github.io/Ref-AVS/#downloads) 6 | 7 | 8 | 9 | ### >>> Introduction 10 | In this paper, we propose a pixel-level segmentation task called **Ref**erring **A**udio-**V**isual **S**egmentation (Ref-AVS), which requires the network to densely predict whether each pixel corresponds to the given multimodal-cue expression, including dynamic audio-visual information. 11 | 12 | - Top-left of Fig.1 highlights the distinctions between Ref-AVS and previous tasks. 13 | ![Fig.1 Teaser](https://github.com/GeWu-Lab/Ref-AVS/blob/main/assets/fig1.png) 14 | 15 | - Fig.2 shows the proposed baseline model to process multimodal-cues. 16 | ![Fig.2 Baseline](https://github.com/GeWu-Lab/Ref-AVS/blob/main/assets/fig2.png) 17 | 18 | - Fig.3 shows the statistics of this dataset. 19 | ![Fig.3 Statistics](https://github.com/GeWu-Lab/Ref-AVS/blob/main/assets/fig3.png) 20 | 21 | ### >>> Run 22 | Run the training & evaluation: 23 | ```python 24 | cd Ref_AVS 25 | sh run.sh # you should change your path configs. See /configs/config.py for more details. 26 | ``` 27 | You can download the [checkpoint](https://pan.baidu.com/s/1NrNv1hTIqI7QAvNSwl7dvw?pwd=hh58) here. 28 | 29 | Core dependencies: 30 | ``` 31 | transformers=4.30.2 32 | towhee=1.1.3 33 | towhee-models=1.1.3 # Towhee is used for extracting VGGish audio feature. 34 | ``` 35 | 36 | ### >>> FAQ 37 | ##### (1) Alternative Audio Feature Extraction 38 | If you found the towhee is hard to establish, please consider using the following code with Google CoLab: [link](https://colab.research.google.com/drive/1r_8OnmwXKwmH0n4RxBfuICVBgpbJt_Fs?usp=sharing#scrollTo=MJWFPPSoAQzF). 39 | 40 | ### Citation 41 | If you find this work useful, please consider citing it: 42 | ``` 43 | @article{wang2024refavs, 44 | title={Ref-AVS: Refer and Segment Objects in Audio-Visual Scenes}, 45 | author={Wang, Yaoting and Sun, Peiwen and Zhou, Dongzhan and Li, Guangyao and Zhang, Honggang and Hu, Di}, 46 | journal={IEEE European Conference on Computer Vision (ECCV)}, 47 | year={2024}, 48 | } 49 | 50 | @inproceedings{wang2024prompting, 51 | title={Prompting segmentation with sound is generalizable audio-visual source localizer}, 52 | author={Wang, Yaoting and Liu, Weisong and Li, Guangyao and Ding, Jian and Hu, Di and Li, Xi}, 53 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 54 | volume={38}, 55 | number={6}, 56 | pages={5669--5677}, 57 | year={2024} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/Ref-AVS/82eb6881ffe4b1ab1fd5c3734bf68303f6b9f816/__init__.py -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/Ref-AVS/82eb6881ffe4b1ab1fd5c3734bf68303f6b9f816/assets/fig1.png -------------------------------------------------------------------------------- /assets/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/Ref-AVS/82eb6881ffe4b1ab1fd5c3734bf68303f6b9f816/assets/fig2.png -------------------------------------------------------------------------------- /assets/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/Ref-AVS/82eb6881ffe4b1ab1fd5c3734bf68303f6b9f816/assets/fig3.png -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import args -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | import os 3 | 4 | import sys 5 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(BASE_DIR) 7 | 8 | import cv2 # type: ignore 9 | 10 | import argparse 11 | import json 12 | import os 13 | from typing import Any, Dict, List 14 | 15 | parser = argparse.ArgumentParser( 16 | description=( 17 | "Ref-AVS, ECCV'2024." 18 | ) 19 | ) 20 | 21 | parser.add_argument( 22 | "--train_params", 23 | type=list, 24 | default=[ 25 | 'audio_proj', 26 | 'text_proj', 27 | 'prompt_proj', 28 | 'avs_adapt', 29 | 'ref_avs_attn', 30 | ], 31 | help="Text model to extract textual reference feature.", 32 | ) 33 | 34 | parser.add_argument( 35 | "--text_model", 36 | type=str, 37 | default='distilbert/distilroberta-base', 38 | help="Text model to extract textual reference feature.", 39 | ) 40 | 41 | parser.add_argument( 42 | "--checkpoint", 43 | type=str, 44 | default=None, 45 | help="The path to load the refavs model checkpoints." 46 | ) 47 | parser.add_argument("--save_ckpt", type=str, default='./ckpt', help='Checkpoints save dir.') 48 | parser.add_argument("--log_path", type=str, default='./logs', help='Log info save path.') 49 | 50 | file_arch = """ 51 | ./data/REFAVS 52 | - /media 53 | - /gt_mask 54 | - /metadata.csv 55 | """ 56 | print(f">>> File arch: {file_arch}") 57 | parser.add_argument( 58 | "--data_dir", 59 | type=str, 60 | default='./data/REFAVS', 61 | help=f"The data paranet dir. File arch should be: {file_arch}" 62 | ) 63 | 64 | parser.add_argument("--show_params", action='store_true', help=f"Show params names with Requires_grad==True.") 65 | parser.add_argument("--m2f_model", type=str, default='facebook/mask2former-swin-base-ade-semantic', help="Pretrained mask2former.") 66 | 67 | parser.add_argument("--lr", type=float, default=1e-4, help='lr to fine tuning adapters.') 68 | parser.add_argument("--epochs", type=int, default=50, help='epochs to fine tuning adapters.') 69 | parser.add_argument("--loss", type=str, default='bce', help='') 70 | 71 | parser.add_argument("--train", default=False, action='store_true', help='start train?') 72 | parser.add_argument("--val", type=str, default=None, help='type: str; val | test') # NOTE: for test and val. 73 | parser.add_argument("--test", default=False, action='store_true', help='start test?') 74 | 75 | 76 | parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.") 77 | 78 | parser.add_argument("--run", type=str, default='train', help="train, test") 79 | 80 | parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.") 81 | parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.") 82 | 83 | 84 | 85 | args = parser.parse_args() 86 | 87 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 88 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 89 | print(f'>>> Sys: set "CUDA_VISIBLE_DEVICES" - GPU: {args.gpu_id}') 90 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_refavs import REFAVS -------------------------------------------------------------------------------- /datasets/dataset_refavs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset, DataLoader 6 | import pandas as pd 7 | import pdb 8 | 9 | import sys 10 | import os 11 | import random 12 | 13 | from torchvision import transforms 14 | from collections import defaultdict 15 | import cv2 16 | from transformers import AutoImageProcessor, AutoTokenizer, AutoModel 17 | from PIL import Image 18 | 19 | from towhee import pipe, ops 20 | from transformers import pipeline 21 | 22 | 23 | # logger = log_agent('audio_recs.log') 24 | 25 | import pickle as pkl 26 | 27 | class REFAVS(Dataset): 28 | def __init__(self, split='train', cfg=None): 29 | # metadata: train/test/val 30 | self.data_dir = cfg.data_dir 31 | meta_path = f'{self.data_dir}/metadata.csv' 32 | metadata = pd.read_csv(meta_path, header=0) 33 | self.split = split 34 | self.metadata = metadata[metadata['split'] == split] # split= train,test,val. 35 | 36 | self.media_path = f'{self.data_dir}/media' 37 | self.label_path = f'{self.data_dir}/gt_mask' 38 | self.frame_num = cfg.frame_n 39 | self.text_max_len = cfg.text_max_len 40 | 41 | # modalities processor/pipelines 42 | self.img_process = AutoImageProcessor.from_pretrained(cfg.m2f_model) 43 | 44 | self.audio_vggish_pipeline = ( # pipeline building 45 | pipe.input('path') 46 | .map('path', 'frame', ops.audio_decode.ffmpeg()) 47 | .map('frame', 'vecs', ops.audio_embedding.vggish()) 48 | .output('vecs') 49 | ) 50 | 51 | self.text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model) 52 | self.text_encoder = AutoModel.from_pretrained(cfg.text_model).cuda().eval() 53 | 54 | def get_audio_emb(self, wav_path): 55 | """ wav string path. """ 56 | emb = torch.tensor(self.audio_vggish_pipeline(wav_path).get()[0]) 57 | # print(len(emb)) 58 | return emb 59 | 60 | def get_text_emb(self, exp): 61 | """ readable textual reference. """ 62 | inputs = self.text_tokenizer(exp, max_length=25, padding="max_length", truncation=True, return_tensors="pt") 63 | inputs['input_ids'] = inputs['input_ids'].cuda() 64 | inputs['attention_mask'] = inputs['attention_mask'].cuda() 65 | with torch.no_grad(): 66 | emb = self.text_encoder(**inputs).last_hidden_state # [1, max_len, 768] 67 | return emb 68 | 69 | def __len__(self): 70 | return len(self.metadata) 71 | 72 | def __getitem__(self, idx): 73 | df_one_video = self.metadata.iloc[idx] 74 | vid, uid, fid, exp = df_one_video['vid'], df_one_video['uid'], df_one_video['fid'], df_one_video['exp'] # uid for vid. 75 | vid = uid.rsplit('_', 2)[0] # TODO: use encoded id. 76 | 77 | img_recs = [] 78 | mask_recs = [] 79 | images = [] 80 | 81 | rec_audio = f'{self.media_path}/{vid}/audio.wav' 82 | rec_text = exp 83 | 84 | feat_aud = self.get_audio_emb(rec_audio) 85 | feat_text = self.get_text_emb(rec_text) 86 | 87 | for _idx in range(self.frame_num): # set frame_num as the batch_size 88 | # frame 89 | path_frame = f'{self.media_path}/{vid}/frames/{_idx}.jpg' # image 90 | image = Image.open(path_frame) 91 | image_sizes = [image.size[::-1]] 92 | image_inputs = self.img_process(image, return_tensors="pt") # singe frame rec 93 | 94 | # mask label 95 | path_mask = f'{self.label_path}/{vid}/fid_{fid}/0000{_idx}.png' # new 96 | mask_cv2 = cv2.imread(path_mask) 97 | mask_cv2 = cv2.resize(mask_cv2, (256, 256)) 98 | mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY) 99 | gt_binary_mask = torch.as_tensor(mask_cv2 > 0, dtype=torch.float32) 100 | 101 | # video frames collect 102 | img_recs.append(image_inputs) 103 | mask_recs.append(gt_binary_mask) 104 | 105 | return vid, mask_recs, img_recs, image_sizes, feat_aud, feat_text, rec_audio, rec_text -------------------------------------------------------------------------------- /logs/write_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | DEFAULT = "/home/yaoting_wang/workplace/Mask2Former/AVS/logs/log" 5 | 6 | def write_log(message, _file_name, _dir_name=DEFAULT, tag=None, once=False): 7 | file_path = f"{_dir_name}/{_file_name}" 8 | 9 | with open(file_path, 'a') as file: 10 | if once: 11 | file.write(f'>>> {"="*60}\n') 12 | currentDateAndTime = datetime.now().strftime("%y%m%d_%H_%M_%S_%f\n") 13 | file.write(f"--- {currentDateAndTime}\n") 14 | file.write(f"--- Tag: {tag}\n") 15 | file.write(f'--- {message}\n') 16 | file.close() 17 | 18 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .avs_model import REFAVS_Model_Base -------------------------------------------------------------------------------- /models/avs_model.py: -------------------------------------------------------------------------------- 1 | from models.local.mask2former import Mask2FormerImageProcessorForRefAVS 2 | from models.local.mask2former import Mask2FormerForRefAVS 3 | from models.local.mask2former import logging 4 | 5 | from PIL import Image 6 | import requests 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | from torch.nn import Module 11 | import re 12 | import matplotlib.pyplot as plt 13 | 14 | logging.set_verbosity_error() 15 | 16 | 17 | image_processor = Mask2FormerImageProcessorForRefAVS.from_pretrained("facebook/mask2former-swin-base-ade-semantic") 18 | model_m2f = Mask2FormerForRefAVS.from_pretrained( 19 | "facebook/mask2former-swin-base-ade-semantic" 20 | ) 21 | 22 | # avs_dataset = AVS() 23 | 24 | class REFAVS_Model_Base(nn.Module): 25 | def __init__(self, cfgs): 26 | super().__init__() 27 | self.model_v = model_m2f.cuda() 28 | 29 | self.dim_v = 1024 30 | self.num_heads = 8 31 | 32 | self.audio_proj = nn.Sequential( 33 | nn.Linear(128, 2048), 34 | nn.ReLU(), 35 | nn.Linear(2048, self.dim_v), 36 | ) 37 | 38 | self.text_proj = nn.Sequential( 39 | nn.Linear(768, 2048), 40 | nn.ReLU(), 41 | nn.Linear(2048, self.dim_v), 42 | ) 43 | 44 | self.prompt_proj = nn.Sequential( 45 | nn.Linear(1024, 2048), 46 | nn.ReLU(), 47 | nn.Linear(2048, 256), 48 | ) 49 | 50 | self.cfgs = cfgs 51 | 52 | self.loss_fn = F.binary_cross_entropy_with_logits # 'bce' 53 | 54 | self.mha_A_T = nn.MultiheadAttention(self.dim_v, self.num_heads) 55 | self.mha_V_T = nn.MultiheadAttention(self.dim_v, self.num_heads) 56 | self.mha_mm = nn.MultiheadAttention(self.dim_v, self.num_heads) 57 | 58 | self.cache_mem_beta = 1 59 | 60 | def fusion_mm_to_text(self, feat_a_or_v, feat_t): 61 | assert feat_a_or_v.shape[-1] == self.dim_v 62 | assert feat_t.shape[-1] == self.dim_v 63 | return torch.concat((feat_a_or_v, feat_t), dim=1) 64 | 65 | def process_with_cached_memory(self, feat_mm): 66 | feat_beta = feat_mm * (self.cache_mem_beta + 1) 67 | cached_mem = torch.cumsum(feat_mm, dim=0) 68 | mean_feat_at_each_time_step = cached_mem / torch.arange(1, feat_mm.shape[0] + 1).view(-1, 1, 1).cuda() 69 | diff_feat = feat_beta - mean_feat_at_each_time_step 70 | return diff_feat 71 | 72 | def forward(self, batch_data): 73 | uid, mask_recs, img_recs, image_sizes, feat_aud, feat_text, rec_audio, rec_text = batch_data 74 | bsz = len(uid) 75 | frame_n = len(img_recs[0]) 76 | loss_uid = [] 77 | uid_preds = [] 78 | assert len(uid) == len(img_recs) and len(uid) == len(rec_text) 79 | 80 | mask_recs = [torch.stack(rec) for rec in mask_recs] 81 | gt_label = torch.stack(mask_recs).view(bsz*frame_n, mask_recs[0].shape[-2], mask_recs[0].shape[-1]).squeeze().cuda() 82 | 83 | feat_aud = torch.stack(feat_aud).cuda() 84 | feat_text = torch.stack(feat_text).cuda() 85 | feat_aud = self.audio_proj(feat_aud).view(bsz, feat_aud.shape[-2], self.dim_v) 86 | feat_text = self.text_proj(feat_text).view(bsz, feat_text.shape[-2], self.dim_v) 87 | 88 | batch_pixel_values, batch_pixel_mask = [], [] 89 | 90 | for idx, _ in enumerate(uid): 91 | img_input = img_recs[idx] 92 | 93 | for img in img_input: 94 | batch_pixel_values.append(img['pixel_values']) 95 | batch_pixel_mask.append(img['pixel_mask']) 96 | 97 | batch_pixel_values = torch.stack(batch_pixel_values).squeeze().cuda() 98 | batch_pixel_mask = torch.stack(batch_pixel_mask).squeeze().cuda() 99 | 100 | batch_input = { 101 | 'pixel_values': batch_pixel_values, 102 | 'pixel_mask': batch_pixel_mask, 103 | 'mask_labels': gt_label 104 | } 105 | 106 | outputs = self.model_v(**batch_input) 107 | feat_vis = outputs['encoder_last_hidden_state'].view(bsz, self.dim_v, 12*12, frame_n).view(bsz, -1, self.dim_v) 108 | 109 | fused_T_with_A = self.fusion_mm_to_text(feat_aud, feat_text) 110 | fused_T_with_V = self.fusion_mm_to_text(feat_vis, feat_text) 111 | 112 | fused_T_with_A = fused_T_with_A.permute(1, 0, 2) 113 | fused_T_with_V = fused_T_with_V.permute(1, 0, 2) 114 | 115 | fused_T_with_A, _ = self.mha_A_T(fused_T_with_A, fused_T_with_A, fused_T_with_A) 116 | fused_T_with_V, _ = self.mha_V_T(fused_T_with_V, fused_T_with_V, fused_T_with_V) 117 | 118 | fused_T_with_A_part_A, fused_T_with_A_part_T = \ 119 | fused_T_with_A[:feat_aud.shape[1], :, :], fused_T_with_A[feat_aud.shape[1]:, :, :] 120 | fused_T_with_V_part_V, fused_T_with_V_part_T = \ 121 | fused_T_with_V[:feat_vis.shape[1], :, :], fused_T_with_V[feat_vis.shape[1]:, :, :] 122 | 123 | assert fused_T_with_A_part_A.shape[0] + fused_T_with_A_part_T.shape[0] == fused_T_with_A.shape[0] 124 | 125 | cues_A = self.process_with_cached_memory(fused_T_with_A_part_A).permute(1, 0, 2) # [bsz, len, dim_v] 126 | cues_V = self.process_with_cached_memory(fused_T_with_V_part_V).permute(1, 0, 2) 127 | cues_T = (feat_text + fused_T_with_A_part_T.permute(1, 0, 2) + fused_T_with_V_part_T.permute(1, 0, 2)) \ 128 | / torch.tensor(3.0).cuda() 129 | 130 | tag_A = torch.full([bsz, 1, self.dim_v], 0).cuda() 131 | tag_V = torch.full([bsz, 1, self.dim_v], 1).cuda() 132 | 133 | cues_V = cues_V.view(bsz, frame_n, 12*12, self.dim_v) 134 | 135 | batch_prompt_emb = [] 136 | for f in range(frame_n): 137 | cues_V_f = cues_V[:, f] 138 | cues_mm = torch.concat([cues_A, tag_A, cues_V_f, tag_V, cues_T], dim=1) 139 | cues_mm, _ = self.mha_mm(cues_mm, cues_mm, cues_mm) 140 | batch_prompt_emb.append(cues_mm) 141 | 142 | batch_prompt_emb = torch.stack(batch_prompt_emb).permute(1, 0, 2, 3) 143 | batch_prompt_emb = batch_prompt_emb.contiguous().view(bsz*frame_n, batch_prompt_emb.shape[-2], self.dim_v) 144 | batch_prompt_emb = self.prompt_proj(batch_prompt_emb) 145 | 146 | batch_input = { 147 | 'pixel_values': batch_pixel_values, 148 | 'pixel_mask': batch_pixel_mask, 149 | 'prompt_features_projected': batch_prompt_emb, 150 | 'mask_labels': gt_label 151 | } 152 | 153 | outputs = self.model_v(**batch_input) 154 | 155 | pred_instance_map = image_processor.post_process_semantic_segmentation( 156 | outputs, target_sizes=[[256, 256]]*(bsz*frame_n), 157 | ) 158 | 159 | pred_instance_map = torch.stack(pred_instance_map, dim=0).view(bsz*frame_n, 256, 256) 160 | 161 | loss_frame = self.loss_fn(input=pred_instance_map.squeeze(), target=gt_label.squeeze().cuda()) 162 | loss_uid.append(loss_frame) 163 | uid_preds.append(pred_instance_map.squeeze()) 164 | 165 | return loss_uid, uid_preds 166 | -------------------------------------------------------------------------------- /models/local/mask2former/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_processor_m2f import * 2 | from .model_m2f import * 3 | from .refavs_transformer import * -------------------------------------------------------------------------------- /models/local/mask2former/image_processor_m2f.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.models.mask2former.image_processing_mask2former import Mask2FormerImageProcessor 3 | 4 | from typing import Dict, List, Optional, Tuple 5 | from torch import Tensor, nn 6 | 7 | class Mask2FormerImageProcessorForRefAVS(Mask2FormerImageProcessor): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | def post_process_semantic_segmentation( 12 | self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None 13 | ) -> "torch.Tensor": 14 | """ 15 | Converts the output of [`Mask2FormerForUniversalSegmentation`] into semantic segmentation maps. Only supports 16 | PyTorch. 17 | 18 | Args: 19 | outputs ([`Mask2FormerForUniversalSegmentation`]): 20 | Raw outputs of the model. 21 | target_sizes (`List[Tuple[int, int]]`, *optional*): 22 | List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested 23 | final size (height, width) of each prediction. If left to None, predictions will not be resized. 24 | Returns: 25 | `List[torch.Tensor]`: 26 | A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) 27 | corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each 28 | `torch.Tensor` correspond to a semantic class id. 29 | """ 30 | AVS_BINARY = True 31 | 32 | class_queries_logits = outputs.class_queries_logits 33 | bsz = class_queries_logits.shape[0] 34 | if AVS_BINARY: 35 | class_queries_logits = outputs.class_queries_logits[:, :, 0].view(bsz, 100, 1) 36 | null_queries_logits = outputs.class_queries_logits[:, :, -1].view(bsz, 100, 1) 37 | class_queries_logits = torch.concat([class_queries_logits, null_queries_logits], dim=-1) 38 | 39 | masks_queries_logits = outputs.masks_queries_logits 40 | 41 | # Scale back to preprocessed image size - (384, 384) for all models 42 | masks_queries_logits = torch.nn.functional.interpolate( 43 | masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False 44 | ) 45 | 46 | # Remove the null class `[..., :-1]` 47 | masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] # [1, 100, 1] 48 | 49 | masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] 50 | if AVS_BINARY: 51 | masks_probs = masks_queries_logits # .sigmoid() 52 | 53 | segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) 54 | batch_size = class_queries_logits.shape[0] 55 | 56 | # Resize logits and compute semantic segmentation maps 57 | if target_sizes is not None: 58 | if batch_size != len(target_sizes): 59 | print(f'bsz: {batch_size} | target: {target_sizes}') 60 | raise ValueError( 61 | "Make sure that you pass in as many target sizes as the batch dimension of the logits" 62 | ) 63 | 64 | semantic_segmentation = [] 65 | for idx in range(batch_size): 66 | resized_logits = torch.nn.functional.interpolate( 67 | segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False 68 | ) 69 | 70 | semantic_map = resized_logits[0].argmax(dim=0) 71 | if AVS_BINARY: 72 | semantic_map = resized_logits[0] 73 | 74 | semantic_segmentation.append(semantic_map) 75 | else: 76 | 77 | semantic_segmentation = segmentation.argmax(dim=1) 78 | semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] 79 | 80 | return semantic_segmentation -------------------------------------------------------------------------------- /models/local/mask2former/model_m2f.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, logging 3 | from transformers.models.mask2former.configuration_mask2former import Mask2FormerConfig 4 | from transformers.models.mask2former.modeling_mask2former import Mask2FormerModel, Mask2FormerTransformerModule, Mask2FormerModelOutput, Mask2FormerPixelDecoderEncoderLayer 5 | from transformers.models.mask2former.modeling_mask2former import Mask2FormerPixelLevelModule, Mask2FormerPixelDecoder, Mask2FormerPixelDecoderEncoderOnly 6 | from transformers.models.mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentationOutput, Mask2FormerMaskedAttentionDecoderOutput 7 | from transformers.models.mask2former.modeling_mask2former import Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention 8 | from transformers.models.mask2former.modeling_mask2former import multi_scale_deformable_attention 9 | 10 | from typing import Dict, List, Optional, Tuple 11 | from torch import Tensor, nn 12 | 13 | from .refavs_transformer import REF_AVS_Transformer 14 | 15 | 16 | class Mask2FormerTransformerModuleForRefAVS(Mask2FormerTransformerModule): 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | 20 | print('>>> Init m2f for refavs...') 21 | self.ref_avs_attn = REF_AVS_Transformer() 22 | 23 | def prefix_tuning(self, prompt, feature): 24 | feature[:23] = prompt + feature[:2] 25 | return feature 26 | 27 | def check_transformer(self): 28 | # print('>>> Using new module.') 29 | ... 30 | 31 | 32 | def forward( 33 | self, 34 | multi_scale_features: List[Tensor], 35 | mask_features: Tensor, 36 | prompt_features_projected: Tensor = None, 37 | output_hidden_states: bool = False, 38 | output_attentions: bool = False, 39 | ) -> Mask2FormerMaskedAttentionDecoderOutput: 40 | 41 | 42 | multi_stage_features = [] 43 | multi_stage_positional_embeddings = [] 44 | size_list = [] 45 | 46 | for i in range(self.num_feature_levels): 47 | size_list.append(multi_scale_features[i].shape[-2:]) 48 | multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) 49 | multi_stage_features.append( 50 | self.input_projections[i](multi_scale_features[i]).flatten(2) 51 | + self.level_embed.weight[i][None, :, None] 52 | ) 53 | 54 | multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) 55 | multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) 56 | 57 | _, batch_size, _ = multi_stage_features[0].shape 58 | 59 | query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) 60 | query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) 61 | 62 | if prompt_features_projected is not None: 63 | 64 | bsz = prompt_features_projected.shape[0] 65 | num_queries = query_features.shape[0] 66 | 67 | query_features = self.ref_avs_attn(target=query_features, source=prompt_features_projected) 68 | 69 | decoder_output = self.decoder( 70 | inputs_embeds=query_features, 71 | multi_stage_positional_embeddings=multi_stage_positional_embeddings, 72 | pixel_embeddings=mask_features, 73 | encoder_hidden_states=multi_stage_features, 74 | query_position_embeddings=query_embeddings, 75 | feature_size_list=size_list, 76 | output_hidden_states=output_hidden_states, 77 | output_attentions=output_attentions, 78 | return_dict=True, 79 | ) 80 | 81 | return decoder_output 82 | 83 | class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttentionForRefAVS(Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): 84 | def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): 85 | super().__init__(embed_dim, num_heads, n_levels, n_points) 86 | self.avs_adapt = nn.Sequential( 87 | nn.Linear(embed_dim, embed_dim//4), 88 | nn.ReLU(), 89 | nn.Linear(embed_dim//4, embed_dim), 90 | ) 91 | def forward( 92 | self, 93 | hidden_states: torch.Tensor, 94 | attention_mask: Optional[torch.Tensor] = None, 95 | encoder_hidden_states=None, 96 | encoder_attention_mask=None, 97 | position_embeddings: Optional[torch.Tensor] = None, 98 | reference_points=None, 99 | spatial_shapes=None, 100 | level_start_index=None, 101 | output_attentions: bool = False, 102 | ): 103 | # add position embeddings to the hidden states before projecting to queries and keys 104 | if position_embeddings is not None: 105 | hidden_states = self.with_pos_embed(hidden_states, position_embeddings) 106 | 107 | batch_size, num_queries, _ = hidden_states.shape 108 | batch_size, sequence_length, _ = encoder_hidden_states.shape 109 | if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: 110 | raise ValueError( 111 | "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" 112 | ) 113 | 114 | value = self.value_proj(encoder_hidden_states) 115 | if attention_mask is not None: 116 | # we invert the attention_mask 117 | value = value.masked_fill(attention_mask[..., None], float(0)) 118 | value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) 119 | sampling_offsets = self.sampling_offsets(hidden_states).view( 120 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 121 | ) 122 | attention_weights = self.attention_weights(hidden_states).view( 123 | batch_size, num_queries, self.n_heads, self.n_levels * self.n_points 124 | ) 125 | attention_weights = nn.functional.softmax(attention_weights, -1).view( 126 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points 127 | ) 128 | # batch_size, num_queries, n_heads, n_levels, n_points, 2 129 | if reference_points.shape[-1] == 2: 130 | offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) 131 | sampling_locations = ( 132 | reference_points[:, :, None, :, None, :] 133 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 134 | ) 135 | elif reference_points.shape[-1] == 4: 136 | sampling_locations = ( 137 | reference_points[:, :, None, :, None, :2] 138 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 139 | ) 140 | else: 141 | raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") 142 | 143 | output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) 144 | output = self.output_proj(output) 145 | 146 | output = self.avs_adapt(output) 147 | 148 | return output, attention_weights 149 | 150 | 151 | class Mask2FormerPixelDecoderEncoderLayerForRefAVS(Mask2FormerPixelDecoderEncoderLayer): 152 | def __init__(self, config: Mask2FormerConfig): 153 | super().__init__(config) 154 | self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttentionForRefAVS( 155 | embed_dim=self.embed_dim, 156 | num_heads=config.num_attention_heads, 157 | n_levels=3, 158 | n_points=4, 159 | ) 160 | 161 | class Mask2FormerPixelDecoderEncoderOnlyForRefAVS(Mask2FormerPixelDecoderEncoderOnly): 162 | def __init__(self, config: Mask2FormerConfig): 163 | super().__init__(config) 164 | self.layers = nn.ModuleList( 165 | [Mask2FormerPixelDecoderEncoderLayerForRefAVS(config) for _ in range(config.encoder_layers)] 166 | ) 167 | 168 | class Mask2FormerPixelDecoderForRefAVS(Mask2FormerPixelDecoder): 169 | def __init__(self, config: Mask2FormerConfig, feature_channels): 170 | super().__init__(config, feature_channels) 171 | self.encoder = Mask2FormerPixelDecoderEncoderOnlyForRefAVS(config) 172 | 173 | class Mask2FormerPixelLevelModuleForRefAVS(Mask2FormerPixelLevelModule): 174 | def __init__(self, config: Mask2FormerConfig): 175 | super().__init__(config) 176 | self.decoder = Mask2FormerPixelDecoderForRefAVS(config, feature_channels=self.encoder.channels) 177 | 178 | 179 | class Mask2FormerModelForRefAVS(Mask2FormerModel): 180 | def __init__(self, config, *args, **kwargs): 181 | super().__init__(config, *args, **kwargs) 182 | 183 | self.pixel_level_module = Mask2FormerPixelLevelModuleForRefAVS(config) 184 | self.transformer_module = Mask2FormerTransformerModuleForRefAVS(in_features=config.feature_size, config=config) 185 | 186 | def forward( 187 | self, 188 | pixel_values: Tensor, 189 | pixel_mask: Optional[Tensor] = None, 190 | prompt_features_projected: Optional[Tensor] = None, 191 | output_hidden_states: Optional[bool] = None, 192 | output_attentions: Optional[bool] = None, 193 | return_dict: Optional[bool] = None, 194 | # memory_last_hidden=None, 195 | ) -> Mask2FormerModelOutput: 196 | 197 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 198 | output_hidden_states = ( 199 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 200 | ) 201 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 202 | 203 | batch_size, _, height, width = pixel_values.shape 204 | 205 | if pixel_mask is None: 206 | pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) 207 | 208 | pixel_level_module_output = self.pixel_level_module( 209 | pixel_values=pixel_values, output_hidden_states=output_hidden_states, 210 | ) 211 | 212 | 213 | transformer_module_output = self.transformer_module( 214 | prompt_features_projected=prompt_features_projected, 215 | multi_scale_features=pixel_level_module_output.decoder_hidden_states, 216 | mask_features=pixel_level_module_output.decoder_last_hidden_state, 217 | output_hidden_states=True, 218 | output_attentions=output_attentions, 219 | ) 220 | 221 | encoder_hidden_states = None 222 | pixel_decoder_hidden_states = None 223 | transformer_decoder_hidden_states = None 224 | transformer_decoder_intermediate_states = None 225 | 226 | if output_hidden_states: 227 | encoder_hidden_states = pixel_level_module_output.encoder_hidden_states 228 | pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states 229 | transformer_decoder_hidden_states = transformer_module_output.hidden_states 230 | transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states 231 | 232 | output = Mask2FormerModelOutput( 233 | encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, 234 | pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, 235 | transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, 236 | encoder_hidden_states=encoder_hidden_states, 237 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 238 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 239 | transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, 240 | attentions=transformer_module_output.attentions, 241 | masks_queries_logits=transformer_module_output.masks_queries_logits, 242 | ) 243 | 244 | if not return_dict: 245 | output = tuple(v for v in output.values() if v is not None) 246 | 247 | return output 248 | 249 | 250 | 251 | class Mask2FormerForRefAVS(Mask2FormerForUniversalSegmentation): 252 | def __init__(self, config: Mask2FormerConfig, *args, **kwargs): 253 | super().__init__(config, *args, **kwargs) 254 | self.model = Mask2FormerModelForRefAVS(config) 255 | 256 | def forward( 257 | self, 258 | pixel_values: Tensor, 259 | prompt_features_projected: Optional[Tensor] = None, 260 | mask_labels: Optional[List[Tensor]] = None, 261 | class_labels: Optional[List[Tensor]] = None, 262 | pixel_mask: Optional[Tensor] = None, 263 | output_hidden_states: Optional[bool] = None, 264 | output_auxiliary_logits: Optional[bool] = None, 265 | output_attentions: Optional[bool] = None, 266 | return_dict: Optional[bool] = None, 267 | # memory_last_hidden=None, 268 | ) -> Mask2FormerForUniversalSegmentationOutput: 269 | 270 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 271 | output_hidden_states = ( 272 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 273 | ) 274 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 275 | 276 | outputs = self.model( 277 | pixel_values=pixel_values, 278 | pixel_mask=pixel_mask, 279 | prompt_features_projected=prompt_features_projected, 280 | output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, 281 | output_attentions=output_attentions, 282 | return_dict=True, 283 | ) 284 | 285 | loss, loss_dict, auxiliary_logits = None, None, None 286 | class_queries_logits = () 287 | 288 | for decoder_output in outputs.transformer_decoder_intermediate_states: 289 | class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) 290 | class_queries_logits += (class_prediction,) 291 | 292 | masks_queries_logits = outputs.masks_queries_logits 293 | 294 | auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) 295 | 296 | if mask_labels is not None and class_labels is not None: 297 | loss_dict = self.get_loss_dict( 298 | masks_queries_logits=masks_queries_logits[-1], 299 | class_queries_logits=class_queries_logits[-1], 300 | mask_labels=mask_labels, 301 | class_labels=class_labels, 302 | auxiliary_predictions=auxiliary_logits, 303 | ) 304 | loss = self.get_loss(loss_dict) 305 | 306 | encoder_hidden_states = None 307 | pixel_decoder_hidden_states = None 308 | transformer_decoder_hidden_states = None 309 | 310 | if output_hidden_states: 311 | encoder_hidden_states = outputs.encoder_hidden_states 312 | pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states 313 | transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states 314 | 315 | output_auxiliary_logits = ( 316 | self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits 317 | ) 318 | if not output_auxiliary_logits: 319 | auxiliary_logits = None 320 | 321 | output = Mask2FormerForUniversalSegmentationOutput( 322 | loss=loss, 323 | class_queries_logits=class_queries_logits[-1], 324 | masks_queries_logits=masks_queries_logits[-1], 325 | auxiliary_logits=auxiliary_logits, 326 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 327 | pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, 328 | transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, 329 | encoder_hidden_states=encoder_hidden_states, 330 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 331 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 332 | attentions=outputs.attentions, 333 | ) 334 | 335 | if not return_dict: 336 | output = tuple(v for v in output.values() if v is not None) 337 | if loss is not None: 338 | output = ((loss)) + output 339 | return output 340 | 341 | -------------------------------------------------------------------------------- /models/local/mask2former/refavs_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class REF_AVS_Transformer(nn.Module): 6 | def __init__(self, embed_dim=256, num_heads=4): 7 | super(REF_AVS_Transformer, self).__init__() 8 | self.num_heads = num_heads 9 | self.head_dim = embed_dim // num_heads # 256 10 | self.scaling = self.head_dim ** -0.5 11 | 12 | self.query_embedding = nn.Linear(embed_dim, embed_dim) 13 | self.key_embedding = nn.Linear(embed_dim, embed_dim) 14 | self.value_embedding = nn.Linear(embed_dim, embed_dim) 15 | 16 | self.out_projection = nn.Linear(embed_dim, embed_dim) 17 | 18 | self.beta_source_pool = nn.Parameter(torch.ones([1])) 19 | self.beta_source_attn = nn.Parameter(torch.ones([1])) 20 | 21 | def forward(self, target, source): 22 | seq_len_tgt, bsz, dim = target.size() 23 | 24 | _, seq_len_src, _ = source.size() 25 | seq_len_q = seq_len_tgt 26 | seq_len_kv = seq_len_src 27 | 28 | q = self.query_embedding(target.permute(1, 0, 2)) 29 | k = self.key_embedding(source) 30 | v = self.value_embedding(source) 31 | 32 | q = q.view(bsz, seq_len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 33 | k = k.view(bsz, seq_len_kv, self.num_heads, self.head_dim).permute(0, 2, 3, 1) 34 | v = v.view(bsz, seq_len_kv, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 35 | 36 | scores = torch.matmul(q, k) * self.scaling 37 | attention_weights = F.softmax(scores, dim=-1) 38 | attended_values = torch.matmul(attention_weights, v) 39 | attended_values = attended_values.permute(0, 2, 1, 3).reshape(bsz, seq_len_q, -1) 40 | 41 | output = self.out_projection(attended_values) 42 | 43 | output = nn.Sigmoid()(self.beta_source_attn) * output 44 | source_pool = nn.Sigmoid()(self.beta_source_pool) * torch.mean(source, dim=1).view(1, bsz, 256) 45 | 46 | return target + output.permute(1, 0, 2) + source_pool -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | accelerate=0.29.2=pypi_0 7 | aiofiles=23.2.1=pypi_0 8 | altair=5.3.0=pypi_0 9 | annotated-types=0.6.0=pypi_0 10 | antlr4-python3-runtime=4.9.3=pypi_0 11 | anyio=4.3.0=pypi_0 12 | appdirs=1.4.4=pypi_0 13 | attrs=23.2.0=pypi_0 14 | av=12.0.0=pypi_0 15 | backports-tarfile=1.2.0=pypi_0 16 | bitsandbytes=0.43.0=pypi_0 17 | blas=1.0=mkl 18 | braceexpand=0.1.7=pypi_0 19 | brotli-python=1.0.9=py310h6a678d5_7 20 | bypy=1.8.5=pypi_0 21 | bzip2=1.0.8=h5eee18b_5 22 | ca-certificates=2024.3.11=h06a4308_0 23 | certifi=2024.2.2=py310h06a4308_0 24 | cffi=1.16.0=pypi_0 25 | chardet=5.2.0=pypi_0 26 | charset-normalizer=2.0.4=pyhd3eb1b0_0 27 | click=8.1.7=pypi_0 28 | cloudpickle=3.0.0=pypi_0 29 | contourpy=1.2.1=pypi_0 30 | cryptography=42.0.8=pypi_0 31 | cycler=0.12.1=pypi_0 32 | cython=3.0.10=pypi_0 33 | decorator=4.4.2=pypi_0 34 | decord=0.6.0=pypi_0 35 | dill=0.3.8=pypi_0 36 | distro=1.9.0=pypi_0 37 | docker-pycreds=0.4.0=pypi_0 38 | docopt=0.6.2=pypi_0 39 | docutils=0.21.2=pypi_0 40 | einops=0.8.0=pypi_0 41 | exceptiongroup=1.2.0=pypi_0 42 | fastapi=0.110.1=pypi_0 43 | ffmpy=0.3.2=pypi_0 44 | filelock=3.13.1=py310h06a4308_0 45 | fonttools=4.51.0=pypi_0 46 | fsspec=2023.10.0=py310h06a4308_0 47 | ftfy=6.2.0=pypi_0 48 | fvcore=0.1.5.post20221221=pypi_0 49 | gitdb=4.0.11=pypi_0 50 | gitpython=3.1.43=pypi_0 51 | gradio=4.26.0=pypi_0 52 | gradio-client=0.15.1=pypi_0 53 | h11=0.14.0=pypi_0 54 | h5py=3.11.0=pypi_0 55 | httpcore=1.0.5=pypi_0 56 | httpx=0.27.0=pypi_0 57 | huggingface-hub=0.23.4=pypi_0 58 | idna=3.4=py310h06a4308_0 59 | imageio=2.34.0=pypi_0 60 | imageio-ffmpeg=0.4.9=pypi_0 61 | img2txt-py=2.4=pypi_0 62 | importlib-metadata=7.0.1=py310h06a4308_0 63 | importlib-resources=6.4.0=pypi_0 64 | intel-openmp=2023.1.0=hdb19cb5_46306 65 | iopath=0.1.10=pypi_0 66 | jaraco-classes=3.4.0=pypi_0 67 | jaraco-context=5.3.0=pypi_0 68 | jaraco-functools=4.0.1=pypi_0 69 | jeepney=0.8.0=pypi_0 70 | jinja2=3.1.3=pypi_0 71 | joblib=1.4.0=pypi_0 72 | jsonschema=4.21.1=pypi_0 73 | jsonschema-specifications=2023.12.1=pypi_0 74 | keyring=25.2.1=pypi_0 75 | kiwisolver=1.4.5=pypi_0 76 | lazy-loader=0.4=pypi_0 77 | ld_impl_linux-64=2.38=h1181459_1 78 | libffi=3.4.4=h6a678d5_0 79 | libgcc-ng=11.2.0=h1234567_1 80 | libgomp=11.2.0=h1234567_1 81 | libstdcxx-ng=11.2.0=h1234567_1 82 | libuuid=1.41.5=h5eee18b_0 83 | markdown-it-py=3.0.0=pypi_0 84 | markupsafe=2.1.5=pypi_0 85 | matplotlib=3.8.4=pypi_0 86 | mdurl=0.1.2=pypi_0 87 | mkl=2023.1.0=h213fc3f_46344 88 | mkl-service=2.4.0=py310h5eee18b_1 89 | mkl_fft=1.3.8=py310h5eee18b_0 90 | mkl_random=1.2.4=py310hdb19cb5_0 91 | more-itertools=10.2.0=pypi_0 92 | moviepy=1.0.3=pypi_0 93 | mpmath=1.3.0=pypi_0 94 | multiprocess=0.70.16=pypi_0 95 | ncurses=6.4=h6a678d5_0 96 | networkx=3.3=pypi_0 97 | nh3=0.2.17=pypi_0 98 | nltk=3.8.1=pypi_0 99 | numpy=1.26.4=py310h5f9d8c6_0 100 | numpy-base=1.26.4=py310hb5e798b_0 101 | nvidia-cublas-cu12=12.1.3.1=pypi_0 102 | nvidia-cuda-cupti-cu12=12.1.105=pypi_0 103 | nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0 104 | nvidia-cuda-runtime-cu12=12.1.105=pypi_0 105 | nvidia-cudnn-cu12=8.9.2.26=pypi_0 106 | nvidia-cufft-cu12=11.0.2.54=pypi_0 107 | nvidia-curand-cu12=10.3.2.106=pypi_0 108 | nvidia-cusolver-cu12=11.4.5.107=pypi_0 109 | nvidia-cusparse-cu12=12.1.0.106=pypi_0 110 | nvidia-nccl-cu12=2.19.3=pypi_0 111 | nvidia-nvjitlink-cu12=12.4.127=pypi_0 112 | nvidia-nvtx-cu12=12.1.105=pypi_0 113 | omegaconf=2.3.0=pypi_0 114 | openai=1.17.0=pypi_0 115 | opencv-python=4.10.0.82=pypi_0 116 | openssl=3.0.13=h7f8727e_0 117 | orjson=3.10.0=pypi_0 118 | outcome=1.3.0.post0=pypi_0 119 | packaging=23.2=py310h06a4308_0 120 | pandas=2.2.2=pypi_0 121 | panopticapi=0.1=pypi_0 122 | parameterized=0.9.0=pypi_0 123 | peft=0.2.0=pypi_0 124 | pillow=10.3.0=pypi_0 125 | pip=23.3.1=py310h06a4308_0 126 | pkginfo=1.11.0=pypi_0 127 | portalocker=2.8.2=pypi_0 128 | proglog=0.1.10=pypi_0 129 | progressbar2=4.4.2=pypi_0 130 | protobuf=4.25.3=pypi_0 131 | psutil=5.9.8=pypi_0 132 | pycocoevalcap=1.2=pypi_0 133 | pycocotools=2.0.7=pypi_0 134 | pycparser=2.22=pypi_0 135 | pydantic=2.6.4=pypi_0 136 | pydantic-core=2.16.3=pypi_0 137 | pydub=0.25.1=pypi_0 138 | pygments=2.17.2=pypi_0 139 | pyparsing=3.1.2=pypi_0 140 | pysocks=1.7.1=py310h06a4308_0 141 | pysrt=1.1.2=pypi_0 142 | python=3.10.14=h955ad1f_0 143 | python-dateutil=2.9.0.post0=pypi_0 144 | python-multipart=0.0.9=pypi_0 145 | python-utils=3.8.2=pypi_0 146 | pytorchvideo=0.1.5=pypi_0 147 | pytube=15.0.0=pypi_0 148 | pytubefix=5.1.1=pypi_0 149 | pytz=2024.1=pypi_0 150 | pyyaml=6.0.1=py310h5eee18b_0 151 | readline=8.2=h5eee18b_0 152 | readme-renderer=43.0=pypi_0 153 | referencing=0.34.0=pypi_0 154 | regex=2023.10.3=py310h5eee18b_0 155 | requests=2.31.0=py310h06a4308_1 156 | requests-toolbelt=1.0.0=pypi_0 157 | rfc3986=2.0.0=pypi_0 158 | rich=13.7.1=pypi_0 159 | rpds-py=0.18.0=pypi_0 160 | ruff=0.3.5=pypi_0 161 | safetensors=0.4.2=py310ha89cbab_0 162 | scikit-image=0.23.1=pypi_0 163 | scikit-learn=1.5.0=pypi_0 164 | scipy=1.13.0=pypi_0 165 | secretstorage=3.3.3=pypi_0 166 | selenium=4.21.0=pypi_0 167 | semantic-version=2.10.0=pypi_0 168 | sentencepiece=0.2.0=pypi_0 169 | sentry-sdk=1.45.0=pypi_0 170 | setproctitle=1.3.3=pypi_0 171 | setuptools=68.2.2=py310h06a4308_0 172 | shapely=2.0.4=pypi_0 173 | shellingham=1.5.4=pypi_0 174 | six=1.16.0=pypi_0 175 | smmap=5.0.1=pypi_0 176 | sniffio=1.3.1=pypi_0 177 | sortedcontainers=2.4.0=pypi_0 178 | soundfile=0.12.1=pypi_0 179 | sqlite=3.41.2=h5eee18b_0 180 | starlette=0.37.2=pypi_0 181 | submitit=1.5.1=pypi_0 182 | sympy=1.12=pypi_0 183 | tabulate=0.9.0=pypi_0 184 | tbb=2021.8.0=hdb19cb5_0 185 | tenacity=8.3.0=pypi_0 186 | termcolor=2.4.0=pypi_0 187 | threadpoolctl=3.5.0=pypi_0 188 | tifffile=2024.2.12=pypi_0 189 | timm=0.9.16=pypi_0 190 | tk=8.6.12=h1ccaba5_0 191 | tokenizers=0.13.3=pypi_0 192 | tomlkit=0.12.0=pypi_0 193 | toolz=0.12.1=pypi_0 194 | torch=2.2.2=pypi_0 195 | torchaudio=2.2.2=pypi_0 196 | torchvision=0.17.2=pypi_0 197 | towhee=1.1.3=pypi_0 198 | towhee-models=1.1.3=pypi_0 199 | tqdm=4.65.0=py310h2f386ee_0 200 | transformers=4.30.2=pypi_0 201 | trio=0.25.1=pypi_0 202 | trio-websocket=0.11.1=pypi_0 203 | triton=2.2.0=pypi_0 204 | twine=5.1.0=pypi_0 205 | typer=0.12.3=pypi_0 206 | typing-extensions=4.9.0=py310h06a4308_1 207 | typing_extensions=4.9.0=py310h06a4308_1 208 | tzdata=2024.1=pypi_0 209 | urllib3=2.1.0=py310h06a4308_1 210 | uvicorn=0.29.0=pypi_0 211 | visual-genome=1.1.1=pypi_0 212 | wandb=0.16.6=pypi_0 213 | wcwidth=0.2.13=pypi_0 214 | webdataset=0.2.86=pypi_0 215 | websockets=11.0.3=pypi_0 216 | webvtt-py=0.4.6=pypi_0 217 | wheel=0.41.2=py310h06a4308_0 218 | wsproto=1.2.0=pypi_0 219 | xz=5.4.6=h5eee18b_0 220 | yacs=0.1.8=pypi_0 221 | yaml=0.2.5=h7b6447c_0 222 | zipp=3.17.0=py310h06a4308_0 223 | zlib=1.2.13=h5eee18b_0 224 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python run_refavs.py --val val --train \ 2 | --data_dir '/home/user/dataset/REFAVS' \ 3 | --save_ckpt '/home/user/trained_ckpts/refavs' \ 4 | --log_path '/home/user/log/ckpt_rec.txt' \ 5 | --show_params 6 | -------------------------------------------------------------------------------- /run_refavs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn.functional import threshold, normalize 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | 12 | from configs import args 13 | from datasets import REFAVS 14 | from models import REFAVS_Model_Base 15 | 16 | from scripts.train import train, test 17 | from logs.write_log import write_log 18 | 19 | 20 | 21 | def run(model): 22 | train_dataset = REFAVS('train', args) 23 | val_dataset = REFAVS('val', args) 24 | test_dataset_s = REFAVS('test_s', args) # seen 25 | test_dataset_u = REFAVS('test_u', args) # unseen 26 | test_dataset_n = REFAVS('test_n', args) # null 27 | 28 | train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=False, collate_fn=collate_fn) 29 | if args.val == 'val': 30 | val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn) 31 | elif args.val == 'test_s': 32 | val_loader = DataLoader(test_dataset_s, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn) 33 | elif args.val == 'test_u': 34 | val_loader = DataLoader(test_dataset_u, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn) 35 | elif args.val == 'test_n': 36 | val_loader = DataLoader(test_dataset_n, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn) 37 | 38 | 39 | tuned_num = 0 40 | for name, param in model.named_parameters(): 41 | param.requires_grad = False 42 | for _n in args.train_params: 43 | if _n in name: 44 | # print('yes:', _n, name) 45 | param.requires_grad = True # finetune 46 | tuned_num += 1 47 | 48 | if args.show_params: 49 | print('>>> check params with grad:') 50 | for name, param in model.named_parameters(): 51 | if param.requires_grad: 52 | print("- Requires_grad:", name) 53 | 54 | message = f'All: {sum(p.numel() for p in model.parameters()) / 1e6}M\n' 55 | message += f'Train-able: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6}M\n' 56 | print(message) 57 | 58 | # optimizer 59 | params1 = [{'params': [p for name, p in model.named_parameters() if p.requires_grad], 'lr': args.lr}] 60 | params = params1 61 | optimizer = torch.optim.AdamW(params) 62 | 63 | train_losses = [] 64 | m_s, f_s, null_s = [], [], [] # miou, f1, metric_s for null 65 | max_miou = 0 66 | 67 | # model 68 | model = model.cuda() 69 | for idx_ep in range(args.epochs): 70 | print(f'[Epoch] {idx_ep}') 71 | currentDateAndTime = datetime.now().strftime("%y%m%d_%H_%M_%S_%f") 72 | 73 | if args.train: 74 | model.train() 75 | loss_train = train(model, train_loader, optimizer, idx_ep, args) 76 | train_losses.append(loss_train) 77 | 78 | if args.val: 79 | model.eval() 80 | m, f = test(model, val_loader, optimizer, idx_ep, args) 81 | m_s.append(m) 82 | f_s.append(f) 83 | 84 | print(m, currentDateAndTime) 85 | ckpt_save_path = f"{args.save_ckpt}/ckpt_best_miou.pth" 86 | 87 | with open(args.log_path, 'a') as f: 88 | f.write(f"Epoch: {idx_ep}: {m_s} | {f_s}\n") 89 | 90 | if m >= max_miou and args.val == 'val': 91 | max_miou = m 92 | torch.save(model.state_dict(), ckpt_save_path) 93 | print(f'>>> saved ckpt at {ckpt_save_path} with miou={max_miou}') 94 | with open(args.log_path, 'a') as f: 95 | f.write(f"Best miou at epoch: {idx_ep}: {max_miou}. Saved at {ckpt_save_path}.\n") 96 | 97 | print(f'train-losses: {train_losses} | miou: {m_s} | f-score{f_s}') 98 | 99 | def collate_fn(batch): 100 | img_recs = [] 101 | mask_recs = [] 102 | image_sizes = [] 103 | uids = [] 104 | 105 | audio_feats = [] 106 | text_feats = [] 107 | audio_recs = [] 108 | text_recs = [] 109 | 110 | for data in batch: 111 | uids.append(data[0]) 112 | mask_recs.append(data[1]) 113 | img_recs.append(data[2]) 114 | image_sizes.append(data[3]) 115 | audio_feats.append(data[4]) 116 | text_feats.append(data[5]) 117 | audio_recs.append(data[6]) 118 | text_recs.append(data[7]) 119 | 120 | return uids, mask_recs, img_recs, image_sizes, audio_feats, text_feats, audio_recs, text_recs 121 | 122 | if __name__ == '__main__': 123 | print(vars(args)) 124 | m2f_avs = REFAVS_Model_Base(cfgs=args) 125 | 126 | if str(args.val).startswith('test'): 127 | ckpt = args.checkpoint 128 | 129 | print('>>> load ckpt from:', ckpt) 130 | m2f_avs.load_state_dict(torch.load(ckpt), strict=True) 131 | 132 | run(m2f_avs) 133 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # from train import * -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from utils import pyutils 5 | from utils import utility 6 | 7 | avg_meter_miou = pyutils.AverageMeter('miou') 8 | avg_meter_F = pyutils.AverageMeter('F_score') 9 | 10 | def train(model, train_loader, optimizer, idx_ep, args): 11 | print('>>> Train start ...') 12 | model.train() 13 | 14 | losses = [] 15 | 16 | for batch_idx, batch_data in enumerate(train_loader): 17 | loss_vid, _ = model(batch_data) 18 | loss_vid = torch.mean(torch.stack(loss_vid)) 19 | 20 | optimizer.zero_grad() 21 | loss_vid.backward() 22 | optimizer.step() 23 | 24 | losses.append(loss_vid.item()) 25 | print(f'[tr] loss_{idx_ep}_{batch_idx}/{len(train_loader.dataset)//train_loader.batch_size}: {loss_vid.item()} | mean_loss: {np.mean(losses)}', end='\r') 26 | 27 | return np.mean(losses) 28 | 29 | def test(model, test_loader, optimizer, idx_ep, args): 30 | model.eval() 31 | 32 | null_s_list = [] 33 | with torch.no_grad(): 34 | for batch_idx, batch_data in enumerate(test_loader): 35 | uid, mask_recs, img_recs, image_sizes, feat_aud, feat_text, rec_audio, rec_text = batch_data 36 | _, vid_preds = model(batch_data) 37 | mask_recs = [torch.stack(mask_rec, dim=0) for mask_rec in mask_recs] 38 | vid_preds_t = torch.stack(vid_preds, dim=0).squeeze().cuda().view(-1, 1, 256, 256) 39 | vid_masks_t = torch.stack(mask_recs, dim=0).squeeze().cuda().view(-1, 1, 256, 256) 40 | 41 | if args.val == 'test_n': 42 | null_s = utility.metric_s_for_null(vid_preds_t) 43 | null_s_list.append(null_s.cpu().numpy()) 44 | print(f'[te] loss_{idx_ep}_{batch_idx}/{len(test_loader.dataset)//test_loader.batch_size}: s={null_s} | mean={np.mean(np.array(null_s_list))} ') 45 | 46 | else: 47 | miou = utility.mask_iou(vid_preds_t, vid_masks_t) 48 | avg_meter_miou.add({'miou': miou}) 49 | 50 | F_score = utility.Eval_Fmeasure(vid_preds_t, vid_masks_t, './logger', device=f'cuda:{args.gpu_id}') 51 | avg_meter_F.add({'F_score': F_score}) 52 | 53 | print(f'[te] loss_{idx_ep}_{batch_idx}/{len(test_loader.dataset)//test_loader.batch_size}: miou={miou:.03f} | F={F_score:.03f} | ', end='\r') 54 | 55 | if args.val == 'test_n': 56 | miou_epoch = np.mean(np.array(null_s_list)) 57 | F_epoch = miou_epoch # fake name, just for null_s 58 | else: 59 | miou_epoch = (avg_meter_miou.pop('miou')).item() 60 | F_epoch = (avg_meter_F.pop('F_score')) 61 | 62 | return miou_epoch, F_epoch 63 | 64 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import pyutils 2 | from .metric import utility 3 | -------------------------------------------------------------------------------- /utils/metric/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | -------------------------------------------------------------------------------- /utils/metric/utility.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import shutil 7 | # import logging 8 | import cv2 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import sys 13 | import time 14 | import pandas as pd 15 | import pdb 16 | from torchvision import transforms 17 | 18 | def metric_s_for_null(pred): 19 | NF, bsz, H, W = pred.shape 20 | pred = pred.view(NF*bsz, H, W) 21 | assert len(pred.shape) == 3 22 | 23 | N = pred.size(0) 24 | num_pixels = pred.view(-1).shape[0] 25 | 26 | temp_pred = torch.sigmoid(pred) 27 | pred = (temp_pred > 0.5).int() 28 | 29 | x = torch.sum(pred.view(-1)) 30 | s = torch.sqrt(x / num_pixels) 31 | 32 | return s 33 | 34 | def mask_iou(pred, target, eps=1e-7, size_average=True): 35 | r""" 36 | param: 37 | pred: size [N x H x W] 38 | target: size [N x H x W] 39 | output: 40 | iou: size [1] (size_average=True) or [N] (size_average=False) 41 | """ 42 | # return mask_iou_224(pred, target, eps=1e-7) 43 | NF, bsz, H, W = pred.shape 44 | pred = pred.view(NF*bsz, H, W) 45 | target = target.view(NF*bsz, H, W) 46 | assert len(pred.shape) == 3 and pred.shape == target.shape 47 | 48 | N = pred.size(0) 49 | num_pixels = pred.size(-1) * pred.size(-2) 50 | no_obj_flag = (target.sum(2).sum(1) == 0) 51 | 52 | temp_pred = torch.sigmoid(pred) 53 | pred = (temp_pred > 0.4).int() 54 | inter = (pred * target).sum(2).sum(1) 55 | union = torch.max(pred, target).sum(2).sum(1) 56 | 57 | inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1) 58 | inter[no_obj_flag] = inter_no_obj[no_obj_flag] 59 | union[no_obj_flag] = num_pixels 60 | 61 | iou = torch.sum(inter / (union + eps)) / N 62 | 63 | return iou 64 | 65 | 66 | def _eval_pr(y_pred, y, num, device='cuda'): 67 | if device.startswith('cuda'): 68 | prec, recall = torch.zeros(num).to(y_pred.device), torch.zeros(num).to(y_pred.device) 69 | thlist = torch.linspace(0, 1 - 1e-10, num).to(y_pred.device) 70 | else: 71 | prec, recall = torch.zeros(num), torch.zeros(num) 72 | thlist = torch.linspace(0, 1 - 1e-10, num) 73 | for i in range(num): 74 | y_temp = (y_pred >= thlist[i]).float() 75 | tp = (y_temp * y).sum() 76 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 77 | 78 | return prec, recall 79 | 80 | 81 | def Eval_Fmeasure(pred, gt, measure_path, pr_num=255, device='cuda'): 82 | r""" 83 | param: 84 | pred: size [N x H x W] 85 | gt: size [N x H x W] 86 | output: 87 | iou: size [1] (size_average=True) or [N] (size_average=False) 88 | """ 89 | 90 | pred = torch.sigmoid(pred) 91 | N = pred.size(0) 92 | beta2 = 0.3 93 | avg_f, img_num = 0.0, 0 94 | score = torch.zeros(pr_num) 95 | 96 | 97 | for img_id in range(N): 98 | if torch.mean(gt[img_id]) == 0.0: 99 | continue 100 | prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num, device=device) 101 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 102 | f_score[f_score != f_score] = 0 # for Nan 103 | avg_f += f_score 104 | img_num += 1 105 | score = avg_f / img_num 106 | 107 | return score.max().item() 108 | 109 | --------------------------------------------------------------------------------