├── requirements.txt ├── data ├── val_idx.npy ├── test_idx.npy ├── tp_source.npy ├── train_idx.npy ├── mcn_mtt_source.npy ├── mcn_msd_big_source.npy └── moods.txt ├── .gitattributes ├── model.py ├── LICENSE ├── README.md └── run.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | pytorch-ignite==0.4.0 3 | numpy 4 | wandb 5 | tqdm 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /data/val_idx.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6d448d63e9c7851a1a32f25f5d05f7d1d5a3c8025a837f124e69b91072b6028f 3 | size 53688 4 | -------------------------------------------------------------------------------- /data/test_idx.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aff0f3dbaf58d823c67b0b7461c495ee06061855f3ab5d4d61bfaf196eabf246 3 | size 53832 4 | -------------------------------------------------------------------------------- /data/tp_source.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5bc1af2e06d016099956109875eb4c7897ed37e6aab92b6dbd4f6dcc6c4842d7 3 | size 53594528 4 | -------------------------------------------------------------------------------- /data/train_idx.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6b806f81e8f10b5f03971078bc3393340388e663fa0eb57b4813c5ddd081863e 3 | size 428808 4 | -------------------------------------------------------------------------------- /data/mcn_mtt_source.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:063485edd896fb2bf720db7b79bc2a2dde34c72a5efabacca203e78183a0b01f 3 | size 53594528 4 | -------------------------------------------------------------------------------- /data/mcn_msd_big_source.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fae4e8dd38a4136d4b87beed9c101445f5b7c01e67753d42893ebadf0f4e5dbd 3 | size 133986128 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/val_idx.npy filter=lfs diff=lfs merge=lfs -text 2 | data/tp_source.npy filter=lfs diff=lfs merge=lfs -text 3 | data/mcn_msd_big_source.npy filter=lfs diff=lfs merge=lfs -text 4 | data/mcn_mtt_source.npy filter=lfs diff=lfs merge=lfs -text 5 | data/test_idx.npy filter=lfs diff=lfs merge=lfs -text 6 | data/train_idx.npy filter=lfs diff=lfs merge=lfs -text -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MultiLayerPerceptron(nn.Module): 6 | 7 | def __init__(self, in_dim, out_dim, n_layers, n_units, dropout, shift=None, 8 | scale=None): 9 | super().__init__() 10 | 11 | self.shift = nn.Parameter(torch.Tensor(in_dim), requires_grad=False) 12 | self.scale = nn.Parameter(torch.Tensor(in_dim), requires_grad=False) 13 | torch.nn.init.zeros_(self.shift) 14 | torch.nn.init.ones_(self.scale) 15 | 16 | if shift is not None: 17 | self.shift.data = shift 18 | if scale is not None: 19 | self.scale.data = scale 20 | 21 | prev_dim = in_dim 22 | layers = [] 23 | for _ in range(n_layers): 24 | layer = torch.nn.Linear(prev_dim, n_units) 25 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') 26 | layers.append(layer) 27 | layers.append(torch.nn.ReLU()) 28 | if dropout > 0.0: 29 | layers.append(torch.nn.Dropout(dropout)) 30 | prev_dim = n_units 31 | out_layer = torch.nn.Linear(prev_dim, out_dim) 32 | torch.nn.init.kaiming_normal_(out_layer.weight, nonlinearity='sigmoid') 33 | layers.append(out_layer) 34 | self.layers = nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | x = (x - self.shift) / self.scale 38 | return self.layers(x) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Pandora Media, LLC. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /data/moods.txt: -------------------------------------------------------------------------------- 1 | Acerbic 2 | Aggressive 3 | Agreeable 4 | Ambitious 5 | Amiable/Good-Natured 6 | Angry 7 | Angst-Ridden 8 | Anguished/Distraught 9 | Animated 10 | Atmospheric 11 | Austere 12 | Autumnal 13 | Belligerent 14 | Bitter 15 | Bittersweet 16 | Bleak 17 | Boisterous 18 | Brash 19 | Brassy 20 | Bravado 21 | Bright 22 | Brittle 23 | Brooding 24 | Calm/Peaceful 25 | Campy 26 | Carefree 27 | Cathartic 28 | Celebratory 29 | Cerebral 30 | Cheerful 31 | Circular 32 | Clinical 33 | Cold 34 | Complex 35 | Confident 36 | Confrontational 37 | Crunchy 38 | Cynical/Sarcastic 39 | Delicate 40 | Detached 41 | Difficult 42 | Dramatic 43 | Dreamy 44 | Driving 45 | Druggy 46 | Earnest 47 | Earthy 48 | Eccentric 49 | Eerie 50 | Effervescent 51 | Elaborate 52 | Elegant 53 | Energetic 54 | Enigmatic 55 | Epic 56 | Ethereal 57 | Exciting 58 | Exuberant 59 | Fierce 60 | Fiery 61 | Flowing 62 | Fractured 63 | Freewheeling 64 | Fun 65 | Gentle 66 | Giddy 67 | Gleeful 68 | Gloomy 69 | Greasy 70 | Gritty 71 | Gutsy 72 | Happy 73 | Harsh 74 | Hedonistic 75 | Hostile 76 | Humorous 77 | Hypnotic 78 | Indulgent 79 | Innocent 80 | Insular 81 | Intense 82 | Intimate 83 | Introspective 84 | Ironic 85 | Irreverent 86 | Joyous 87 | Knotty 88 | Laid-Back/Mellow 89 | Light 90 | Literate 91 | Lively 92 | Lush 93 | Malevolent 94 | Manic 95 | Meandering 96 | Melancholy 97 | Melodic 98 | Menacing 99 | Messy 100 | Naive 101 | Nihilistic 102 | Nocturnal 103 | Nostalgic 104 | Ominous 105 | Optimistic 106 | Organic 107 | Outrageous 108 | Paranoid 109 | Passionate 110 | Pastoral 111 | Plaintive 112 | Playful 113 | Poignant 114 | Positive 115 | Powerful 116 | Precious 117 | Provocative 118 | Quirky 119 | Rambunctious 120 | Ramshackle 121 | Raucous 122 | Rebellious 123 | Reckless 124 | Refined 125 | Reflective 126 | Relaxed 127 | Reserved 128 | Restrained 129 | Reverent 130 | Rollicking 131 | Romantic 132 | Rousing 133 | Rowdy 134 | Rustic 135 | Sad 136 | Searching 137 | Self-Conscious 138 | Sensual 139 | Sentimental 140 | Sexual 141 | Sexy 142 | Silly 143 | Sleazy 144 | Slick 145 | Smooth 146 | Snide 147 | Soft/Quiet 148 | Somber 149 | Soothing 150 | Sophisticated 151 | Spacey 152 | Sparkling 153 | Sparse 154 | Spicy 155 | Spiritual 156 | Spooky 157 | Sprawling 158 | Springlike 159 | Stately 160 | Street-Smart 161 | Strong 162 | Stylish 163 | Suffocating 164 | Sugary 165 | Summery 166 | Swaggering 167 | Sweet 168 | Tender 169 | Tense/Anxious 170 | Theatrical 171 | Thoughtful 172 | Tough 173 | Trashy 174 | Trippy 175 | Uncompromising 176 | Unsettling 177 | Uplifting 178 | Urgent 179 | Visceral 180 | Volatile 181 | Warm 182 | Weary 183 | Whimsical 184 | Wintry 185 | Wistful 186 | Witty 187 | Wry 188 | Yearning -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Mood Classification using Listening Data 2 | ======================================== 3 | 4 | This repository contains the data and code to reproduce the results in the paper 5 | 6 | **Filip Korzeniowski**, **Oriol Nieto**, Matthew C. McCallum, Minz Won, Sergio Oramas, Erik M. Schmidt. 7 | “Mood Classification Using Listening Data”, 21st International Society for Music Information 8 | Retrieval Conference, Montréal, Canada, 2020 ([PDF](https://ccrma.stanford.edu/~urinieto/MARL/publications/ISMIR2020_MoodPrediction.pdf)). *(Authors in bold contributed equally.)* 9 | 10 | The AllMusic Mood Subset 11 | ------------------------ 12 | 13 | We provide a list of track ids from the Million Song Dataset (MSD), with train/val/test splits and a number of input 14 | features in this repository. All files can be found in `data`. 15 | 16 | **Note:** The data files are stored on `git lfs`, but you can download them [here](https://drive.google.com/file/d/1ecA1N1Mp1mOpwbntfWNQIMMrwPBrYzvl/view?usp=sharing) if you get any quota errors. 17 | 18 | ### Meta-Data 19 | 20 | * Track metadata (`metadata.csv`): MSD artist id, song id, and track id. Album ids are consecutive numbers and do not 21 | point to any database. Further, we provide artist names, album names, and track names. All rows in NumPy files 22 | correspond to this ordering. 23 | * AllMusic Moods (`moods.txt`): Set of mood names used in this dataset. This is a subset of all moods available on 24 | AllMusic, selected by frequency of annotations. The original IDs of these moods can be found in the official [Rovi website](http://prod-doc.rovicorp.com/mashery/index.php/MusicMoods). 25 | * Data Splits (`{train,val,test}_idx.npy`): NumPy arrays containing the indices of tracks used in the respective set. 26 | 27 | ### Features 28 | 29 | We provide the following features: 30 | 31 | * Taste Profile (`tp_source.npy`): Listening-based embeddings computed using weighted alternating least-sqares on the complete Taste-Profile dataset. 32 | * Musicnn-MSD (`mcn_msd_big_source.npy`): Audio-based embeddings given by the penultimate layer of the [Musicnn](https://github.com/jordipons/musicnn) model on 33 | the 30-second 7-digital snippets from the MSD. Here, we used the large Musicnn model trained on the MSD. 34 | * Musicnn-MTT (`mcn_mtt_source.npy`): Same as before, but using a smaller Musicnn model trained on the MagnaTagATune dataset. 35 | 36 | ### Ground Truth 37 | 38 | For legal reasons, we cannot provide the moods from AllMusic. However, the moods for an album can be obtained from 39 | [allmusic.com](https://allmusic.com), for example for [this Bob Dylan album](https://www.allmusic.com/album/mw0000198752). 40 | We do not encourage the research community to collect and publish the data, but if they do, we accept pull requests. 41 | 42 | After collecting the data, make sure to bring it into a multi-hot vector format (where 1 indicates the presence of a 43 | mood, and 0 the absence) format and store it as `data/mood_target.npy`. Each row should represent the ground truth 44 | for the corresponding track found in `data/metadata.csv`. 45 | 46 | Running the experiments 47 | ----------------------- 48 | 49 | The `run.py` scripts trains a model, reports validation results, and computes test set predictions for further evaluation. 50 | It logs the training progress to the console and to [Weights & Biases](http://wandb.ai). You can either create a free 51 | account or disable the corresponding lines in the script. Make sure you have all requirements installed, see `requirements.txt`. 52 | 53 | Model hyper-parameters can be set using command line arguments. The standard values correspond to the best parameters found 54 | for Taste-Profile embeddings. Here's the explicit cli call for the two types of embeddings (listening-based and audio-based). 55 | Make sure to set a gpu id if you want to use it by adding `--gpu_id `: 56 | 57 | ```bash 58 | # listening-based embeddings, e.g. taste-profile 59 | python run.py --n_layers 4 --n_units 3909 --lr 4e-4 --dropout 0.25 --weight_decay 0.0 --feature tp 60 | 61 | # audio-based embeddings, e.g. musicnn msd-trained embeddings 62 | python run.py --n_layers 4 --n_units 3933 --lr 5e-5 --dropout 0.25 --weight_decay 1e-6 --feature mcn_msd_big 63 | ``` 64 | 65 | We provide the following features in this repo: 66 | * Taste-Profile (`--feature tp`) 67 | * Large MusiCnn trained on the Million Song Dataset (`--feature mcn_msd_big`) 68 | * Regular MusiCnn trained on the MagnaTagATune Dataset (`--feature mcn_mtt`) 69 | 70 | You can easily add your own features by storing a NumPy file in the `data` directory called `yourfeature_source.npy` 71 | and calling the script using `--feature yourfeature`. Make sure that the rows correspond to the MSD track ids found in 72 | `msd_track_ids.txt`. 73 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import warnings 4 | from os.path import dirname, join 5 | 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from ignite.contrib.handlers import CosineAnnealingScheduler, create_lr_scheduler_with_warmup 10 | from ignite.contrib.metrics import AveragePrecision 11 | from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer 12 | from ignite.handlers import ModelCheckpoint, global_step_from_engine 13 | from ignite.metrics import Accuracy, Fbeta, Loss, Precision, Recall, RunningAverage 14 | from sklearn.preprocessing import StandardScaler 15 | from torch import nn 16 | from torch.utils.data import DataLoader, TensorDataset 17 | 18 | from model import MultiLayerPerceptron 19 | 20 | WANDB_PROJECT = 'Mood Prediction' 21 | DATA_DIR = join(dirname(__file__), 'data') 22 | 23 | 24 | def load_data(feature): 25 | features = np.load(join(DATA_DIR, f'{feature}_source.npy')) 26 | moods = np.load(join(DATA_DIR, f'mood_target.npy')) 27 | 28 | train_idxs = np.load(join(DATA_DIR, f'train_idx.npy')) 29 | val_idxs = np.load(join(DATA_DIR, f'val_idx.npy')) 30 | test_idxs = np.load(join(DATA_DIR, f'test_idx.npy')) 31 | 32 | train_set = TensorDataset( 33 | torch.from_numpy(features[train_idxs]), 34 | torch.from_numpy(moods[train_idxs])) 35 | val_set = TensorDataset( 36 | torch.from_numpy(features[val_idxs]), 37 | torch.from_numpy(moods[val_idxs])) 38 | test_set = TensorDataset( 39 | torch.from_numpy(features[test_idxs]), 40 | torch.from_numpy(moods[test_idxs])) 41 | 42 | return train_set, val_set, test_set 43 | 44 | 45 | def add_tag(metrics, tag): 46 | return {f'{tag}/{k}': v for k, v in metrics.items()} 47 | 48 | 49 | def activate_output(output): 50 | y_pred, y = output 51 | return torch.sigmoid(y_pred), y 52 | 53 | 54 | def threshold_output(output): 55 | y_pred, y = activate_output(output) 56 | return torch.round(y_pred), y 57 | 58 | 59 | def set_random_seed(seed): 60 | seed = seed if seed is not None else np.random.randint(1, int(1e9)) 61 | random.seed(seed) 62 | np.random.seed(seed) 63 | torch.manual_seed(seed) 64 | if torch.cuda.is_available(): 65 | torch.cuda.manual_seed_all(seed) 66 | return seed 67 | 68 | 69 | def parse_args(): 70 | parser = argparse.ArgumentParser( 71 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 72 | parser.add_argument('--exp_tags', nargs='*', default=None, 73 | help='Tags to use for W&B run') 74 | parser.add_argument('--seed', type=int, default=None, 75 | help='Random seed to set') 76 | parser.add_argument('--gpu_id', type=int, default=None, 77 | help='GPU to use') 78 | parser.add_argument('--n_workers', type=int, default=4, 79 | help='Number of workers for data loading.') 80 | parser.add_argument('--feature', default='tp', 81 | help='Input embedding to use (tp=taste profile)') 82 | parser.add_argument('--batch_size', type=int, default=128, 83 | help='Mini-batch size for training') 84 | parser.add_argument('--n_epochs', type=int, default=100, 85 | help='Number of epochs to train.') 86 | parser.add_argument('--n_layers', type=int, default=4, 87 | help='Number of neural network layers.') 88 | parser.add_argument('--n_units', type=int, default=3909, 89 | help='Number of units per neural network layer.') 90 | parser.add_argument('--dropout', type=float, default=0.25, 91 | help='Dropout probability for all layers.') 92 | parser.add_argument('--weight_decay', type=float, default=0.0, 93 | help='Weight decay factor.') 94 | parser.add_argument('--lr', type=float, default=4e-4, 95 | help='Initial learning rate.') 96 | config = parser.parse_args() 97 | return config 98 | 99 | 100 | if __name__ == '__main__': 101 | cfg = parse_args() 102 | cfg.seed = set_random_seed(cfg.seed) 103 | wandb.init( 104 | project=WANDB_PROJECT, 105 | tags=cfg.exp_tags, 106 | config=cfg, 107 | config_exclude_keys=['exp_tags']) 108 | wandb.run.save() 109 | 110 | device = torch.device( 111 | f'cuda:{cfg.gpu_id}' 112 | if torch.cuda.is_available() and cfg.gpu_id is not None 113 | else 'cpu') 114 | print(f'\nUsing device: {device}') 115 | 116 | train_set, val_set, test_set = load_data(cfg.feature) 117 | train_loader = DataLoader( 118 | train_set, 119 | batch_size=cfg.batch_size, 120 | shuffle=True, 121 | num_workers=cfg.n_workers) 122 | val_loader = DataLoader( 123 | val_set, 124 | batch_size=cfg.batch_size, 125 | shuffle=False, 126 | num_workers=cfg.n_workers, 127 | drop_last=False) 128 | print(f'\nNo. Train: {len(train_set):6d}') 129 | print(f'No. Val: {len(val_set):6d}') 130 | print(f'No. Test: {len(test_set):6d}') 131 | 132 | scaler = StandardScaler().fit(train_set[:][0]) 133 | model = MultiLayerPerceptron( 134 | in_dim=train_set[0][0].shape[0], 135 | out_dim=train_set[0][1].shape[0], 136 | n_layers=cfg.n_layers, 137 | n_units=cfg.n_units, 138 | dropout=cfg.dropout, 139 | shift=torch.from_numpy(scaler.mean_.astype(np.float32)), 140 | scale=torch.from_numpy(scaler.scale_.astype(np.float32)) 141 | ).to(device) 142 | print('\nModel:\n') 143 | print(model) 144 | wandb.watch(model) 145 | 146 | loss = nn.BCEWithLogitsLoss() 147 | optimizer = torch.optim.Adam( 148 | model.parameters(), 149 | lr=cfg.lr, 150 | weight_decay=cfg.weight_decay) 151 | trainer = create_supervised_trainer(model, optimizer, loss, device) 152 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') 153 | 154 | trainer.add_event_handler( 155 | Events.ITERATION_COMPLETED, 156 | create_lr_scheduler_with_warmup( 157 | CosineAnnealingScheduler( 158 | optimizer, 159 | param_name='lr', 160 | start_value=cfg.lr, 161 | end_value=0, 162 | cycle_size=len(train_loader) * cfg.n_epochs, 163 | start_value_mult=0, 164 | end_value_mult=0), 165 | warmup_start_value=0.0, 166 | warmup_end_value=cfg.lr, 167 | warmup_duration=len(train_loader) 168 | ) 169 | ) 170 | 171 | evaluator = create_supervised_evaluator( 172 | model, metrics={ 173 | 'loss': Loss(loss), 174 | 'acc_smpl': Accuracy(threshold_output, is_multilabel=True), 175 | 'p': Precision(threshold_output, average=True), 176 | 'r': Recall(threshold_output, average=True), 177 | 'f1': Fbeta(1.0, output_transform=threshold_output), 178 | 'ap': AveragePrecision(output_transform=activate_output) 179 | }, 180 | device=device) 181 | 182 | model_checkpoint = ModelCheckpoint( 183 | dirname=wandb.run.dir, 184 | filename_prefix='best', 185 | require_empty=False, 186 | score_function=lambda e: e.state.metrics['ap'], 187 | global_step_transform=global_step_from_engine(trainer)) 188 | evaluator.add_event_handler( 189 | Events.COMPLETED, model_checkpoint, {'model': model}) 190 | 191 | 192 | @trainer.on(Events.EPOCH_COMPLETED) 193 | def validate(trainer): 194 | evaluator.run(val_loader) 195 | wandb.log(trainer.state.metrics, step=trainer.state.epoch) 196 | wandb.log(add_tag(evaluator.state.metrics, 'val'), step=trainer.state.epoch) 197 | wandb.log({'Lr': optimizer.param_groups[0]['lr']}, step=trainer.state.epoch) 198 | print( 199 | f'Epoch {trainer.state.epoch:3d}:' 200 | f' Tr [{" ".join(f"{m}={v:.3f}" for m, v in trainer.state.metrics.items())}]' 201 | f' Va [{" ".join(f"{m}={v:.3f}" for m, v in evaluator.state.metrics.items())}]' 202 | ) 203 | 204 | 205 | print('\nTraining:\n') 206 | # ignore warnings from metrics 207 | with warnings.catch_warnings(): 208 | warnings.simplefilter('ignore') 209 | trainer.run(train_loader, max_epochs=cfg.n_epochs) 210 | 211 | model.load_state_dict(torch.load(model_checkpoint.last_checkpoint)) 212 | model.eval() 213 | with torch.no_grad(): 214 | preds = torch.sigmoid(model(test_set[:][0].to(device))).cpu().numpy() 215 | np.save(join(wandb.run.dir, 'test_predictions.npy'), preds) 216 | --------------------------------------------------------------------------------