├── PAM ├── __init__.py ├── models │ ├── __init__.py │ ├── audio.py │ ├── utils.py │ ├── clap.py │ ├── config.py │ ├── pytorch_utils.py │ └── htsat.py ├── config.yml └── PAM.py ├── requirements.txt ├── LICENSE ├── run.py ├── pcc.py ├── README.md └── dataset.py /PAM/__init__.py: -------------------------------------------------------------------------------- 1 | from .PAM import PAM -------------------------------------------------------------------------------- /PAM/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import htsat 4 | from . import config 5 | from . import pytorch_utils 6 | from . import htsat -------------------------------------------------------------------------------- /PAM/models/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | from .htsat import HTSATWrapper 6 | 7 | def get_audio_encoder(name: str): 8 | if name == "HTSAT": 9 | return HTSATWrapper 10 | else: 11 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) -------------------------------------------------------------------------------- /PAM/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'gpt2' 3 | text_len: 77 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'HTSAT' 9 | out_emb: 768 10 | sampling_rate: 44100 11 | duration: 7 12 | fmin: 50 13 | fmax: 8000 #14000 14 | n_fft: 1024 # 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False -------------------------------------------------------------------------------- /PAM/models/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | audioread==3.0.0 3 | certifi==2022.12.7 4 | cffi==1.15.1 5 | charset-normalizer==3.0.1 6 | colorama==0.4.6 7 | decorator==5.1.1 8 | filelock==3.9.0 9 | flit_core==3.6.0 10 | huggingface-hub==0.12.1 11 | idna==3.4 12 | importlib-metadata==6.0.0 13 | importlib-resources==5.12.0 14 | jaraco.classes==3.2.3 15 | joblib==1.2.0 16 | lazy_loader==0.1 17 | librosa==0.10.0 18 | llvmlite==0.39.1 19 | mkl-service==2.4.0 20 | more-itertools==9.0.0 21 | msgpack==1.0.4 22 | numba==0.56.4 23 | numpy==1.23.5 24 | packaging==23.0 25 | pandas==1.4.2 26 | pooch==1.6.0 27 | pycparser==2.21 28 | pywin32-ctypes==0.2.0 29 | PyYAML==6.0 30 | regex==2022.10.31 31 | requests==2.28.2 32 | scikit-learn==1.2.1 33 | scipy==1.10.1 34 | setuptools==65.6.3 35 | six==1.16.0 36 | soundfile==0.12.1 37 | soxr==0.3.3 38 | threadpoolctl==3.1.0 39 | tokenizers==0.13.2 40 | torch==1.13.1 41 | torchaudio==0.13.1 42 | torchlibrosa==0.1.0 43 | torchvision==0.14.1 44 | tqdm==4.64.1 45 | transformers==4.26.1 46 | typing_extensions==4.4.0 47 | urllib3==1.26.14 48 | wheel==0.38.4 49 | wincertstore==0.2 50 | zipp==3.14.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Soham 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | from PAM import PAM 6 | from dataset import ExampleDatasetFolder 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser(description = "PAM") 10 | parser.add_argument('--folder', type=str, help='Folder path to evaluate') 11 | parser.add_argument('--batch_size', type=int, default=10, help='Number of examples per batch') 12 | parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for dataloader') 13 | args = parser.parse_args() 14 | 15 | # initialize PAM 16 | pam = PAM(use_cuda=torch.cuda.is_available()) 17 | 18 | # Create Dataset and Dataloader 19 | dataset = ExampleDatasetFolder( 20 | src=args.folder, 21 | ) 22 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle = False, 23 | num_workers = args.num_workers, 24 | pin_memory = False, drop_last=False, collate_fn=dataset.collate) 25 | 26 | # Evaluate and print PAM score 27 | collect_pam, collect_pam_segment = [], [] 28 | for files, audios, sample_index in tqdm(dataloader): 29 | pam_score, pam_segment_score = pam.evaluate(audios, sample_index) 30 | collect_pam += pam_score 31 | collect_pam_segment += pam_segment_score 32 | 33 | print(f"PAM Score: {sum(collect_pam)/len(collect_pam)}") -------------------------------------------------------------------------------- /pcc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | from PAM import PAM 6 | from dataset import ExampleDatasetFiles 7 | import os 8 | import glob 9 | import pandas as pd 10 | import numpy as np 11 | 12 | def evaluate_pam(dataloader, pam): 13 | """Evaluate PAM score using the provided dataloader""" 14 | collect_pam, collect_pam_segment = [], [] 15 | for _, audios, sample_index in tqdm(dataloader): 16 | pam_score, pam_segment_score = pam.evaluate(audios, sample_index) 17 | collect_pam += pam_score 18 | collect_pam_segment += pam_segment_score 19 | return collect_pam, collect_pam_segment 20 | 21 | def load_task_dataframe(task, model): 22 | """Load and return human listening scores""" 23 | df = pd.read_csv(os.path.join(task, "scores.csv")) 24 | model_df = df[df["Model"] == model.split(os.path.sep)[-1]] 25 | files = [os.path.join(model, x) + ".wav" for x in list(model_df["File Name"])] 26 | OVLs, RELs = model_df["OVL"], model_df["REL"] 27 | return files, OVLs, RELs 28 | 29 | def evaluate_task(task, model, pam): 30 | """Evaluate files generated by a model for particular task""" 31 | print(f"\nTask: {task}, Model: {model}") 32 | 33 | # Load human listening scores 34 | files, OVLs, RELs = load_task_dataframe(task, model) 35 | 36 | # Create Dataset and Dataloader 37 | dataset = ExampleDatasetFiles( 38 | src=files, 39 | repro=args.repro, 40 | ) 41 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle = False, 42 | num_workers = args.num_workers, 43 | pin_memory = False, drop_last=False, collate_fn=dataset.collate) 44 | 45 | # Evaluate and print PAM score 46 | collect_pam, _ = evaluate_pam(dataloader, pam) 47 | 48 | print(f"Average PAM Score: {sum(collect_pam)/len(collect_pam)}") 49 | print(f"PCC PAM / OVL: {np.corrcoef(collect_pam, OVLs)[0,1]}") 50 | print(f"PCC PAM / REL: {np.corrcoef(collect_pam, RELs)[0,1]}") 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description = "PAM") 54 | parser.add_argument('--folder', type=str, default="human_eval", help='Folder path to evaluate') 55 | parser.add_argument('--batch_size', type=int, default=10, help='Number of examples per batch') 56 | parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for dataloader') 57 | parser.add_argument('--repro', type=bool, default=True, help='Reproduce paper setup and evaluation') 58 | args = parser.parse_args() 59 | 60 | # initialize PAM 61 | pam = PAM(use_cuda=torch.cuda.is_available()) 62 | 63 | # Run evaluation on tasks 64 | tasks = glob.glob(os.path.join(args.folder,"**")) 65 | for task in tasks: 66 | models = glob.glob(os.path.join(task,"**")) 67 | models = [m for m in models if ".csv" not in m] 68 | for model in models: 69 | evaluate_task(task, model, pam) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PAM: Prompting Audio-Language Models for Audio Quality Assessment 2 | [[`Paper`](https://arxiv.org/abs/2402.00282)] [[`data`](https://github.com/soham97/PAM/tree/main?tab=readme-ov-file#data)] 3 | 4 | PAM is a no-reference metric for assessing audio quality for different audio processing tasks. It prompts Audio-Language Models (ALMs) using an antonym prompt strategy to calculate an audio quality score. It does not require reference data or task-specific models and correlates well with human perception. 5 | ![PAM_9 (1)](https://github.com/soham97/PAM/assets/28994673/3c0754ac-636a-4fc6-8045-d06282121ea4) 6 | 7 | ## Setup 8 | Open the [Anaconda](https://www.anaconda.com) terminal and run: 9 | ```shell 10 | > git clone https://github.com/soham97/PAM.git 11 | > cd PAM 12 | > conda create -n pam python=3.10 13 | > conda activate pam 14 | > pip install -r requirements.txt 15 | ``` 16 | 17 | ## Compute PAM 18 | #### Folder evaluation 19 | To compute PAM on folder containing audio files, you can directly run: 20 | ```shell 21 | > python run.py --folder {folder_path} 22 | ``` 23 | The symbol `{..}` indicates user input. 24 | 25 | #### Custom evaluation 26 | To compute PAM on heirarchy of folder or multiple directory, we recommed creating a custom dataset. 27 | - In `dataset.py` creating a custom dataset by inheriting from `AudioDataset`, similar to `ExampleDataset` 28 | - Modify the `get_filelist` function to fit to your directory structure 29 | - Update the `run.py` with your custom dataset and make changes to evaluation if needed 30 | 31 | ## Data 32 | The manuscript uses data from multiple sources. It can be obtained as follows: 33 | - For the text-to-audio and text-to-music generation, we conducted the human listening test using Amazon Turk. The audio generated by models and human listening scores are available at: [Zenodo](https://zenodo.org/records/10737388) 34 | - For text-to-music generation with FAD comparison (Figure 6), we used the data and human listening scores from [Adapting Frechet Audio Distance for Generative Music Evaluation 35 | (ICASSP 24)](https://arxiv.org/abs/2311.01616). The website is [here](https://github.com/microsoft/fadtk) 36 | - For text-to-speech generation, we used the data and human listening scores from [Evaluating speech synthesis by training recognizers on synthetic speech (2023)](https://arxiv.org/abs/2310.00706) 37 | - For distortions (Figure 4) we sourced the data from NISQA. The data with human listening scores, can be downloaded from the GitHub repo: [here](https://github.com/gabrielmittag/NISQA). 38 | - For voice conversion, we use the voice conversion subset from the VoiceMOS Challenge data. The data can be downloaded at: [Zenodo](https://zenodo.org/records/10691660) 39 | 40 | ## Paper reproduction 41 | This section covers reproducing numbers for text-to-audio and text-to-music. First download the human listening test data by following the instruction listed above. The download should contain a folder titled `human_eval`. 42 | 43 | Then run the following commands. 44 | ```shell 45 | > python pcc.py --folder {folder_path} 46 | ``` 47 | where `{folder_path}` points to `human_eval` folder. 48 | 49 | ## Citation 50 | ```BibTeX 51 | @article{deshmukh2024pam, 52 | title={PAM: Prompting Audio-Language Models for Audio Quality Assessment}, 53 | author={Soham Deshmukh and Dareen Alharthi and Benjamin Elizalde and Hannes Gamper and Mahmoud Al Ismail and Rita Singh and Bhiksha Raj and Huaming Wang}, 54 | journal={arXiv preprint arXiv:2402.00282}, 55 | year={2023} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /PAM/models/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.text_model = text_model 46 | self.base = AutoModel.from_pretrained(text_model) 47 | 48 | self.projection = Projection(transformer_embed_dim, d_out) 49 | 50 | def forward(self, x): 51 | batch_size = x['input_ids'].shape[0] 52 | hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768) 53 | 54 | sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17]) 55 | out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768] 56 | 57 | projected_vec = self.projection(out) 58 | 59 | return projected_vec 60 | 61 | class CLAP(nn.Module): 62 | def __init__(self, 63 | # audio 64 | audioenc_name: str, 65 | sample_rate: int, 66 | window_size: int, 67 | hop_size: int, 68 | mel_bins: int, 69 | fmin: int, 70 | fmax: int, 71 | classes_num: int, 72 | out_emb: int, 73 | # text 74 | text_model: str, 75 | transformer_embed_dim: int, 76 | # common 77 | d_proj: int, 78 | ): 79 | super().__init__() 80 | 81 | 82 | self.audio_encoder = AudioEncoder( 83 | audioenc_name, out_emb, d_proj, 84 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 85 | 86 | self.caption_encoder = TextEncoder( 87 | d_proj, text_model, transformer_embed_dim 88 | ) 89 | 90 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 91 | 92 | def forward(self, audio, text): 93 | audio_embed, _ = self.audio_encoder(audio) 94 | caption_embed = self.caption_encoder(text) 95 | 96 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import numpy as np 4 | import torch 5 | import glob 6 | import torchaudio 7 | import torchaudio.transforms as T 8 | 9 | RESAMPLE_RATE = 44100 10 | AUDIO_DURATION = 7 11 | SAMPLES = RESAMPLE_RATE*AUDIO_DURATION 12 | 13 | class AudioDataset(Dataset): 14 | def __init__(self, src, repro = False): 15 | self.src = src 16 | self.repro = repro 17 | self.filelist = self.get_files() 18 | 19 | def get_files(self): 20 | r"""Return a list of filepaths to evaluate PAM on. User implemented.""" 21 | raise NotImplementedError 22 | 23 | def __getitem__(self, index): 24 | r"""Retrieve audio file and return processed audio.""" 25 | file = self.filelist[index] 26 | audio = self.readaudio(file) 27 | return file, audio 28 | 29 | def process_audio(self, audio): 30 | r"""Process audio to be a multiple of 7 seconds.""" 31 | audio = audio.reshape(-1) 32 | if SAMPLES >= audio.shape[0]: 33 | repeat_factor = int(np.ceil((SAMPLES) / 34 | audio.shape[0])) 35 | # Repeat audio_time_series by repeat_factor to match audio_duration 36 | audio = audio.repeat(repeat_factor) 37 | # remove excess part of audio_time_series 38 | audio = audio[0:SAMPLES] 39 | else: 40 | if self.repro: 41 | # retain only first 7 seconds 42 | start_index = 0 43 | audio = audio[start_index:start_index + SAMPLES] 44 | else: 45 | cutoff = int(np.floor(audio.shape[0]/SAMPLES)) 46 | # cutoff audio 47 | initial_audio_series = audio[0:cutoff*SAMPLES] 48 | # remaining audio repeat and cut off 49 | remaining = audio[cutoff*SAMPLES:] 50 | if remaining.shape[0] != 0: 51 | repeat_factor = int(np.ceil((SAMPLES) / remaining.shape[0])) 52 | remaining = remaining.repeat(repeat_factor) 53 | remaining = remaining[0:SAMPLES] 54 | audio = torch.cat([initial_audio_series, remaining]) 55 | else: 56 | audio = initial_audio_series 57 | 58 | return audio 59 | 60 | def readaudio(self, file): 61 | r"""Loads audio file and returns raw audio.""" 62 | audio, sample_rate = torchaudio.load(file) 63 | 64 | # Resample audio if needed 65 | if RESAMPLE_RATE != sample_rate: 66 | resampler = T.Resample(sample_rate, RESAMPLE_RATE) 67 | audio = resampler(audio) 68 | 69 | # process audio to be a multiple of 7 seconds 70 | audio = self.process_audio(audio) 71 | return audio 72 | 73 | def collate(self, batch): 74 | r"""Collate batch and generate chunk pointers.""" 75 | # Assign a reference variable to identify the file associated with each chunk 76 | files = [x[0] for x in batch] 77 | sample_len = [0] + [int(len(x[1])/SAMPLES) for x in batch] 78 | sample_index = [sum(sample_len[0:i+1]) for i in range(len(sample_len))] 79 | 80 | # Create chunks 81 | batch = torch.cat([x[1] for x in batch]) 82 | batch_chunks = [batch[SAMPLES*i:SAMPLES*i+SAMPLES].reshape(1,-1) for i in range(0,int(batch.shape[0]/SAMPLES))] 83 | batch_chunks = torch.cat(batch_chunks,axis=0) 84 | 85 | return files, batch_chunks, sample_index 86 | 87 | def __len__(self): 88 | r"""Size of dataset.""" 89 | return len(self.filelist) 90 | 91 | class ExampleDatasetFolder(AudioDataset): 92 | def __init__(self, src, repro = False): 93 | self.src = src 94 | self.repro = repro 95 | self.filelist = self.get_files() 96 | super().__init__(src,repro) 97 | 98 | def get_files(self): 99 | return glob.glob(os.path.join(self.src,"**/*.wav"), recursive=True) 100 | 101 | class ExampleDatasetFiles(AudioDataset): 102 | def __init__(self, src, repro = False): 103 | self.src = src 104 | self.repro = repro 105 | super().__init__(src,repro) 106 | 107 | def get_files(self): 108 | return self.src 109 | -------------------------------------------------------------------------------- /PAM/models/config.py: -------------------------------------------------------------------------------- 1 | # Ke Chen 2 | # knutchen@ucsd.edu 3 | # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION 4 | # The configuration for training the model 5 | 6 | exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model 7 | workspace = "/home/kechen/Research/HTSAT" # the folder of your code 8 | dataset_path = "/home/Research/audioset" # the dataset path 9 | desed_folder = "/home/Research/DESED" # the desed file 10 | 11 | dataset_type = "audioset" # "audioset" "esc-50" "scv2" 12 | index_type = "full_train" # only works for audioset 13 | balanced_data = True # only works for audioset 14 | 15 | loss_type = "clip_bce" # 16 | # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce" 17 | 18 | # trained from a checkpoint, or evaluate a single model 19 | resume_checkpoint = None 20 | # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt" 21 | 22 | esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation 23 | 24 | 25 | debug = False 26 | 27 | random_seed = 970131 # 19970318 970131 12412 127777 1009 34047 28 | batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128 29 | learning_rate = 1e-3 # 1e-4 also workable 30 | max_epoch = 100 31 | num_workers = 3 32 | 33 | lr_scheduler_epoch = [10,20,30] 34 | lr_rate = [0.02, 0.05, 0.1] 35 | 36 | # these data preparation optimizations do not bring many improvements, so deprecated 37 | enable_token_label = False # token label 38 | class_map_path = "class_hier_map.npy" 39 | class_filter = None 40 | retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, 41 | 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900] 42 | token_label_range = [0.2,0.6] 43 | enable_time_shift = False # shift time 44 | enable_label_enhance = False # enhance hierarchical label 45 | enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram 46 | 47 | 48 | 49 | # for model's design 50 | enable_tscam = True # enbale the token-semantic layer 51 | 52 | # for signal processing 53 | sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50 54 | clip_samples = sample_rate * 10 # audio_set 10-sec clip 55 | window_size = 1024 56 | hop_size = 320 # 160 for scv2, 320 for audioset and esc-50 57 | mel_bins = 64 58 | fmin = 50 59 | fmax = 14000 60 | shift_max = int(clip_samples * 0.5) 61 | 62 | # for data collection 63 | classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35 64 | patch_size = (25, 4) # deprecated 65 | crop_size = None # int(clip_samples * 0.5) deprecated 66 | 67 | # for htsat hyperparamater 68 | htsat_window_size = 8 69 | htsat_spec_size = 256 70 | htsat_patch_size = 4 71 | htsat_stride = (4, 4) 72 | htsat_num_head = [4,8,16,32] 73 | htsat_dim = 96 74 | htsat_depth = [2,2,6,2] 75 | 76 | swin_pretrain_path = None 77 | # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth" 78 | 79 | # Some Deprecated Optimization in the model design, check the model code for details 80 | htsat_attn_heatmap = False 81 | htsat_hier_output = False 82 | htsat_use_max = False 83 | 84 | 85 | # for ensemble test 86 | 87 | ensemble_checkpoints = [] 88 | ensemble_strides = [] 89 | 90 | 91 | # weight average folder 92 | wa_folder = "/home/version_0/checkpoints/" 93 | # weight average output filename 94 | wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt" 95 | 96 | esm_model_pathes = [ 97 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt", 98 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt", 99 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt", 100 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt", 101 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt", 102 | "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt" 103 | ] 104 | 105 | # for framewise localization 106 | heatmap_dir = "/home/Research/heatmap_output" 107 | test_file = "htsat-test-ensemble" 108 | fl_local = False # indicate if we need to use this dataset for the framewise detection 109 | fl_dataset = "/home/Research/desed/desed_eval.npy" 110 | fl_class_num = [ 111 | "Speech", "Frying", "Dishes", "Running_water", 112 | "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing", 113 | "Cat", "Dog", "Vacuum_cleaner" 114 | ] 115 | 116 | # map 527 classes into 10 classes 117 | fl_audioset_mapping = [ 118 | [0,1,2,3,4,5,6,7], 119 | [366, 367, 368], 120 | [364], 121 | [288, 289, 290, 291, 292, 293, 294, 295, 296, 297], 122 | [369], 123 | [382], 124 | [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402], 125 | [81, 82, 83, 84, 85], 126 | [74, 75, 76, 77, 78, 79], 127 | [377] 128 | ] -------------------------------------------------------------------------------- /PAM/models/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def move_data_to_device(x, device): 6 | if 'float' in str(x.dtype): 7 | x = torch.Tensor(x) 8 | elif 'int' in str(x.dtype): 9 | x = torch.LongTensor(x) 10 | else: 11 | return x 12 | 13 | return x.to(device) 14 | 15 | 16 | def do_mixup(x, mixup_lambda): 17 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 18 | (1, 3, 5, ...). 19 | Args: 20 | x: (batch_size * 2, ...) 21 | mixup_lambda: (batch_size * 2,) 22 | Returns: 23 | out: (batch_size, ...) 24 | """ 25 | out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ 26 | x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) 27 | return out 28 | 29 | 30 | def append_to_dict(dict, key, value): 31 | if key in dict.keys(): 32 | dict[key].append(value) 33 | else: 34 | dict[key] = [value] 35 | 36 | 37 | def interpolate(x, ratio): 38 | """Interpolate data in time domain. This is used to compensate the 39 | resolution reduction in downsampling of a CNN. 40 | 41 | Args: 42 | x: (batch_size, time_steps, classes_num) 43 | ratio: int, ratio to interpolate 44 | Returns: 45 | upsampled: (batch_size, time_steps * ratio, classes_num) 46 | """ 47 | (batch_size, time_steps, classes_num) = x.shape 48 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 49 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 50 | return upsampled 51 | 52 | 53 | def pad_framewise_output(framewise_output, frames_num): 54 | """Pad framewise_output to the same length as input frames. The pad value 55 | is the same as the value of the last frame. 56 | Args: 57 | framewise_output: (batch_size, frames_num, classes_num) 58 | frames_num: int, number of frames to pad 59 | Outputs: 60 | output: (batch_size, frames_num, classes_num) 61 | """ 62 | pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) 63 | """tensor for padding""" 64 | 65 | output = torch.cat((framewise_output, pad), dim=1) 66 | """(batch_size, frames_num, classes_num)""" 67 | 68 | return output 69 | 70 | 71 | def count_parameters(model): 72 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 73 | 74 | 75 | def count_flops(model, audio_length): 76 | """Count flops. Code modified from others' implementation. 77 | """ 78 | multiply_adds = True 79 | list_conv2d=[] 80 | def conv2d_hook(self, input, output): 81 | batch_size, input_channels, input_height, input_width = input[0].size() 82 | output_channels, output_height, output_width = output[0].size() 83 | 84 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 85 | bias_ops = 1 if self.bias is not None else 0 86 | 87 | params = output_channels * (kernel_ops + bias_ops) 88 | flops = batch_size * params * output_height * output_width 89 | 90 | list_conv2d.append(flops) 91 | 92 | list_conv1d=[] 93 | def conv1d_hook(self, input, output): 94 | batch_size, input_channels, input_length = input[0].size() 95 | output_channels, output_length = output[0].size() 96 | 97 | kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 98 | bias_ops = 1 if self.bias is not None else 0 99 | 100 | params = output_channels * (kernel_ops + bias_ops) 101 | flops = batch_size * params * output_length 102 | 103 | list_conv1d.append(flops) 104 | 105 | list_linear=[] 106 | def linear_hook(self, input, output): 107 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 108 | 109 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 110 | bias_ops = self.bias.nelement() 111 | 112 | flops = batch_size * (weight_ops + bias_ops) 113 | list_linear.append(flops) 114 | 115 | list_bn=[] 116 | def bn_hook(self, input, output): 117 | list_bn.append(input[0].nelement() * 2) 118 | 119 | list_relu=[] 120 | def relu_hook(self, input, output): 121 | list_relu.append(input[0].nelement() * 2) 122 | 123 | list_pooling2d=[] 124 | def pooling2d_hook(self, input, output): 125 | batch_size, input_channels, input_height, input_width = input[0].size() 126 | output_channels, output_height, output_width = output[0].size() 127 | 128 | kernel_ops = self.kernel_size * self.kernel_size 129 | bias_ops = 0 130 | params = output_channels * (kernel_ops + bias_ops) 131 | flops = batch_size * params * output_height * output_width 132 | 133 | list_pooling2d.append(flops) 134 | 135 | list_pooling1d=[] 136 | def pooling1d_hook(self, input, output): 137 | batch_size, input_channels, input_length = input[0].size() 138 | output_channels, output_length = output[0].size() 139 | 140 | kernel_ops = self.kernel_size[0] 141 | bias_ops = 0 142 | 143 | params = output_channels * (kernel_ops + bias_ops) 144 | flops = batch_size * params * output_length 145 | 146 | list_pooling2d.append(flops) 147 | 148 | def foo(net): 149 | childrens = list(net.children()) 150 | if not childrens: 151 | if isinstance(net, nn.Conv2d): 152 | net.register_forward_hook(conv2d_hook) 153 | elif isinstance(net, nn.Conv1d): 154 | net.register_forward_hook(conv1d_hook) 155 | elif isinstance(net, nn.Linear): 156 | net.register_forward_hook(linear_hook) 157 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 158 | net.register_forward_hook(bn_hook) 159 | elif isinstance(net, nn.ReLU): 160 | net.register_forward_hook(relu_hook) 161 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 162 | net.register_forward_hook(pooling2d_hook) 163 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 164 | net.register_forward_hook(pooling1d_hook) 165 | else: 166 | print('Warning: flop of module {} is not counted!'.format(net)) 167 | return 168 | for c in childrens: 169 | foo(c) 170 | 171 | # Register hook 172 | foo(model) 173 | 174 | device = device = next(model.parameters()).device 175 | input = torch.rand(1, audio_length).to(device) 176 | 177 | out = model(input) 178 | 179 | total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ 180 | sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) 181 | 182 | return total_flops -------------------------------------------------------------------------------- /PAM/PAM.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | import re 7 | from transformers import AutoTokenizer, logging 8 | from .models.clap import CLAP 9 | import os 10 | import torch 11 | import argparse 12 | import yaml 13 | import sys 14 | from huggingface_hub.file_download import hf_hub_download 15 | logging.set_verbosity_error() 16 | import torch.nn.functional as F 17 | import collections 18 | 19 | HF_REPO = "microsoft/msclap" 20 | CLAP_VERSION = "CLAP_weights_2023.pth" 21 | PAM_PROMPTS = ['the sound is clear and clean.','the sound is noisy and with artifacts.'] 22 | 23 | class PAM(): 24 | """ 25 | A class for PAM metric. 26 | """ 27 | def __init__(self, model_fp: Path | str | None = None, use_cuda=False): 28 | self.np_str_obj_array_pattern = re.compile(r'[SaUO]') 29 | self.file_path = os.path.realpath(__file__) 30 | self.default_collate_err_msg_format = ( 31 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 32 | "dicts or lists; found {}") 33 | self.config_as_str = (Path(__file__).parent / f"config.yml").read_text() 34 | 35 | # Automatically download model if not provided 36 | if not model_fp: 37 | model_fp = hf_hub_download(HF_REPO, CLAP_VERSION) 38 | 39 | self.model_fp = model_fp 40 | self.use_cuda = use_cuda 41 | self.clap, self.tokenizer, self.args = self.load_clap() 42 | 43 | # Two prompt strategy 44 | self.pam_prompts = PAM_PROMPTS 45 | self.get_text_embeddings() 46 | 47 | def read_config_as_args(self,config_path,args=None,is_config_str=False): 48 | return_dict = {} 49 | 50 | if config_path is not None: 51 | if is_config_str: 52 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 53 | else: 54 | with open(config_path, "r") as f: 55 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 56 | 57 | if args != None: 58 | for k, v in yml_config.items(): 59 | if k in args.__dict__: 60 | args.__dict__[k] = v 61 | else: 62 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 63 | else: 64 | for k, v in yml_config.items(): 65 | return_dict[k] = v 66 | 67 | args = args if args != None else return_dict 68 | return argparse.Namespace(**args) 69 | 70 | def load_clap(self): 71 | r"""Load CLAP model with args from config file""" 72 | 73 | args = self.read_config_as_args(self.config_as_str, is_config_str=True) 74 | 75 | self.token_keys = ['input_ids', 'attention_mask'] 76 | 77 | clap = CLAP( 78 | audioenc_name=args.audioenc_name, 79 | sample_rate=args.sampling_rate, 80 | window_size=args.window_size, 81 | hop_size=args.hop_size, 82 | mel_bins=args.mel_bins, 83 | fmin=args.fmin, 84 | fmax=args.fmax, 85 | classes_num=args.num_classes, 86 | out_emb=args.out_emb, 87 | text_model=args.text_model, 88 | transformer_embed_dim=args.transformer_embed_dim, 89 | d_proj=args.d_proj 90 | ) 91 | 92 | # Load pretrained weights for model 93 | model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] 94 | 95 | # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`: 96 | # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005 97 | clap.load_state_dict(model_state_dict, strict=False) 98 | 99 | clap.eval() # set clap in eval mode 100 | tokenizer = AutoTokenizer.from_pretrained(args.text_model) 101 | tokenizer.add_special_tokens({'pad_token': '!'}) 102 | 103 | if self.use_cuda and torch.cuda.is_available(): 104 | clap = clap.cuda() 105 | 106 | return clap, tokenizer, args 107 | 108 | def default_collate(self, batch): 109 | r"""Puts each data field into a tensor with outer dimension batch size""" 110 | elem = batch[0] 111 | elem_type = type(elem) 112 | if isinstance(elem, torch.Tensor): 113 | out = None 114 | if torch.utils.data.get_worker_info() is not None: 115 | # If we're in a background process, concatenate directly into a 116 | # shared memory tensor to avoid an extra copy 117 | numel = sum([x.numel() for x in batch]) 118 | storage = elem.storage()._new_shared(numel) 119 | out = elem.new(storage) 120 | return torch.stack(batch, 0, out=out) 121 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 122 | and elem_type.__name__ != 'string_': 123 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 124 | # array of string classes and object 125 | if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: 126 | raise TypeError( 127 | self.default_collate_err_msg_format.format(elem.dtype)) 128 | 129 | return self.default_collate([torch.as_tensor(b) for b in batch]) 130 | elif elem.shape == (): # scalars 131 | return torch.as_tensor(batch) 132 | elif isinstance(elem, float): 133 | return torch.tensor(batch, dtype=torch.float64) 134 | elif isinstance(elem, int): 135 | return torch.tensor(batch) 136 | elif isinstance(elem, str): 137 | return batch 138 | elif isinstance(elem, collections.abc.Mapping): 139 | return {key: self.default_collate([d[key] for d in batch]) for key in elem} 140 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 141 | return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) 142 | elif isinstance(elem, collections.abc.Sequence): 143 | # check to make sure that the elements in batch have consistent size 144 | it = iter(batch) 145 | elem_size = len(next(it)) 146 | if not all(len(elem) == elem_size for elem in it): 147 | raise RuntimeError( 148 | 'each element in list of batch should be of equal size') 149 | transposed = zip(*batch) 150 | return [self.default_collate(samples) for samples in transposed] 151 | 152 | raise TypeError(self.default_collate_err_msg_format.format(elem_type)) 153 | 154 | def preprocess_text(self, text_queries): 155 | r"""Load list of class labels and return tokenized text""" 156 | tokenized_texts = [] 157 | for ttext in text_queries: 158 | if 'gpt' in self.args.text_model: 159 | ttext = ttext + ' <|endoftext|>' 160 | tok = self.tokenizer.encode_plus( 161 | text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt") 162 | for key in self.token_keys: 163 | tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1) 164 | tokenized_texts.append(tok) 165 | 166 | tokenized_texts_batch = {key: torch.cat([d[key].reshape(1,-1) for d in tokenized_texts]) for key in tokenized_texts[0]} 167 | return tokenized_texts_batch 168 | 169 | def get_text_embeddings(self): 170 | r"""Save text embeddings of PAM prompts""" 171 | preprocessed_text = self.preprocess_text(self.pam_prompts) 172 | self.pam_embeddings = self._get_text_embeddings(preprocessed_text) 173 | 174 | def _get_text_embeddings(self, preprocessed_text): 175 | r"""Load preprocessed text and return text embeddings""" 176 | with torch.no_grad(): 177 | return self.clap.caption_encoder(preprocessed_text) 178 | 179 | def _get_audio_embeddings(self, preprocessed_audio): 180 | r"""Load preprocessed audio and return a audio embeddings""" 181 | with torch.no_grad(): 182 | return self.clap.audio_encoder(preprocessed_audio)[0] 183 | 184 | def compute_similarity(self, audio_embeddings): 185 | r"""Compute similarity between text and audio embeddings""" 186 | audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) 187 | text_embeddings = self.pam_embeddings/torch.norm(self.pam_embeddings, dim=-1, keepdim=True) 188 | 189 | logit_scale = self.clap.logit_scale.exp() 190 | similarity = logit_scale*text_embeddings @ audio_embeddings.T 191 | return similarity.T 192 | 193 | def evaluate(self, audio_tensors, sample_index=None): 194 | r"""Compute PAM score using audio tensors""" 195 | if self.use_cuda and torch.cuda.is_available(): 196 | audio_tensors = audio_tensors.cuda() 197 | 198 | audio_embedddings = self._get_audio_embeddings(audio_tensors) 199 | sim = self.compute_similarity(audio_embedddings) 200 | prob = F.softmax(sim, dim=1) 201 | pam_score = prob[:,0] 202 | 203 | pam_score = pam_score.detach().cpu() 204 | if sample_index is not None: 205 | per_file_scores = [pam_score[sample_index[i]:sample_index[i+1]] for i in range(len(sample_index)-1)] 206 | avg_per_file_scores = [sum(x).item()/len(x) for x in per_file_scores] 207 | 208 | return avg_per_file_scores, per_file_scores -------------------------------------------------------------------------------- /PAM/models/htsat.py: -------------------------------------------------------------------------------- 1 | # Ke Chen 2 | # knutchen@ucsd.edu 3 | # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION 4 | # Model Core 5 | # below codes are based and referred from https://github.com/microsoft/Swin-Transformer 6 | # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf 7 | 8 | 9 | import math 10 | import random 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.checkpoint as checkpoint 14 | 15 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 16 | from torchlibrosa.augmentation import SpecAugmentation 17 | 18 | from itertools import repeat 19 | 20 | from .pytorch_utils import do_mixup, interpolate 21 | from . import config 22 | 23 | import collections.abc 24 | import warnings 25 | 26 | from torch.nn.init import _calculate_fan_in_and_fan_out 27 | 28 | def _ntuple(n): 29 | def parse(x): 30 | if isinstance(x, collections.abc.Iterable): 31 | return x 32 | return tuple(repeat(x, n)) 33 | return parse 34 | 35 | to_1tuple = _ntuple(1) 36 | to_2tuple = _ntuple(2) 37 | to_3tuple = _ntuple(3) 38 | to_4tuple = _ntuple(4) 39 | to_ntuple = _ntuple 40 | 41 | 42 | def drop_path(x, drop_prob: float = 0., training: bool = False): 43 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 44 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 45 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 46 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 47 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 48 | 'survival rate' as the argument. 49 | """ 50 | if drop_prob == 0. or not training: 51 | return x 52 | keep_prob = 1 - drop_prob 53 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 54 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 55 | random_tensor.floor_() # binarize 56 | output = x.div(keep_prob) * random_tensor 57 | return output 58 | 59 | 60 | class DropPath(nn.Module): 61 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 62 | """ 63 | def __init__(self, drop_prob=None): 64 | super(DropPath, self).__init__() 65 | self.drop_prob = drop_prob 66 | 67 | def forward(self, x): 68 | return drop_path(x, self.drop_prob, self.training) 69 | 70 | class PatchEmbed(nn.Module): 71 | """ 2D Image to Patch Embedding 72 | """ 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16): 74 | super().__init__() 75 | img_size = to_2tuple(img_size) 76 | patch_size = to_2tuple(patch_size) 77 | patch_stride = to_2tuple(patch_stride) 78 | self.img_size = img_size 79 | self.patch_size = patch_size 80 | self.patch_stride = patch_stride 81 | self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) 82 | self.num_patches = self.grid_size[0] * self.grid_size[1] 83 | self.flatten = flatten 84 | self.in_chans = in_chans 85 | self.embed_dim = embed_dim 86 | 87 | padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) 88 | 89 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) 90 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 91 | 92 | def forward(self, x): 93 | B, C, H, W = x.shape 94 | assert H == self.img_size[0] and W == self.img_size[1], \ 95 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 96 | x = self.proj(x) 97 | if self.flatten: 98 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 99 | x = self.norm(x) 100 | return x 101 | 102 | class Mlp(nn.Module): 103 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 104 | """ 105 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 106 | super().__init__() 107 | out_features = out_features or in_features 108 | hidden_features = hidden_features or in_features 109 | self.fc1 = nn.Linear(in_features, hidden_features) 110 | self.act = act_layer() 111 | self.fc2 = nn.Linear(hidden_features, out_features) 112 | self.drop = nn.Dropout(drop) 113 | 114 | def forward(self, x): 115 | x = self.fc1(x) 116 | x = self.act(x) 117 | x = self.drop(x) 118 | x = self.fc2(x) 119 | x = self.drop(x) 120 | return x 121 | 122 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 123 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 124 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 125 | def norm_cdf(x): 126 | # Computes standard normal cumulative distribution function 127 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 128 | 129 | if (mean < a - 2 * std) or (mean > b + 2 * std): 130 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 131 | "The distribution of values may be incorrect.", 132 | stacklevel=2) 133 | 134 | with torch.no_grad(): 135 | # Values are generated by using a truncated uniform distribution and 136 | # then using the inverse CDF for the normal distribution. 137 | # Get upper and lower cdf values 138 | l = norm_cdf((a - mean) / std) 139 | u = norm_cdf((b - mean) / std) 140 | 141 | # Uniformly fill tensor with values from [l, u], then translate to 142 | # [2l-1, 2u-1]. 143 | tensor.uniform_(2 * l - 1, 2 * u - 1) 144 | 145 | # Use inverse cdf transform for normal distribution to get truncated 146 | # standard normal 147 | tensor.erfinv_() 148 | 149 | # Transform to proper mean, std 150 | tensor.mul_(std * math.sqrt(2.)) 151 | tensor.add_(mean) 152 | 153 | # Clamp to ensure it's in the proper range 154 | tensor.clamp_(min=a, max=b) 155 | return tensor 156 | 157 | 158 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 159 | # type: (Tensor, float, float, float, float) -> Tensor 160 | r"""Fills the input Tensor with values drawn from a truncated 161 | normal distribution. The values are effectively drawn from the 162 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 163 | with values outside :math:`[a, b]` redrawn until they are within 164 | the bounds. The method used for generating the random values works 165 | best when :math:`a \leq \text{mean} \leq b`. 166 | Args: 167 | tensor: an n-dimensional `torch.Tensor` 168 | mean: the mean of the normal distribution 169 | std: the standard deviation of the normal distribution 170 | a: the minimum cutoff value 171 | b: the maximum cutoff value 172 | Examples: 173 | >>> w = torch.empty(3, 5) 174 | >>> nn.init.trunc_normal_(w) 175 | """ 176 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 177 | 178 | 179 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 180 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 181 | if mode == 'fan_in': 182 | denom = fan_in 183 | elif mode == 'fan_out': 184 | denom = fan_out 185 | elif mode == 'fan_avg': 186 | denom = (fan_in + fan_out) / 2 187 | 188 | variance = scale / denom 189 | 190 | if distribution == "truncated_normal": 191 | # constant is stddev of standard normal truncated to (-2, 2) 192 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 193 | elif distribution == "normal": 194 | tensor.normal_(std=math.sqrt(variance)) 195 | elif distribution == "uniform": 196 | bound = math.sqrt(3 * variance) 197 | tensor.uniform_(-bound, bound) 198 | else: 199 | raise ValueError(f"invalid distribution {distribution}") 200 | 201 | 202 | def lecun_normal_(tensor): 203 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 204 | 205 | 206 | # below codes are based and referred from https://github.com/microsoft/Swin-Transformer 207 | # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf 208 | 209 | def window_partition(x, window_size): 210 | """ 211 | Args: 212 | x: (B, H, W, C) 213 | window_size (int): window size 214 | Returns: 215 | windows: (num_windows*B, window_size, window_size, C) 216 | """ 217 | B, H, W, C = x.shape 218 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 219 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 220 | return windows 221 | 222 | 223 | def window_reverse(windows, window_size, H, W): 224 | """ 225 | Args: 226 | windows: (num_windows*B, window_size, window_size, C) 227 | window_size (int): Window size 228 | H (int): Height of image 229 | W (int): Width of image 230 | Returns: 231 | x: (B, H, W, C) 232 | """ 233 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 234 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 235 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 236 | return x 237 | 238 | 239 | class WindowAttention(nn.Module): 240 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 241 | It supports both of shifted and non-shifted window. 242 | Args: 243 | dim (int): Number of input channels. 244 | window_size (tuple[int]): The height and width of the window. 245 | num_heads (int): Number of attention heads. 246 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 247 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 248 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 249 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 250 | """ 251 | 252 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 253 | 254 | super().__init__() 255 | self.dim = dim 256 | self.window_size = window_size # Wh, Ww 257 | self.num_heads = num_heads 258 | head_dim = dim // num_heads 259 | self.scale = qk_scale or head_dim ** -0.5 260 | 261 | # define a parameter table of relative position bias 262 | self.relative_position_bias_table = nn.Parameter( 263 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 264 | 265 | # get pair-wise relative position index for each token inside the window 266 | coords_h = torch.arange(self.window_size[0]) 267 | coords_w = torch.arange(self.window_size[1]) 268 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 269 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 270 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 271 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 272 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 273 | relative_coords[:, :, 1] += self.window_size[1] - 1 274 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 275 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 276 | self.register_buffer("relative_position_index", relative_position_index) 277 | 278 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 279 | self.attn_drop = nn.Dropout(attn_drop) 280 | self.proj = nn.Linear(dim, dim) 281 | self.proj_drop = nn.Dropout(proj_drop) 282 | 283 | trunc_normal_(self.relative_position_bias_table, std=.02) 284 | self.softmax = nn.Softmax(dim=-1) 285 | 286 | def forward(self, x, mask=None): 287 | """ 288 | Args: 289 | x: input features with shape of (num_windows*B, N, C) 290 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 291 | """ 292 | B_, N, C = x.shape 293 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 294 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 295 | 296 | q = q * self.scale 297 | attn = (q @ k.transpose(-2, -1)) 298 | 299 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 300 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 301 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 302 | attn = attn + relative_position_bias.unsqueeze(0) 303 | 304 | if mask is not None: 305 | nW = mask.shape[0] 306 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 307 | attn = attn.view(-1, self.num_heads, N, N) 308 | attn = self.softmax(attn) 309 | else: 310 | attn = self.softmax(attn) 311 | 312 | attn = self.attn_drop(attn) 313 | 314 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 315 | x = self.proj(x) 316 | x = self.proj_drop(x) 317 | return x, attn 318 | 319 | def extra_repr(self): 320 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 321 | 322 | 323 | # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model 324 | class SwinTransformerBlock(nn.Module): 325 | r""" Swin Transformer Block. 326 | Args: 327 | dim (int): Number of input channels. 328 | input_resolution (tuple[int]): Input resulotion. 329 | num_heads (int): Number of attention heads. 330 | window_size (int): Window size. 331 | shift_size (int): Shift size for SW-MSA. 332 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 333 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 334 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 335 | drop (float, optional): Dropout rate. Default: 0.0 336 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 337 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 338 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 339 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 340 | """ 341 | 342 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 343 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 344 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): 345 | super().__init__() 346 | self.dim = dim 347 | self.input_resolution = input_resolution 348 | self.num_heads = num_heads 349 | self.window_size = window_size 350 | self.shift_size = shift_size 351 | self.mlp_ratio = mlp_ratio 352 | self.norm_before_mlp = norm_before_mlp 353 | if min(self.input_resolution) <= self.window_size: 354 | # if window size is larger than input resolution, we don't partition windows 355 | self.shift_size = 0 356 | self.window_size = min(self.input_resolution) 357 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 358 | 359 | self.norm1 = norm_layer(dim) 360 | self.attn = WindowAttention( 361 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 362 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 363 | 364 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 365 | if self.norm_before_mlp == 'ln': 366 | self.norm2 = nn.LayerNorm(dim) 367 | elif self.norm_before_mlp == 'bn': 368 | self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) 369 | else: 370 | raise NotImplementedError 371 | mlp_hidden_dim = int(dim * mlp_ratio) 372 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 373 | 374 | if self.shift_size > 0: 375 | # calculate attention mask for SW-MSA 376 | H, W = self.input_resolution 377 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 378 | h_slices = (slice(0, -self.window_size), 379 | slice(-self.window_size, -self.shift_size), 380 | slice(-self.shift_size, None)) 381 | w_slices = (slice(0, -self.window_size), 382 | slice(-self.window_size, -self.shift_size), 383 | slice(-self.shift_size, None)) 384 | cnt = 0 385 | for h in h_slices: 386 | for w in w_slices: 387 | img_mask[:, h, w, :] = cnt 388 | cnt += 1 389 | 390 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 391 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 392 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 393 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 394 | else: 395 | attn_mask = None 396 | 397 | self.register_buffer("attn_mask", attn_mask) 398 | 399 | def forward(self, x): 400 | # pdb.set_trace() 401 | H, W = self.input_resolution 402 | # print("H: ", H) 403 | # print("W: ", W) 404 | # pdb.set_trace() 405 | B, L, C = x.shape 406 | # assert L == H * W, "input feature has wrong size" 407 | 408 | shortcut = x 409 | x = self.norm1(x) 410 | x = x.view(B, H, W, C) 411 | 412 | # cyclic shift 413 | if self.shift_size > 0: 414 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 415 | else: 416 | shifted_x = x 417 | 418 | # partition windows 419 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 420 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 421 | 422 | # W-MSA/SW-MSA 423 | attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 424 | 425 | # merge windows 426 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 427 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 428 | 429 | # reverse cyclic shift 430 | if self.shift_size > 0: 431 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 432 | else: 433 | x = shifted_x 434 | x = x.view(B, H * W, C) 435 | 436 | # FFN 437 | x = shortcut + self.drop_path(x) 438 | x = x + self.drop_path(self.mlp(self.norm2(x))) 439 | 440 | return x, attn 441 | 442 | def extra_repr(self): 443 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 444 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 445 | 446 | 447 | 448 | class PatchMerging(nn.Module): 449 | r""" Patch Merging Layer. 450 | Args: 451 | input_resolution (tuple[int]): Resolution of input feature. 452 | dim (int): Number of input channels. 453 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 454 | """ 455 | 456 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 457 | super().__init__() 458 | self.input_resolution = input_resolution 459 | self.dim = dim 460 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 461 | self.norm = norm_layer(4 * dim) 462 | 463 | def forward(self, x): 464 | """ 465 | x: B, H*W, C 466 | """ 467 | H, W = self.input_resolution 468 | B, L, C = x.shape 469 | assert L == H * W, "input feature has wrong size" 470 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 471 | 472 | x = x.view(B, H, W, C) 473 | 474 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 475 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 476 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 477 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 478 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 479 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 480 | 481 | x = self.norm(x) 482 | x = self.reduction(x) 483 | 484 | return x 485 | 486 | def extra_repr(self): 487 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 488 | 489 | 490 | class BasicLayer(nn.Module): 491 | """ A basic Swin Transformer layer for one stage. 492 | Args: 493 | dim (int): Number of input channels. 494 | input_resolution (tuple[int]): Input resolution. 495 | depth (int): Number of blocks. 496 | num_heads (int): Number of attention heads. 497 | window_size (int): Local window size. 498 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 499 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 500 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 501 | drop (float, optional): Dropout rate. Default: 0.0 502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 507 | """ 508 | 509 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 510 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 511 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 512 | norm_before_mlp='ln'): 513 | 514 | super().__init__() 515 | self.dim = dim 516 | self.input_resolution = input_resolution 517 | self.depth = depth 518 | self.use_checkpoint = use_checkpoint 519 | 520 | # build blocks 521 | self.blocks = nn.ModuleList([ 522 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 523 | num_heads=num_heads, window_size=window_size, 524 | shift_size=0 if (i % 2 == 0) else window_size // 2, 525 | mlp_ratio=mlp_ratio, 526 | qkv_bias=qkv_bias, qk_scale=qk_scale, 527 | drop=drop, attn_drop=attn_drop, 528 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 529 | norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) 530 | for i in range(depth)]) 531 | 532 | # patch merging layer 533 | if downsample is not None: 534 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 535 | else: 536 | self.downsample = None 537 | 538 | def forward(self, x): 539 | attns = [] 540 | for blk in self.blocks: 541 | if self.use_checkpoint: 542 | x = checkpoint.checkpoint(blk, x) 543 | else: 544 | x, attn = blk(x) 545 | if not self.training: 546 | attns.append(attn.unsqueeze(0)) 547 | if self.downsample is not None: 548 | x = self.downsample(x) 549 | if not self.training: 550 | attn = torch.cat(attns, dim = 0) 551 | attn = torch.mean(attn, dim = 0) 552 | return x, attn 553 | 554 | def extra_repr(self): 555 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 556 | 557 | 558 | # The Core of HTSAT 559 | class HTSAT_Swin_Transformer(nn.Module): 560 | r"""HTSAT based on the Swin Transformer 561 | Args: 562 | spec_size (int | tuple(int)): Input Spectrogram size. Default 256 563 | patch_size (int | tuple(int)): Patch size. Default: 4 564 | path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 565 | in_chans (int): Number of input image channels. Default: 1 (mono) 566 | num_classes (int): Number of classes for classification head. Default: 527 567 | embed_dim (int): Patch embedding dimension. Default: 96 568 | depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. 569 | num_heads (tuple(int)): Number of attention heads in different layers. 570 | window_size (int): Window size. Default: 8 571 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 572 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 573 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 574 | drop_rate (float): Dropout rate. Default: 0 575 | attn_drop_rate (float): Attention dropout rate. Default: 0 576 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 577 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 578 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 579 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 580 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 581 | config (module): The configuration Module from config.py 582 | """ 583 | 584 | def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), 585 | in_chans=1, num_classes=527, 586 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], 587 | window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 588 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 589 | norm_layer=nn.LayerNorm, 590 | ape=False, patch_norm=True, 591 | use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs): 592 | super(HTSAT_Swin_Transformer, self).__init__() 593 | 594 | self.config = config 595 | self.spec_size = spec_size 596 | self.patch_stride = patch_stride 597 | self.patch_size = patch_size 598 | self.window_size = window_size 599 | self.embed_dim = embed_dim 600 | self.depths = depths 601 | self.ape = ape 602 | self.in_chans = in_chans 603 | self.num_classes = num_classes 604 | self.num_heads = num_heads 605 | self.num_layers = len(self.depths) 606 | self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) 607 | 608 | self.drop_rate = drop_rate 609 | self.attn_drop_rate = attn_drop_rate 610 | self.drop_path_rate = drop_path_rate 611 | 612 | self.qkv_bias = qkv_bias 613 | self.qk_scale = None 614 | 615 | self.patch_norm = patch_norm 616 | self.norm_layer = norm_layer if self.patch_norm else None 617 | self.norm_before_mlp = norm_before_mlp 618 | self.mlp_ratio = mlp_ratio 619 | 620 | self.use_checkpoint = use_checkpoint 621 | 622 | # process mel-spec ; used only once 623 | self.freq_ratio = self.spec_size // self.config.mel_bins 624 | window = 'hann' 625 | center = True 626 | pad_mode = 'reflect' 627 | ref = 1.0 628 | amin = 1e-10 629 | top_db = None 630 | self.interpolate_ratio = 32 # Downsampled ratio 631 | # Spectrogram extractor 632 | self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, 633 | win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, 634 | freeze_parameters=True) 635 | # Logmel feature extractor 636 | self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, 637 | n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, 638 | freeze_parameters=True) 639 | # Spec augmenter 640 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 641 | freq_drop_width=8, freq_stripes_num=2) # 2 2 642 | self.bn0 = nn.BatchNorm2d(self.config.mel_bins) 643 | 644 | 645 | # split spctrogram into non-overlapping patches 646 | self.patch_embed = PatchEmbed( 647 | img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, 648 | embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride) 649 | 650 | num_patches = self.patch_embed.num_patches 651 | patches_resolution = self.patch_embed.grid_size 652 | self.patches_resolution = patches_resolution 653 | 654 | # absolute position embedding 655 | if self.ape: 656 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) 657 | trunc_normal_(self.absolute_pos_embed, std=.02) 658 | 659 | self.pos_drop = nn.Dropout(p=self.drop_rate) 660 | 661 | # stochastic depth 662 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule 663 | 664 | # build layers 665 | self.layers = nn.ModuleList() 666 | for i_layer in range(self.num_layers): 667 | layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), 668 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 669 | patches_resolution[1] // (2 ** i_layer)), 670 | depth=self.depths[i_layer], 671 | num_heads=self.num_heads[i_layer], 672 | window_size=self.window_size, 673 | mlp_ratio=self.mlp_ratio, 674 | qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, 675 | drop=self.drop_rate, attn_drop=self.attn_drop_rate, 676 | drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], 677 | norm_layer=self.norm_layer, 678 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 679 | use_checkpoint=use_checkpoint, 680 | norm_before_mlp=self.norm_before_mlp) 681 | self.layers.append(layer) 682 | 683 | self.norm = self.norm_layer(self.num_features) 684 | self.avgpool = nn.AdaptiveAvgPool1d(1) 685 | self.maxpool = nn.AdaptiveMaxPool1d(1) 686 | 687 | if self.config.enable_tscam: 688 | SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio 689 | self.tscam_conv = nn.Conv2d( 690 | in_channels = self.num_features, 691 | out_channels = self.num_classes, 692 | kernel_size = (SF,3), 693 | padding = (0,1) 694 | ) 695 | self.head = nn.Linear(num_classes, num_classes) 696 | else: 697 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 698 | 699 | self.apply(self._init_weights) 700 | 701 | def _init_weights(self, m): 702 | if isinstance(m, nn.Linear): 703 | trunc_normal_(m.weight, std=.02) 704 | if isinstance(m, nn.Linear) and m.bias is not None: 705 | nn.init.constant_(m.bias, 0) 706 | elif isinstance(m, nn.LayerNorm): 707 | nn.init.constant_(m.bias, 0) 708 | nn.init.constant_(m.weight, 1.0) 709 | 710 | @torch.jit.ignore 711 | def no_weight_decay(self): 712 | return {'absolute_pos_embed'} 713 | 714 | @torch.jit.ignore 715 | def no_weight_decay_keywords(self): 716 | return {'relative_position_bias_table'} 717 | 718 | def forward_features(self, x): 719 | frames_num = x.shape[2] 720 | x = self.patch_embed(x) 721 | if self.ape: 722 | x = x + self.absolute_pos_embed 723 | x = self.pos_drop(x) 724 | for i, layer in enumerate(self.layers): 725 | x, attn = layer(x) 726 | 727 | if self.config.enable_tscam: 728 | # for x 729 | x = self.norm(x) 730 | B, N, C = x.shape 731 | SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] 732 | ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] 733 | x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) 734 | B, C, F, T = x.shape 735 | # group 2D CNN 736 | c_freq_bin = F // self.freq_ratio 737 | x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) 738 | x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) 739 | 740 | # get latent_output 741 | latent_output = self.avgpool(torch.flatten(x,2)) 742 | latent_output = torch.flatten(latent_output, 1) 743 | 744 | # display the attention map, if needed 745 | if self.config.htsat_attn_heatmap: 746 | # for attn 747 | attn = torch.mean(attn, dim = 1) 748 | attn = torch.mean(attn, dim = 1) 749 | attn = attn.reshape(B, SF, ST) 750 | c_freq_bin = SF // self.freq_ratio 751 | attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) 752 | attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1) 753 | attn = attn.mean(dim = 1) 754 | attn_max = torch.max(attn, dim = 1, keepdim = True)[0] 755 | attn_min = torch.min(attn, dim = 1, keepdim = True)[0] 756 | attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min) 757 | attn = attn.unsqueeze(dim = 2) 758 | 759 | x = self.tscam_conv(x) 760 | x = torch.flatten(x, 2) # B, C, T 761 | 762 | if self.config.htsat_attn_heatmap: 763 | fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) 764 | else: 765 | fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) 766 | 767 | x = self.avgpool(x) 768 | x = torch.flatten(x, 1) 769 | 770 | if self.config.loss_type == "clip_ce": 771 | output_dict = { 772 | 'framewise_output': fpx, # already sigmoided 773 | 'clipwise_output': x, 774 | 'latent_output': latent_output 775 | } 776 | else: 777 | output_dict = { 778 | 'framewise_output': fpx, # already sigmoided 779 | 'clipwise_output': torch.sigmoid(x), 780 | 'latent_output': latent_output 781 | } 782 | 783 | else: 784 | x = self.norm(x) # B N C 785 | B, N, C = x.shape 786 | 787 | fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) ) 788 | B, C, F, T = fpx.shape 789 | c_freq_bin = F // self.freq_ratio 790 | fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T) 791 | fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) 792 | fpx = torch.sum(fpx, dim = 2) 793 | fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) 794 | x = self.avgpool(x.transpose(1, 2)) # B C 1 795 | x = torch.flatten(x, 1) 796 | if self.num_classes > 0: 797 | x = self.head(x) 798 | fpx = self.head(fpx) 799 | output_dict = {'framewise_output': torch.sigmoid(fpx), 800 | 'clipwise_output': torch.sigmoid(x)} 801 | return output_dict 802 | 803 | def crop_wav(self, x, crop_size, spe_pos = None): 804 | time_steps = x.shape[2] 805 | tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) 806 | for i in range(len(x)): 807 | if spe_pos is None: 808 | crop_pos = random.randint(0, time_steps - crop_size - 1) 809 | else: 810 | crop_pos = spe_pos 811 | tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] 812 | return tx 813 | 814 | # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model 815 | def reshape_wav2img(self, x): 816 | B, C, T, F = x.shape 817 | target_T = int(self.spec_size * self.freq_ratio) 818 | target_F = self.spec_size // self.freq_ratio 819 | assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" 820 | # to avoid bicubic zero error 821 | if T < target_T: 822 | x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) 823 | if F < target_F: 824 | x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) 825 | x = x.permute(0,1,3,2).contiguous() 826 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) 827 | # print(x.shape) 828 | x = x.permute(0,1,3,2,4).contiguous() 829 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) 830 | return x 831 | 832 | # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model 833 | def repeat_wat2img(self, x, cur_pos): 834 | B, C, T, F = x.shape 835 | target_T = int(self.spec_size * self.freq_ratio) 836 | target_F = self.spec_size // self.freq_ratio 837 | assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" 838 | # to avoid bicubic zero error 839 | if T < target_T: 840 | x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) 841 | if F < target_F: 842 | x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) 843 | x = x.permute(0,1,3,2).contiguous() # B C F T 844 | x = x[:,:,:,cur_pos:cur_pos + self.spec_size] 845 | x = x.repeat(repeats = (1,1,4,1)) 846 | return x 847 | 848 | def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None): 849 | x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) 850 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 851 | 852 | 853 | x = x.transpose(1, 3) 854 | x = self.bn0(x) 855 | x = x.transpose(1, 3) 856 | if self.training: 857 | x = self.spec_augmenter(x) 858 | if self.training and mixup_lambda is not None: 859 | x = do_mixup(x, mixup_lambda) 860 | 861 | if infer_mode: 862 | # in infer mode. we need to handle different length audio input 863 | frame_num = x.shape[2] 864 | target_T = int(self.spec_size * self.freq_ratio) 865 | repeat_ratio = math.floor(target_T / frame_num) 866 | x = x.repeat(repeats=(1,1,repeat_ratio,1)) 867 | x = self.reshape_wav2img(x) 868 | output_dict = self.forward_features(x) 869 | elif self.config.enable_repeat_mode: 870 | if self.training: 871 | cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1) 872 | x = self.repeat_wat2img(x, cur_pos) 873 | output_dict = self.forward_features(x) 874 | else: 875 | output_dicts = [] 876 | for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size): 877 | tx = x.clone() 878 | tx = self.repeat_wat2img(tx, cur_pos) 879 | output_dicts.append(self.forward_features(tx)) 880 | clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) 881 | framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) 882 | for d in output_dicts: 883 | clipwise_output += d["clipwise_output"] 884 | framewise_output += d["framewise_output"] 885 | clipwise_output = clipwise_output / len(output_dicts) 886 | framewise_output = framewise_output / len(output_dicts) 887 | 888 | output_dict = { 889 | 'framewise_output': framewise_output, 890 | 'clipwise_output': clipwise_output 891 | } 892 | else: 893 | if x.shape[2] > self.freq_ratio * self.spec_size: 894 | if self.training: 895 | x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) 896 | x = self.reshape_wav2img(x) 897 | output_dict = self.forward_features(x) 898 | else: 899 | # Change: Hard code here 900 | overlap_size = 344 #(x.shape[2] - 1) // 4 901 | output_dicts = [] 902 | crop_size = 689 #(x.shape[2] - 1) // 2 903 | for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): 904 | tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) 905 | tx = self.reshape_wav2img(tx) 906 | output_dicts.append(self.forward_features(tx)) 907 | clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) 908 | framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) 909 | latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device) 910 | for d in output_dicts: 911 | clipwise_output += d["clipwise_output"] 912 | framewise_output += d["framewise_output"] 913 | latent_output += d["latent_output"] 914 | clipwise_output = clipwise_output / len(output_dicts) 915 | framewise_output = framewise_output / len(output_dicts) 916 | latent_output = latent_output / len(output_dicts) 917 | output_dict = { 918 | 'framewise_output': framewise_output, 919 | 'clipwise_output': clipwise_output, 920 | 'latent_output': latent_output, 921 | } 922 | else: # this part is typically used, and most easy one 923 | x = self.reshape_wav2img(x) 924 | output_dict = self.forward_features(x) 925 | # x = self.head(x) 926 | return output_dict 927 | 928 | class HTSATWrapper(nn.Module): 929 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 930 | fmax, classes_num, out_emb): 931 | super().__init__() 932 | 933 | # print("parameters are being overidden when using HTSAT") 934 | # print("HTSAT only support loading a pretrained model on AudioSet") 935 | # @TODO later look at what parameters are same and can be merged 936 | 937 | self.htsat = HTSAT_Swin_Transformer(config=config) 938 | 939 | def forward(self, x): 940 | out_dict = self.htsat(x) 941 | out_dict['embedding'] = out_dict['latent_output'] 942 | return out_dict --------------------------------------------------------------------------------