├── LICENSE ├── README.md ├── chinese_wwm_ext_pytorch └── .gitignore ├── config.py ├── data └── .gitignore ├── dataloader.py ├── model.py ├── model └── .gitignore ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jiaan Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Model Chorus Recognition 2 | Code and Dataset of ICANN2021 paper: *Multi-Model Chorus Recognition for Improving Song Search* 3 | 4 | ## Dependency 5 | - python 3.6+ 6 | - PyTorch 1.0+ 7 | - Transformers 2.6.0 8 | 9 | - others 10 | - python_speech_features 11 | - Pandas 12 | 13 | ## CHORD 14 | You can download the CHOrus Recognition Dataset [here](https://drive.google.com/file/d/1nkoDvCym3hz_qI9u6XnzE9JEkBVnOqmM/view?usp=sharing) and unzip to ```data/``` file. 15 | 16 | ## Pre-trained Language Model 17 | We used ```BERT-wwm-ext, Chinese``` pre-trained language model 18 | Related introduction and download link please refer to [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B%E4%B8%8B%E8%BD%BD). 19 | 20 | ## Train and Test 21 | For training, you can run commands like this: 22 | ```shell 23 | python train.py 24 | ``` 25 | 26 | For evaluation, the command may like this: 27 | ```shell 28 | python test.py 29 | ``` 30 | 31 | ## Citation 32 | If you use this code useful, please star our repo or consider citing: 33 | ``` 34 | @article{Wang2021MultiModalCR, 35 | title={Multi-Modal Chorus Recognition for Improving Song Search}, 36 | author={Jiaan Wang and Zhixu Li and Binbin Gu and Tingyi Zhang and Qingsheng Liu and Zhigang Chen}, 37 | journal={ArXiv}, 38 | year={2021}, 39 | volume={abs/2106.16153} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /chinese_wwm_ext_pytorch/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krystalan/MMCR/ab3ad3d5cc492c80063de88f0127768b836cfc69/chinese_wwm_ext_pytorch/.gitignore -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import torch 4 | 5 | 6 | class Config(object): 7 | def __init__(self, **kwargs): 8 | if kwargs is not None: 9 | for key, value in kwargs.items(): 10 | setattr(self, key, value) 11 | # file_path 12 | self.chord_embedding_path = 'data/chord_embedding.pkl' 13 | self.music_path = 'data/music/' 14 | self.save_path = 'model/model.pkl' 15 | self.train_path = 'data/train.csv' 16 | self.test_path = 'data/test.csv' 17 | # device 18 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | # Pre-trained Language Model 20 | self.PTM = 'chinese_wwm_ext_pytorch' 21 | 22 | def __str__(self): 23 | """Pretty-print configurations in alphabetical order""" 24 | config_str = 'Configurations\n' 25 | config_str += pprint.pformat(self.__dict__) 26 | return config_str 27 | 28 | def get_config(): 29 | parser = argparse.ArgumentParser() 30 | 31 | # # load setting 32 | # parser.add_argument('--checkpoint', type=str, default=None) 33 | 34 | # train setting 35 | parser.add_argument('--learning_rate', type=float, default=6e-4) 36 | parser.add_argument('--n_epoch', type=int, default=5) 37 | parser.add_argument('-b', '--batch_size', type=int, default=128) 38 | parser.add_argument('-tb', '--test_batch_size', type=int, default=1) 39 | 40 | 41 | kwargs = parser.parse_args() 42 | 43 | # Namespace => Dictionary 44 | kwargs = vars(kwargs) 45 | 46 | return Config(**kwargs) 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krystalan/MMCR/ab3ad3d5cc492c80063de88f0127768b836cfc69/data/.gitignore -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import pandas as pd 4 | from config import get_config 5 | from utils import get_MFCC,get_lyric_feature 6 | import pickle 7 | import math 8 | 9 | config = get_config() 10 | 11 | 12 | class MusicDataset(Dataset): 13 | def __init__(self, csv_file, max_len=300, dim=768): 14 | self.data = pd.read_csv(csv_file,encoding='utf-8',header=None) 15 | self.data = self.data.values.tolist() 16 | self.chord_embedding = torch.nn.Embedding(10,64) 17 | with open(config.chord_embedding_path,'rb') as f: 18 | pretrained_weight = pickle.load(f) 19 | self.chord_embedding.weight.data.copy_(pretrained_weight) 20 | self.C_to_N = {'A': 0, 'Am': 1, 'Bm': 2, 'C': 3, 'D': 4,'Dm': 5, 'E': 6, 'Em': 7, 'F': 8, 'G': 9} 21 | self.pe = torch.zeros(max_len, dim) 22 | position = torch.arange(0, max_len).unsqueeze(1) 23 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *-(math.log(10000.0) / dim))) 24 | self.pe[:, 0::2] = torch.sin(position.float() * div_term) 25 | self.pe[:, 1::2] = torch.cos(position.float() * div_term) 26 | self.len = len(self.data) 27 | 28 | def __getitem__(self, idx): 29 | position = int(self.data[idx][1]) 30 | lyric_feature = get_lyric_feature(self.data[idx][2]) 31 | lyric_feature = lyric_feature + self.pe[position].unsqueeze(0) 32 | MFCC_feature = torch.tensor(get_MFCC(config.music_path+self.data[idx][0]+'.wav',self.data[idx][3],self.data[idx][4])).to(config.device) 33 | length = MFCC_feature.size()[0] 34 | if length > 1280: 35 | MFCC_feature = MFCC_feature[0:1280].to(config.device) 36 | if length < 1280: 37 | padding = torch.zeros(1280-length,13).to(config.device) 38 | MFCC_feature = torch.cat((MFCC_feature,padding),0).to(config.device) 39 | 40 | chord = self.data[idx][5] 41 | if chord!='_': 42 | chord = eval(chord) 43 | chord = [self.C_to_N[i] for i in chord] 44 | else: 45 | chord = [] 46 | 47 | if len(chord)>20: 48 | chord = chord[0:20] 49 | lens = len(chord) 50 | if lens!=0: 51 | chord = torch.tensor(chord).to(config.device) 52 | chord_feature = self.chord_embedding(chord) 53 | chord_feature = chord_feature.view(lens*64,1) 54 | else: 55 | chord_feature = torch.tensor([[]]) 56 | chord_feature = torch.transpose(chord_feature,1,0).to(config.device) 57 | length = chord_feature.size()[0] 58 | if length < 1280: 59 | padding = torch.zeros(1280-length,1).to(config.device) 60 | chord_feature = torch.cat((chord_feature,padding),0) 61 | res = [lyric_feature,MFCC_feature,chord_feature] 62 | return res, self.data[idx][6] 63 | 64 | def __len__(self): 65 | return self.len 66 | 67 | 68 | def get_loader(csv_file,bs): 69 | dataset = MusicDataset(csv_file) 70 | dataloader = DataLoader(dataset=dataset, batch_size=bs, drop_last=True) 71 | return dataloader 72 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | import math 5 | from config import get_config 6 | 7 | config = get_config() 8 | 9 | class MusicModel(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.actv = nn.ReLU() 13 | self.music_layer = nn.Linear(13,10) 14 | self.music_layer_2 = nn.Linear(10,1) 15 | self.classifier1 = nn.Linear(3328,1000) 16 | self.classifier2 = nn.Linear(1000,512) 17 | self.classifier3 = nn.Linear(512,1) 18 | 19 | def forward(self,model_input): 20 | lyric_feature = model_input[0] 21 | MFCC_feature = model_input[1] 22 | chord_feature = model_input[2] 23 | 24 | lyric_feature = torch.transpose(lyric_feature,2,1) 25 | 26 | music_feature = self.music_layer(MFCC_feature) 27 | music_feature = self.music_layer_2(music_feature) 28 | 29 | all_feature = torch.cat((lyric_feature,music_feature,chord_feature),1) 30 | all_feature = torch.transpose(all_feature,2,1) 31 | 32 | output = self.actv(self.classifier1(all_feature)) 33 | output = self.actv(self.classifier2(output)) 34 | output = self.classifier3(output) 35 | return torch.sigmoid(output) -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krystalan/MMCR/ab3ad3d5cc492c80063de88f0127768b836cfc69/model/.gitignore -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model import MusicModel 4 | from config import get_config 5 | from dataloader import get_loader 6 | from tqdm import tqdm 7 | 8 | def test(model_path,test_path,bs): 9 | model = MusicModel() 10 | model = torch.load(model_path) 11 | 12 | test_loader = get_loader(test_path,bs) 13 | test_loader = iter(test_loader) 14 | 15 | TP = 0 16 | FN = 0 17 | FP = 0 18 | TN = 0 19 | 20 | desc = ' - (Testing) - ' 21 | for (data,label) in tqdm(test_loader,desc=desc,ncols=80): 22 | result = float(model(data).squeeze(-1).squeeze(-1)) 23 | label = int(label[0]) 24 | 25 | if label==1: 26 | if result >= 0.5: 27 | TP += 1 28 | else: 29 | FN += 1 30 | else: 31 | if result >= 0.5: 32 | FP += 1 33 | else: 34 | TN += 1 35 | acc = float(TP+TN)/float(TP+FN+FP+TN) 36 | acc = round(acc*100,2) 37 | print('ACC:'+str(acc)) 38 | 39 | 40 | if __name__ == "__main__": 41 | config = get_config() 42 | test(config.save_path,config.test_path,config.test_batch_size) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from model import MusicModel 5 | from config import get_config 6 | from dataloader import get_loader 7 | from tqdm import tqdm 8 | 9 | seed = 4 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | def train(config): 15 | model = MusicModel() 16 | model.to(config.device) 17 | optimizer = optim.Adam(model.parameters(),lr=config.learning_rate) 18 | criterion = nn.BCELoss() 19 | 20 | train_loader = get_loader(config.train_path,config.batch_size) 21 | train_loader = iter(train_loader) 22 | 23 | for epoch in range(1,1+config.n_epoch): 24 | desc = ' - (Training|epoch:'+str(epoch)+') - ' 25 | for (data,label) in tqdm(train_loader,desc=desc,ncols=100): 26 | result = model(data).squeeze(-1).squeeze(-1) 27 | label = label.float() 28 | loss = criterion(result,label) 29 | optimizer.zero_grad() 30 | loss.backward() 31 | optimizer.step() 32 | torch.save(model,config.save_path) 33 | 34 | if __name__ == "__main__": 35 | config = get_config() 36 | train(config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertConfig 3 | from transformers import BertModel 4 | from transformers import BertTokenizer 5 | import pandas as pd 6 | import torch 7 | from config import get_config 8 | import scipy.io.wavfile 9 | from python_speech_features import mfcc 10 | 11 | config = get_config() 12 | tokenizer = BertTokenizer.from_pretrained(config.PTM) 13 | model = BertModel.from_pretrained(config.PTM) 14 | model.to(config.device) 15 | 16 | def get_lyric_feature(lyric): 17 | input_id = tokenizer.encode(lyric) 18 | input_id = torch.tensor([input_id]) 19 | input_id = input_id.to(config.device) 20 | _ , pooled_output = model(input_id) 21 | return pooled_output 22 | 23 | def get_MFCC(root,start,end): 24 | fs, sig = scipy.io.wavfile.read(root) 25 | start = int(float(start)*fs)-1 26 | if end == '[end]': 27 | return mfcc(sig[start::],fs).tolist() 28 | else: 29 | end = int(float(end)*fs) 30 | return mfcc(sig[start:end],fs).tolist() --------------------------------------------------------------------------------