├── .gitignore ├── dataloader └── dataset.py ├── notebook └── EDA.ipynb ├── preprocessing ├── audio_utils.py ├── constants.py ├── main.py └── msd_preprocessor.py ├── readme.md └── requirments.txt /.gitignore: -------------------------------------------------------------------------------- 1 | error/ 2 | wav/ 3 | npy/ 4 | songs/ 5 | dataset/ 6 | *.yaml 7 | *.ckpt 8 | *.pt 9 | *.png 10 | env/ 11 | grid_search/ 12 | lightning_logs/ 13 | .ipynb_checkpoints/ 14 | scripts/ 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from typing import Callable, List, Dict, Any 9 | from torch.utils.data import Dataset 10 | 11 | class ECALS_Dataset(Dataset): 12 | def __init__(self, data_path, split, sr, duration, num_chunks, text_preprocessor=None, text_type="bert", text_rep="stochastic"): 13 | self.data_path = data_path 14 | self.split = split 15 | self.sr = sr 16 | self.text_preprocessor = text_preprocessor 17 | self.input_length = int(sr * duration) 18 | self.num_chunks = num_chunks 19 | self.text_type = text_type 20 | self.text_rep = text_rep 21 | self.msd_to_id = pickle.load(open(os.path.join(data_path, "lastfm_annotation", "MSD_id_to_7D_id.pkl"), 'rb')) 22 | self.id_to_path = pickle.load(open(os.path.join(data_path, "lastfm_annotation", "7D_id_to_path.pkl"), 'rb')) 23 | self.get_split() 24 | self.get_file_list() 25 | 26 | def get_split(self): 27 | track_split = json.load(open(os.path.join(self.data_path, "ecals_annotation", "ecals_track_split.json"), "r")) 28 | self.train_track = track_split['train_track'] + track_split['extra_track'] 29 | self.valid_track = track_split['valid_track'] 30 | self.test_track = track_split['test_track'] 31 | 32 | def get_file_list(self): 33 | annotation = json.load(open(os.path.join(self.data_path, "ecals_annotation", "annotation.json"), 'r')) 34 | self.list_of_label = json.load(open(os.path.join(self.data_path, "ecals_annotation", "ecals_tags.json"), 'r')) 35 | self.tag_to_idx = {i:idx for idx, i in enumerate(self.list_of_label)} 36 | if self.split == "TRAIN": 37 | self.fl = [annotation[i] for i in self.train_track] 38 | elif self.split == "VALID": 39 | self.fl = [annotation[i] for i in self.valid_track] 40 | elif self.split == "TEST": 41 | self.fl = [annotation[i] for i in self.test_track] 42 | else: 43 | raise ValueError(f"Unexpected split name: {self.split}") 44 | del annotation 45 | 46 | def audio_load(self, msd_id): 47 | audio_path = self.id_to_path[self.msd_to_id[msd_id]] 48 | audio = np.load(os.path.join(self.data_path, "npy", audio_path.replace(".mp3",".npy")), mmap_mode='r') 49 | random_idx = random.randint(0, audio.shape[-1]-self.input_length) 50 | audio = torch.from_numpy(np.array(audio[random_idx:random_idx+self.input_length])) 51 | return audio 52 | 53 | def tag_to_binary(self, tag_list): 54 | bainry = np.zeros([len(self.list_of_label),], dtype=np.float32) 55 | for tag in tag_list: 56 | bainry[self.tag_to_idx[tag]] = 1.0 57 | return bainry 58 | 59 | def text_load(self, tag_list): 60 | """ 61 | input: tag_list = list of tag 62 | output: text = string of text 63 | """ 64 | if self.text_rep == "caption": 65 | if self.split == "TRAIN": 66 | random.shuffle(tag_list) 67 | text = tag_list 68 | elif self.text_rep == "tag": 69 | text = [random.choice(tag_list)] 70 | elif self.text_rep == "stochastic": 71 | k = random.choice(range(1, len(tag_list)+1)) 72 | text = random.sample(tag_list, k) 73 | return text 74 | 75 | def get_train_item(self, index): 76 | item = self.fl[index] 77 | tag_list = item['tag'] 78 | binary = self.tag_to_binary(tag_list) 79 | text = self.text_load(tag_list) 80 | audio_tensor = self.audio_load(item['track_id']) 81 | return { 82 | "audio":audio_tensor, 83 | "binary":binary, 84 | "text":text 85 | } 86 | 87 | def get_eval_item(self, index): 88 | item = self.fl[index] 89 | tag_list = item['tag'] 90 | binary = self.tag_to_binary(tag_list) 91 | text = self.text_load(tag_list) 92 | tags = self.list_of_label 93 | track_id = item['track_id'] 94 | audio_path = self.id_to_path[self.msd_to_id[track_id]] 95 | audio = np.load(os.path.join(self.data_path, "npy", audio_path.replace(".mp3",".npy")), mmap_mode='r') 96 | hop = (len(audio) - self.input_length) // self.num_chunks 97 | audio = np.stack([np.array(audio[i * hop : i * hop + self.input_length]) for i in range(self.num_chunks)]).astype('float32') 98 | return { 99 | "audio":audio, 100 | "track_id":track_id, 101 | "tags":tags, 102 | "binary":binary, 103 | "text":text 104 | } 105 | 106 | def __getitem__(self, index): 107 | if (self.split=='TRAIN') or (self.split=='VALID'): 108 | return self.get_train_item(index) 109 | else: 110 | return self.get_eval_item(index) 111 | 112 | def batch_processor(self, batch): 113 | # batch = list of dcitioanry 114 | audio = [item_dict['audio'] for item_dict in batch] 115 | binary = [item_dict['binary'] for item_dict in batch] 116 | audios = torch.stack(audio) 117 | binarys = torch.tensor(np.stack(binary)) 118 | text, text_mask = self._text_preprocessor(batch, "text") 119 | return {"audio":audios, "binary":binarys, "text":text, "text_mask":text_mask} 120 | 121 | def _text_preprocessor(self, batch, target_text): 122 | if self.text_type == "bert": 123 | batch_text = [", ".join(item_dict[target_text]) for item_dict in batch] 124 | encoding = self.text_preprocessor.batch_encode_plus(batch_text, padding='longest', max_length=64, truncation=True, return_tensors="pt") 125 | text = encoding['input_ids'] 126 | text_mask = encoding['attention_mask'] 127 | elif self.text_type == "glove": 128 | batch_emb = [] 129 | batch_text = [item_dict[target_text] for item_dict in batch] 130 | for tag_seq in batch_text: 131 | tag_seq_emb = [np.array(self.text_preprocessor[token]).astype('float32') for token in tag_seq] 132 | batch_emb.append(torch.from_numpy(np.mean(tag_seq_emb, axis=0))) 133 | text = torch.stack(batch_emb) 134 | text_mask = None 135 | return text, text_mask 136 | 137 | def __len__(self): 138 | return len(self.fl) 139 | -------------------------------------------------------------------------------- /preprocessing/audio_utils.py: -------------------------------------------------------------------------------- 1 | STR_CLIP_ID = 'clip_id' 2 | STR_AUDIO_SIGNAL = 'audio_signal' 3 | STR_TARGET_VECTOR = 'target_vector' 4 | 5 | 6 | STR_CH_FIRST = 'channels_first' 7 | STR_CH_LAST = 'channels_last' 8 | 9 | import io 10 | import os 11 | import tqdm 12 | import logging 13 | import subprocess 14 | from typing import Tuple 15 | from pathlib import Path 16 | 17 | # import librosa 18 | import numpy as np 19 | import soundfile as sf 20 | 21 | import itertools 22 | from numpy.fft import irfft 23 | 24 | def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]: 25 | """ 26 | Decoding, downmixing, and downsampling by librosa. 27 | Returns a channel-first audio signal. 28 | 29 | Args: 30 | path: 31 | sample_rate: 32 | downmix_to_mono: 33 | 34 | Returns: 35 | (audio signal, sample rate) 36 | """ 37 | 38 | def _decode_resample_by_ffmpeg(filename, sr): 39 | """decode, downmix, and resample audio file""" 40 | channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option 41 | resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option 42 | cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -" 43 | p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 44 | out, err = p.communicate() 45 | return out 46 | 47 | src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate))) 48 | return src.T, sr 49 | 50 | 51 | def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]: 52 | """ 53 | Decoding, downmixing, and downsampling by librosa. 54 | Returns a channel-first audio signal. 55 | """ 56 | src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs) 57 | return src, sr 58 | 59 | 60 | def load_audio( 61 | path: str or Path, 62 | ch_format: str, 63 | sample_rate: int = None, 64 | downmix_to_mono: bool = False, 65 | resample_by: str = 'ffmpeg', 66 | **kwargs, 67 | ) -> Tuple[np.ndarray, int]: 68 | """A wrapper of librosa.load that: 69 | - forces the returned audio to be 2-dim, 70 | - defaults to sr=None, and 71 | - defaults to downmix_to_mono=False. 72 | 73 | The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg. 74 | The resampling is done by `librosa`'s child package `resampy`. 75 | 76 | Args: 77 | path: audio file path 78 | ch_format: one of 'channels_first' or 'channels_last' 79 | sample_rate: target sampling rate. if None, use the rate of the audio file 80 | downmix_to_mono: 81 | resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling. 82 | **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type. 83 | 84 | Returns: 85 | (audio, sr) tuple 86 | """ 87 | if ch_format not in (STR_CH_FIRST, STR_CH_LAST): 88 | raise ValueError(f'ch_format is wrong here -> {ch_format}') 89 | 90 | if os.stat(path).st_size > 22050: 91 | if resample_by == 'librosa': 92 | src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs) 93 | elif resample_by == 'ffmpeg': 94 | src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono) 95 | else: 96 | raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet') 97 | else: 98 | raise ValueError('Given audio is too short!') 99 | return src, sr 100 | 101 | # if src.ndim == 1: 102 | # src = np.expand_dims(src, axis=0) 103 | # # now always 2d and channels_first 104 | 105 | # if ch_format == STR_CH_FIRST: 106 | # return src, sr 107 | # else: 108 | # return src.T, sr 109 | 110 | def ms(x): 111 | """Mean value of signal `x` squared. 112 | :param x: Dynamic quantity. 113 | :returns: Mean squared of `x`. 114 | """ 115 | return (np.abs(x)**2.0).mean() 116 | 117 | def normalize(y, x=None): 118 | """normalize power in y to a (standard normal) white noise signal. 119 | Optionally normalize to power in signal `x`. 120 | #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1. 121 | """ 122 | if x is not None: 123 | x = ms(x) 124 | else: 125 | x = 1.0 126 | return y * np.sqrt(x / ms(y)) 127 | 128 | def noise(N, color='white', state=None): 129 | """Noise generator. 130 | :param N: Amount of samples. 131 | :param color: Color of noise. 132 | :param state: State of PRNG. 133 | :type state: :class:`np.random.RandomState` 134 | """ 135 | try: 136 | return _noise_generators[color](N, state) 137 | except KeyError: 138 | raise ValueError("Incorrect color.") 139 | 140 | def white(N, state=None): 141 | """ 142 | White noise. 143 | :param N: Amount of samples. 144 | :param state: State of PRNG. 145 | :type state: :class:`np.random.RandomState` 146 | White noise has a constant power density. It's narrowband spectrum is therefore flat. 147 | The power in white noise will increase by a factor of two for each octave band, 148 | and therefore increases with 3 dB per octave. 149 | """ 150 | state = np.random.RandomState() if state is None else state 151 | return state.randn(N) 152 | 153 | def pink(N, state=None): 154 | """ 155 | Pink noise. 156 | :param N: Amount of samples. 157 | :param state: State of PRNG. 158 | :type state: :class:`np.random.RandomState` 159 | Pink noise has equal power in bands that are proportionally wide. 160 | Power density decreases with 3 dB per octave. 161 | """ 162 | state = np.random.RandomState() if state is None else state 163 | uneven = N % 2 164 | X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) 165 | S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero 166 | y = (irfft(X / S)).real 167 | if uneven: 168 | y = y[:-1] 169 | return normalize(y) 170 | 171 | def blue(N, state=None): 172 | """ 173 | Blue noise. 174 | :param N: Amount of samples. 175 | :param state: State of PRNG. 176 | :type state: :class:`np.random.RandomState` 177 | Power increases with 6 dB per octave. 178 | Power density increases with 3 dB per octave. 179 | """ 180 | state = np.random.RandomState() if state is None else state 181 | uneven = N % 2 182 | X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) 183 | S = np.sqrt(np.arange(len(X))) # Filter 184 | y = (irfft(X * S)).real 185 | if uneven: 186 | y = y[:-1] 187 | return normalize(y) 188 | 189 | def brown(N, state=None): 190 | """ 191 | Violet noise. 192 | :param N: Amount of samples. 193 | :param state: State of PRNG. 194 | :type state: :class:`np.random.RandomState` 195 | Power decreases with -3 dB per octave. 196 | Power density decreases with 6 dB per octave. 197 | """ 198 | state = np.random.RandomState() if state is None else state 199 | uneven = N % 2 200 | X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) 201 | S = (np.arange(len(X)) + 1) # Filter 202 | y = (irfft(X / S)).real 203 | if uneven: 204 | y = y[:-1] 205 | return normalize(y) 206 | 207 | def violet(N, state=None): 208 | """ 209 | Violet noise. Power increases with 6 dB per octave. 210 | :param N: Amount of samples. 211 | :param state: State of PRNG. 212 | :type state: :class:`np.random.RandomState` 213 | Power increases with +9 dB per octave. 214 | Power density increases with +6 dB per octave. 215 | """ 216 | state = np.random.RandomState() if state is None else state 217 | uneven = N % 2 218 | X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) 219 | S = (np.arange(len(X))) # Filter 220 | y = (irfft(X * S)).real 221 | if uneven: 222 | y = y[:-1] 223 | return normalize(y) 224 | 225 | _noise_generators = { 226 | 'white': white, 227 | 'pink': pink, 228 | 'blue': blue, 229 | 'brown': brown, 230 | 'violet': violet, 231 | } 232 | 233 | def noise_generator(N=44100, color='white', state=None): 234 | """Noise generator. 235 | :param N: Amount of unique samples to generate. 236 | :param color: Color of noise. 237 | Generate `N` amount of unique samples and cycle over these samples. 238 | """ 239 | #yield from itertools.cycle(noise(N, color)) # Python 3.3 240 | for sample in itertools.cycle(noise(N, color, state)): 241 | yield sample 242 | 243 | def heaviside(N): 244 | """Heaviside. 245 | Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`. 246 | """ 247 | return 0.5 * (np.sign(N) + 1) -------------------------------------------------------------------------------- /preprocessing/constants.py: -------------------------------------------------------------------------------- 1 | DATASET="../dataset" 2 | INT_RANDOM_SEED = 42 3 | # MUSIC_SAMPLE_RATE = 22050 4 | MUSIC_SAMPLE_RATE = 16000 5 | STR_CH_FIRST = 'channels_first' 6 | STR_CH_LAST = 'channels_last' 7 | DATA_LENGTH = MUSIC_SAMPLE_RATE * 30 8 | INPUT_LENGTH = MUSIC_SAMPLE_RATE * 10 9 | CHUNK_SIZE = 16 10 | # METADATA = ['title','artist_name','release','year'] 11 | METADATA = ['artist_name','year'] 12 | BLACK_LIST = ['d-i-v-o-r-c-e', 'n metal'] 13 | MIDFEATURE = ['key','tempo'] 14 | CONTEXTUAL = ['theme','mood','decade'] 15 | MUSICAL = ['genre','style','instrument','vocal'] 16 | CULTUREAL = ['language','location'] 17 | 18 | TOKEN_DICT = { 19 | "artist_name":'[ARTIST_NAME]', 20 | "release":'[RELEASE]', 21 | "title":'[TITLE]', 22 | "year":'[YEAR]', 23 | "theme":'[THEME]', 24 | "mood":'[MOOD]', 25 | "genre":'[GENRE]', 26 | "style":'[STYLE]', 27 | "instrument":'[INSTRUMENT]', 28 | "decade":'[DECADE]', 29 | "language":'[LANGUAGE]', 30 | "location":'[LOCATION]', 31 | "vocal":'[VOCAL]', 32 | "tempo":'[TEMPO]', 33 | "key":'[KEY]' 34 | } 35 | 36 | LASTFM_TAG_INFO = { 37 | '00s': "decade", 38 | '60s': "decade", 39 | '70s': "decade", 40 | '80s': "decade", 41 | '90s': "decade", 42 | 'acoustic': "instrument", 43 | 'alternative': "genre", 44 | 'alternative rock': "genre", 45 | 'ambient': "genre", 46 | 'beautiful': "mood", 47 | 'blues': "genre", 48 | 'catchy': "mood", 49 | 'chill': "mood", 50 | 'chillout': "mood", 51 | 'classic rock': "genre", 52 | 'country': "genre", 53 | 'dance': "mood", 54 | 'easy listening': "mood", 55 | 'electro': "genre", 56 | 'electronic': "genre", 57 | 'electronica': "genre", 58 | 'experimental': "genre", 59 | 'female vocalist': "vocal", 60 | 'female vocalists': "vocal", 61 | 'folk': "genre", 62 | 'funk': "genre", 63 | 'guitar': "instrument", 64 | 'happy': "mood", 65 | 'hard rock': "genre", 66 | 'heavy metal': "genre", 67 | 'hip-hop': "genre", 68 | 'house': "genre", 69 | 'indie': "genre", 70 | 'indie pop': "genre", 71 | 'indie rock': "genre", 72 | 'instrumental': "genre", 73 | 'jazz': "genre", 74 | 'male vocalists': "vocal", 75 | 'mellow': "mood", 76 | 'metal': "genre", 77 | 'oldies': "mood", 78 | 'party': "mood", 79 | 'pop': "genre", 80 | 'progressive rock': "genre", 81 | 'punk': "genre", 82 | 'rnb': "genre", 83 | 'rock': "genre", 84 | 'sad': "mood", 85 | 'sexy': "mood", 86 | 'soul': "genre" 87 | } -------------------------------------------------------------------------------- /preprocessing/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from msd_preprocessor import MSD_processor 4 | from constants import DATASET 5 | 6 | def main(): 7 | MSD_processor(msd_path= os.path.join(DATASET)) 8 | 9 | if __name__ == '__main__': 10 | main() -------------------------------------------------------------------------------- /preprocessing/msd_preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import sqlite3 5 | import torch 6 | import json 7 | import pandas as pd 8 | import numpy as np 9 | import multiprocessing 10 | from collections import Counter 11 | from functools import partial 12 | from contextlib import contextmanager 13 | from sklearn import preprocessing 14 | from sklearn.model_selection import train_test_split 15 | from audio_utils import load_audio 16 | from sklearn.preprocessing import MultiLabelBinarizer 17 | 18 | from constants import DATASET, DATA_LENGTH, STR_CH_FIRST, MUSIC_SAMPLE_RATE, BLACK_LIST, LASTFM_TAG_INFO 19 | 20 | NaN_to_emptylist = lambda d: d if isinstance(d, list) or isinstance(d, str) else [] 21 | flatten_list_of_list = lambda l: [item for sublist in l for item in sublist] 22 | 23 | def tag_normalize(tag): 24 | tag = tag.replace("'n'","").replace("'","").replace("(","").replace(")","").replace("/"," ").replace("-"," ").replace(" & ","n").replace("&", "n") 25 | tag = unique_word(tag) 26 | return tag 27 | 28 | def unique_word(tag): 29 | unique_tag, remove_dix = [], None 30 | token = tag.split() 31 | for idx, i in enumerate(token): 32 | if len(i) == 1: 33 | unique_tag.append(token[idx] + token[idx+1]) 34 | remove_dix = idx + 1 35 | else: 36 | unique_tag.append(i) 37 | if remove_dix: 38 | unique_tag.remove(token[remove_dix]) 39 | return " ".join(unique_tag) 40 | 41 | def _remove(tag_list): 42 | return [i for i in tag_list if i not in BLACK_LIST] 43 | 44 | def getMsdInfo(msd_path): 45 | con = sqlite3.connect(msd_path) 46 | msd_db = pd.read_sql_query("SELECT * FROM songs", con) 47 | msd_db = msd_db.set_index('track_id') 48 | return msd_db 49 | 50 | def _json_dump(path, item): 51 | with open(path, mode="w") as io: 52 | json.dump(item, io, indent=4) 53 | 54 | @contextmanager 55 | def poolcontext(*args, **kwargs): 56 | pool = multiprocessing.Pool(*args, **kwargs) 57 | yield pool 58 | pool.terminate() 59 | 60 | def msd_resampler(_id, path): 61 | save_name = os.path.join(DATASET,'npy', path.replace(".mp3",".npy")) 62 | try: 63 | src, _ = load_audio( 64 | path=os.path.join(DATASET,'songs',path), 65 | ch_format= STR_CH_FIRST, 66 | sample_rate= MUSIC_SAMPLE_RATE, 67 | downmix_to_mono= True) 68 | if src.shape[-1] < DATA_LENGTH: # short case 69 | pad = np.zeros(DATA_LENGTH) 70 | pad[:src.shape[-1]] = src 71 | src = pad 72 | elif src.shape[-1] > DATA_LENGTH: # too long case 73 | src = src[:DATA_LENGTH] 74 | 75 | if not os.path.exists(os.path.dirname(save_name)): 76 | os.makedirs(os.path.dirname(save_name)) 77 | np.save(save_name, src.astype(np.float32)) 78 | except: 79 | os.makedirs(os.path.join(DATASET,"error"), exist_ok=True) 80 | np.save(os.path.join(DATASET,"error", _id + ".npy"), _id) # check black case 81 | 82 | def binary_df_to_list(binary, tags, indices, data_type): 83 | list_of_tag = [] 84 | for bool_tags in binary: 85 | list_of_tag.append([tags[idx] for idx, i in enumerate(bool_tags) if i] ) 86 | df_tag_list = pd.DataFrame(index=indices, columns=[data_type]) 87 | df_tag_list.index.name = "track_id" 88 | df_tag_list[data_type] = list_of_tag 89 | df_tag_list['is_'+ data_type] = [True for i in range(len(df_tag_list))] 90 | return df_tag_list 91 | 92 | def lastfm_processor(lastfm_path): 93 | """ 94 | input: lastfm_path 95 | return: pandas.DataFrame => index: msd trackid, columns: list of tag 96 | TRAAAAK128F9318786 [rock, alternative rock, hard rock] 97 | TRAAAAW128F429D538 [hip-hop] 98 | """ 99 | lastfm_tags = open(os.path.join(lastfm_path, "50tagList.txt"),'r').read().splitlines() 100 | lastfm_tags = [i.lower() for i in lastfm_tags] 101 | # lastfm split and 102 | train_list = pickle.load(open(os.path.join(lastfm_path, "filtered_list_train.cP"), 'rb')) 103 | test_list = pickle.load(open(os.path.join(lastfm_path, "filtered_list_test.cP"), 'rb')) 104 | msd_id_to_tag_vector = pickle.load(open(os.path.join(lastfm_path, "msd_id_to_tag_vector.cP"), 'rb')) 105 | total_list = train_list + test_list 106 | binary = [msd_id_to_tag_vector[msdid].astype(np.int16).squeeze(-1) for msdid in total_list] 107 | track_split = { 108 | "train_track": train_list[0:201680], 109 | "valid_track": train_list[201680:], 110 | "test_track": test_list, 111 | } 112 | 113 | _json_dump(os.path.join(lastfm_path, "lastfm_tags.json"), lastfm_tags) 114 | _json_dump(os.path.join(lastfm_path, "lastfm_tag_info.json"), LASTFM_TAG_INFO) 115 | _json_dump(os.path.join(lastfm_path, "lastfm_track_split.json"), track_split) 116 | 117 | lastfm_binary = pd.DataFrame(binary, index=total_list, columns=lastfm_tags) 118 | df_lastfm = binary_df_to_list(binary=binary, tags=lastfm_tags, indices=total_list, data_type="lastfm") 119 | return df_lastfm, track_split 120 | 121 | def cals_processor(cals_path): 122 | train_ids = np.load(os.path.join(cals_path, "train_ids.npy")) 123 | train_binary = np.load(os.path.join(cals_path, "train_binaries.npy")) 124 | valid_ids = np.load(os.path.join(cals_path, "valid_ids.npy")) 125 | valid_binary = np.load(os.path.join(cals_path, "valid_binaries.npy")) 126 | test_ids = np.load(os.path.join(cals_path, "test_ids.npy")) 127 | test_binary = np.load(os.path.join(cals_path, "test_binaries.npy")) 128 | cals_ids = list(train_ids) + list(valid_ids) + list(test_ids) 129 | cals_binary = np.vstack([train_binary,valid_binary,test_binary]) 130 | cals_tags = list(np.load(os.path.join(cals_path, "cals_tags.npy"))) 131 | ids_to_tag = {} 132 | for ids, binary in zip(cals_ids, cals_binary): 133 | ids_to_tag[ids] = { 134 | "cals":[cals_tags[idx] for idx, i in enumerate(binary) if i], 135 | "is_cals":True 136 | } 137 | df_cals = pd.DataFrame(ids_to_tag).T 138 | df_cals.index.name = "track_id" 139 | return df_cals 140 | 141 | def allmusic_processor(allmusic_path): 142 | """ 143 | input: allmusic_path 144 | return: pandas.DataFrame => index: msd trackid, columns: list of tag 145 | TRWYIGP128F1454835 [Pop/Rock, Electronic, Adult Alternative Pop/R... 146 | TRGFXIU128F1454832 [Pop/Rock, Electronic, Adult Alternative Pop/R... 147 | """ 148 | df_all = pd.read_hdf(os.path.join(allmusic_path, 'ground_truth_assignments/AMG_Multilabel_tagsets/msd_amglabels_all.h5')) 149 | tag_stats, tag_dict = {}, {} 150 | for category in df_all.columns: 151 | df_all[category] = df_all[category].apply(NaN_to_emptylist) 152 | df_all[category] = df_all[category].map(lambda x: list(map(str.lower, x))) 153 | tag_stats[category[:-1]] = {i:j for i,j in Counter(flatten_list_of_list(df_all[category])).most_common()} 154 | for tag in set(flatten_list_of_list(df_all[category])): 155 | tag_dict[tag] = category[:-1] 156 | _json_dump(os.path.join(allmusic_path, "allmusic_tags.json"), list(tag_dict.keys())) 157 | _json_dump(os.path.join(allmusic_path, "allmusic_tag_info.json"), tag_dict) 158 | _json_dump(os.path.join(allmusic_path, "allmusic_tag_stats.json"), tag_stats) 159 | 160 | tag_list = df_all['genres']+df_all['styles']+df_all['moods']+df_all['themes'] 161 | df_allmusic = pd.DataFrame(index=df_all.index, columns=["allmusic"]) 162 | df_allmusic["allmusic"] = tag_list 163 | df_allmusic['is_allmusic'] = [True for i in range(len(df_allmusic))] 164 | return df_allmusic 165 | 166 | def msd500_processor(msd500_path): 167 | msd500_tags = pd.read_csv(os.path.join(msd500_path,"selected_tags.tsv"), sep='\t', header=None) 168 | msd500_map = {'mood':'mood', 'instrument':'instrument', 'activity':'theme', 169 | 'language':'language', 'location':'location', 'decade':'decade', 'genre':'genre'} 170 | msd500_tag_info = {i:msd500_map[j.split("/")[0]] for i,j in zip(msd500_tags[0], msd500_tags[1])} 171 | msd500_anno = pd.read_csv(os.path.join(msd500_path,"track_tags.tsv"), sep="\t", header=None) 172 | use_tag = list(msd500_tag_info.keys()) 173 | msd500_anno = msd500_anno.set_index(2) 174 | msd500_anno = msd500_anno.loc[use_tag] 175 | item_dict = {i:[] for i in msd500_anno[0]} 176 | for _id, tag in zip(msd500_anno[0], msd500_anno.index): 177 | item = item_dict[_id].copy() 178 | item.append(tag) 179 | item_dict[_id] = list(set(item)) 180 | 181 | df_msd500 = pd.DataFrame(index=item_dict.keys()) 182 | df_msd500['msd500'] = item_dict.values() 183 | df_msd500['is_msd500'] = [True for i in range(len(df_msd500))] 184 | df_msd500.index.name = "track_id" 185 | msd500_tag_stat = {i:j for i,j in Counter(flatten_list_of_list(df_msd500['msd500'])).most_common()} 186 | _json_dump(os.path.join(msd500_path, "msd500_tag_info.json"), msd500_tag_info) 187 | _json_dump(os.path.join(msd500_path, "msd500_tags.json"), list(msd500_tags[0])) 188 | _json_dump(os.path.join(msd500_path, "msd500_tag_stats.json"), msd500_tag_stat) 189 | return df_msd500 190 | 191 | def _check_mp3_file(df_msd, id_to_path, MSD_id_to_7D_id): 192 | mp3_path, error_id = {}, [] 193 | for msdid in df_msd.index: 194 | try: 195 | mp3_path[msdid] = id_to_path[MSD_id_to_7D_id[msdid]] 196 | except: 197 | error_id.append(msdid) 198 | df_msd = df_msd.drop(error_id) 199 | return df_msd, mp3_path 200 | 201 | def _track_split(df_target, msd_path, types = "ecals"): 202 | track_split = {} 203 | if types == "ecals": 204 | df_target = df_target[df_target['tag'].apply(lambda x: len(x) != 0)] 205 | for i in set(df_target['splits']): 206 | track_list = list(df_target[df_target['splits'] == i].index) 207 | if i == "TRAIN": 208 | track_split['train_track'] = track_list 209 | elif i == "VALID": 210 | track_split['valid_track'] = track_list 211 | elif i == "TEST": 212 | track_split['test_track'] = track_list 213 | elif i == "STUDENT": 214 | track_split['extra_track'] = track_list 215 | _tag_stat = {i:j for i,j in Counter(flatten_list_of_list(list(df_target['tag']))).most_common()} 216 | track_list = track_split['train_track'] + track_split['valid_track']+ track_split['test_track'] 217 | print("finish msd extraction", len(track_list), "extra_track: ", len(track_split['extra_track']), "tag: ", len(_tag_stat)) 218 | 219 | _json_dump(os.path.join(msd_path, f"{types}_track_split.json"), track_split) 220 | _json_dump(os.path.join(msd_path, f"{types}_tags.json"), list(_tag_stat.keys())) 221 | _json_dump(os.path.join(msd_path, f"{types}_tag_stats.json"), _tag_stat) 222 | return track_split 223 | 224 | def _check_stat(df, track_list): 225 | df_test = df.loc[track_list] 226 | save_tag = set(df_test.T.loc[df_test.sum() > 2].index) 227 | return save_tag 228 | 229 | def filtering(df_tags, tr_track, va_track, te_track): 230 | merge_tag = df_tags['cals'] + df_tags['lastfm'] + df_tags['msd500'] + df_tags['allmusic'] 231 | merge_tag = merge_tag.apply(lambda x: _remove(x)) 232 | merge_tag = merge_tag.apply(lambda x: list(map(tag_normalize, x))) 233 | tag_list = merge_tag.apply(set).apply(list) 234 | mlb = MultiLabelBinarizer() 235 | binary = mlb.fit_transform(tag_list) 236 | df = pd.DataFrame(binary, index=list(merge_tag.index), columns=mlb.classes_) 237 | tr_save = _check_stat(df, tr_track) 238 | va_save = _check_stat(df, va_track) 239 | te_save = _check_stat(df, te_track) 240 | tag = list(tr_save & va_save & te_save) 241 | df_all = df[tag] 242 | df_binary = df_all.loc[df_all.sum(axis=1) > 0] 243 | filtered_tag = [] 244 | for idx in range(len(df_all)): 245 | item = df_all.iloc[idx] 246 | filtered_tag.append(list(item[item == 1].index)) 247 | return filtered_tag, df_binary 248 | 249 | 250 | def MSD_processor(msd_path): 251 | meta_path = os.path.join(msd_path, "track_metadata.db") 252 | lastfm_path = os.path.join(msd_path, "lastfm_annotation") 253 | allmusic_path = os.path.join(msd_path, "allmusic_annotation") 254 | msd500_path = os.path.join(msd_path, "msd500_annotation") 255 | cals_path = os.path.join(msd_path, "cals_annotation") 256 | ecals_path = os.path.join(msd_path, "ecals_annotation") 257 | os.makedirs(ecals_path, exist_ok=True) 258 | 259 | MSD_id_to_7D_id = pickle.load(open(os.path.join(lastfm_path, "MSD_id_to_7D_id.pkl"), 'rb')) 260 | id_to_path = pickle.load(open(os.path.join(lastfm_path, "7D_id_to_path.pkl"), 'rb')) 261 | lastfm_tags = [i.lower() for i in open(os.path.join(lastfm_path, "50tagList.txt"),'r').read().splitlines()] 262 | cals_split = pd.read_csv(os.path.join(cals_path, "msd_splits.tsv"), sep="\t").rename(columns={"clip_ids":"track_id"}).set_index("track_id") 263 | df_msdmeta = getMsdInfo(meta_path) 264 | df_cals = cals_processor(cals_path) 265 | df_lastfm, _ = lastfm_processor(lastfm_path) 266 | df_msd500 = msd500_processor(msd500_path) 267 | df_allmusic = allmusic_processor(allmusic_path) 268 | cals_lastfm = pd.merge(df_cals, df_lastfm, how='outer',on='track_id') 269 | cals_lastfm_msd500 = pd.merge(cals_lastfm, df_msd500, how='outer',on='track_id') 270 | cals_lastfm_msd500_allmusic = pd.merge(cals_lastfm_msd500, df_allmusic, how='outer',on='track_id') 271 | df_tags = pd.merge(cals_split, cals_lastfm_msd500_allmusic, how='outer',on='track_id') 272 | df_tags['length'] = df_tags['length'] / 22050 273 | 274 | for column in ["cals","lastfm","msd500","allmusic","is_cals","is_lastfm","is_msd500","is_allmusic"]: 275 | if "is_" in column: 276 | df_tags[column] = df_tags[column].fillna(False) 277 | else: 278 | df_tags[column] = df_tags[column].apply(NaN_to_emptylist) 279 | df_merge = pd.merge(df_tags, df_msdmeta, how='left',on='track_id') 280 | df_merge['splits'] = df_merge['splits'].fillna("NONE") 281 | df_final = df_merge[df_merge['splits'] != "NONE"] 282 | 283 | target_col = ["splits","length",'cals',"lastfm","msd500","allmusic","is_cals","is_lastfm","is_msd500","is_allmusic","release","artist_name","year","title"] 284 | df_target = df_final[target_col] 285 | df_target, mp3_path = _check_mp3_file(df_target, id_to_path, MSD_id_to_7D_id) 286 | 287 | with poolcontext(processes=multiprocessing.cpu_count()) as pool: 288 | pool.starmap(msd_resampler, zip(list(mp3_path.keys()),list(mp3_path.values()))) 289 | print("finish extract") 290 | 291 | error_ids = [msdid.replace(".npy","") for msdid in os.listdir(os.path.join(msd_path, 'error'))] 292 | df_target = df_target.drop(error_ids) # drop errors 293 | tr_track = list(df_target[df_target['splits'] == "TRAIN"].index) 294 | va_track = list(df_target[df_target['splits'] == "VALID"].index) 295 | te_track = list(df_target[df_target['splits'] == "TEST"].index) 296 | 297 | filtered_tag, df_binary = filtering(df_target, tr_track, va_track, te_track) 298 | df_target['tag'] = filtered_tag 299 | binary_error = [i for i in error_ids if i in df_binary.index] 300 | df_binary = df_binary.drop(binary_error) # drop errors 301 | 302 | df_binary.to_csv(os.path.join(ecals_path, 'ecals_binary.csv')) 303 | df_target['track_id'] = df_target.index 304 | track_split = _track_split(df_target, ecals_path, types = "ecals") 305 | ecals_track = track_split['train_track'] + track_split['valid_track'] + track_split['test_track'] + track_split['extra_track'] 306 | annotation_dict = df_target[["tag","release","artist_name","year","title",'track_id']].to_dict('index') # for small 307 | target_anotation_dict = {i:annotation_dict[i] for i in ecals_track} 308 | print(len(target_anotation_dict)) 309 | with open(os.path.join(ecals_path, f"annotation.json"), mode="w") as io: 310 | json.dump(target_anotation_dict, io) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Extended Cleaned tag and Artist-Level Stratified split (eCALS) 2 | 3 | For quick start, Just donwload preprocessed dataset from zenodo and check `ecals_annotation` folder 4 | - [Zenodo-Link](https://zenodo.org/record/7107130) 5 | 6 | We introduce the extended tag version of CALS split (cleaned and artist-level stratified) for the Million Song Dataset (MSD). Different from the previously CALS dataset split, we provide 1054 vocabulary and caption level tag sequences instead of 50 small vocabs. Since we inherit the existing cals split, there is no difference in the test dataset. However, we use all tag annotations in the existing tag annotations `msd50, msd500, and allmusic`. This is the dataset repository for the paper: [Toward Universal Text-to-Music Retrieval](https://arxiv.org/abs/2211.14558) 7 | 8 |

9 | 10 |

11 | 12 | 13 | ### Example of data annotation 14 | 15 | Dataset is `key-value` type. The key is `msdid`, and value is python dictionary (item). Item consists of tag annotation data, and artist name, track title, and year meta data. 16 | 17 | ``` 18 | { 19 | "TRSAGNY128F425391E": 20 | { 21 | 'tag': 22 | ['aggressive', 'confrontational', 'energetic', 'alternative indie rock', 'self conscious', 'rowdy', 'bravado', 'pop rock', 'hardcore punk', 'passionate', 'confident', 'gutsy', 'swaggering', 'earnest', 'urgent', 'anguished distraught', 'straight edge', 'cathartic', 'punk', 'brash', 'rebellious', 'dramatic', 'alternative pop rock', 'street smart', 'summery', 'knotty', 'volatile', 'fiery', 'punk new wave', 'angry'], 23 | 'release': 'Sink With Kalifornija', 24 | 'artist_name': 'Youth Brigade', 25 | 'year': 1984, 26 | 'title': 'What Are You Fighting For?', 27 | 'track_id': 'TRSAGNY128F425391E' 28 | } 29 | } 30 | ``` 31 | 32 | ### Dataset Loader For Classification 33 | If you want to use this dataset for audio-language representation learning or captioning, please refer to the this [repository](https://github.com/SeungHeonDoh/music-text-representation/blob/main/mtr/contrastive/dataset.py). 34 | 35 | ```python 36 | class ECALS_Dataset(Dataset): 37 | """ 38 | data_path (str): location of msu-benchmark 39 | split (str): one of {TRAIN, VALID, TEST} 40 | sr (int): sampling rate of waveform - 16000 41 | num_chunks (int): chunk size of inference audio 42 | """ 43 | def __init__(self, data_path, split, sr, duration, num_chunks): 44 | self.data_path = data_path 45 | self.split = split 46 | self.sr = sr 47 | self.input_length = int(sr * duration) 48 | self.num_chunks = num_chunks 49 | self.msd_to_id = pickle.load(open(os.path.join(data_path, "lastfm_annotation", "MSD_id_to_7D_id.pkl"), 'rb')) 50 | self.id_to_path = pickle.load(open(os.path.join(data_path, "lastfm_annotation", "7D_id_to_path.pkl"), 'rb')) 51 | self.get_split() 52 | self.get_file_list() 53 | 54 | def get_split(self): 55 | track_split = json.load(open(os.path.join(self.data_path, "ecals_annotation", "ecals_track_split.json"), "r")) 56 | self.train_track = track_split['train_track'] + track_split['extra_track'] 57 | self.valid_track = track_split['valid_track'] 58 | self.test_track = track_split['test_track'] 59 | 60 | def get_file_list(self): 61 | annotation = json.load(open(os.path.join(self.data_path, "ecals_annotation", "annotation.json"), 'r')) 62 | self.list_of_label = json.load(open(os.path.join(self.data_path, "ecals_annotation", "ecals_tags.json"), 'r')) 63 | self.tag_to_idx = {i:idx for idx, i in enumerate(self.list_of_label)} 64 | if self.split == "TRAIN": 65 | self.fl = [annotation[i] for i in self.train_track] 66 | elif self.split == "VALID": 67 | self.fl = [annotation[i] for i in self.valid_track] 68 | elif self.split == "TEST": 69 | self.fl = [annotation[i] for i in self.test_track] 70 | else: 71 | raise ValueError(f"Unexpected split name: {self.split}") 72 | del annotation 73 | 74 | def audio_load(self, msd_id): 75 | audio_path = self.id_to_path[self.msd_to_id[msd_id]] 76 | audio = np.load(os.path.join(self.data_path, "npy", audio_path.replace(".mp3",".npy")), mmap_mode='r') 77 | random_idx = random.randint(0, audio.shape[-1]-self.input_length) 78 | audio = torch.from_numpy(np.array(audio[random_idx:random_idx+self.input_length])) 79 | return audio 80 | 81 | def tag_to_binary(self, tag_list): 82 | bainry = np.zeros([len(self.list_of_label),], dtype=np.float32) 83 | for tag in tag_list: 84 | bainry[self.tag_to_idx[tag]] = 1.0 85 | return bainry 86 | 87 | def __getitem__(self, index): 88 | item = self.fl[index] 89 | tag_list = item['tag'] 90 | binary = self.tag_to_binary(tag_list) 91 | audio_tensor = self.audio_load(item['track_id']) 92 | return { 93 | "audio":audio_tensor, 94 | "binary":binary, 95 | "tag_list":tag_list 96 | } 97 | 98 | def __len__(self): 99 | return len(self.fl) 100 | ``` 101 | 102 | ### Dataset stat 103 | 104 | ``` 105 | Train Track: 444865 (== CALS train + student) 106 | Valid Track: 34481 (== CALS valid track) 107 | Test Track: 34631 (== CALS test track) 108 | Unique Tag: 1054 109 | Unique Tag Caption: 139541 110 | Unique Artist: 32650 111 | Unique Album : 89920 112 | Unique Year: 90 113 | ``` 114 | 115 | ### Download Source Dataset from Zenodo 116 | 117 | ``` 118 | wget https://zenodo.org/record/7107130/files/dataset.tar.gz 119 | tar -xvf dataset.tar.gz 120 | cd dataset 121 | wget http://millionsongdataset.com/sites/default/files/AdditionalFiles/track_metadata.db 122 | ``` 123 | 124 | 125 | ``` 126 | └── dataset 127 | track_metadata.db 128 | ├── allmusic_annotation 129 | │ └── ground_truth_assignments 130 | │ ├── AMG_Multilabel_tagsets 131 | │ │ ├── msd_amglabels_all.h5 132 | ... 133 | ├── cals_annotation 134 | │ ├── cals_error.npy 135 | │ ├── cals_tags.npy 136 | ... 137 | ├── lastfm_annotation 138 | │ ├── 50tagList.txt 139 | │ ├── filtered_list_test.cP 140 | │ ├── filtered_list_train.cP 141 | │ ├── msd_id_to_tag_vector.cP 142 | ... 143 | ├── msd500_annotation 144 | │ ├── dataset_stats.txt 145 | │ ├── selected_tags.tsv 146 | │ └── track_tags.tsv 147 | └── ecals_annotation 148 | ├── annotation.json 149 | ├── ecals_tags.json 150 | ├── ecals_tag_stats.json 151 | ├── ecals_track_split.json 152 | └── multiquery_samples.json 153 | ``` 154 | 155 | ### Run Preprocessing Code 156 | ``` 157 | cd preprocessing 158 | python main.py 159 | ``` 160 | 161 | ### MSD audio 162 | Due to copyright issue, we don't provide audio data in this page. 163 | 164 | ### Citation 165 | Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follow. 166 | ``` 167 | @inproceedings{doh2023toward, 168 | title={Toward Universal Text-to-Music Retrieval}, 169 | author={Doh, SeungHeon and Won, Minz and Choi, Keunwoo and Nam, Juhan}, 170 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 171 | year={2023} 172 | } 173 | ``` 174 | -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | youtube-dl==2021.6.6 2 | pandas>=1.3.0 3 | cython==0.29.28 4 | tables==3.6.1 5 | scikit-multilearn==0.2.0 6 | bson==0.5.10 7 | ffmpeg-python==0.2.0 8 | numpy==1.22.4 9 | librosa==0.9.1 10 | tqdm==4.64.0 11 | sklearn --------------------------------------------------------------------------------