├── requirements.txt ├── README.md ├── Dockerfile └── src ├── loss.py ├── about.txt ├── mxresnet.py ├── spectrograms.py ├── utils.py └── app.py /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | spotipy 3 | streamlit 4 | numpy 5 | pandas 6 | matplotlib 7 | plotly 8 | pillow 9 | fastai 10 | torch 11 | typing 12 | librosa 13 | pydub -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Feel 2 | 3 | "Learning to Feel" is a web app using Streamlit that uses deep learning to identify and extract emotions and moods from music. With this app you can both explore the data and classify your own songs. 4 | 5 | [Check it out here](http://167.172.220.53:8501/). 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | EXPOSE 8501 3 | 4 | RUN apt-get update 5 | RUN apt-get install -y libsndfile1-dev 6 | RUN apt-get install -y ffmpeg 7 | 8 | WORKDIR /usr/src/app 9 | COPY requirements.txt ./requirements.txt 10 | RUN pip install -r requirements.txt 11 | COPY . . 12 | 13 | CMD [ "streamlit", "run", "src/app.py" ] -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import * 4 | from torch.autograd import * 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data_utils 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | class BCELoss(nn.Module): 12 | def __init__(self, reduce=False): 13 | super().__init__() 14 | self.reduce = reduce 15 | 16 | def forward(self, logit, target): 17 | target = target.float() 18 | loss = nn.BCEWithLogitsLoss()(logit, target) 19 | if len(loss.size())==2: 20 | loss = loss.sum(dim=1) 21 | if not self.reduce: 22 | return loss 23 | else: 24 | return loss.mean() 25 | 26 | 27 | # Adapted from https://www.kaggle.com/c/human-protein-atlas-image-classification/discussion/78109 28 | class FocalLoss(nn.Module): 29 | def __init__(self, gamma=2, reduce=False): 30 | super().__init__() 31 | self.gamma = gamma 32 | self.reduce = reduce 33 | 34 | def forward(self, logit, target): 35 | target = target.float() 36 | max_val = (-logit).clamp(min=0) 37 | loss = logit - logit * target + max_val + \ 38 | ((-max_val).exp() + (-logit - max_val).exp()).log() 39 | 40 | invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0)) 41 | loss = (invprobs * self.gamma).exp() * loss 42 | if len(loss.size())==2: 43 | loss = loss.sum(dim=1) 44 | if not self.reduce: 45 | return loss 46 | else: 47 | return loss.mean() 48 | 49 | class MixupBCELoss(BCELoss): 50 | def forward(self, x, y): 51 | if isinstance(y, dict): 52 | y0, y1, a = y['y0'], y['y1'], y['a'] 53 | loss = a*super().forward(x, y0) + (1-a)*super().forward(x, y1) 54 | else: 55 | loss = super().forward(x, y) 56 | return 100*loss.mean() 57 | 58 | 59 | class MixupFocalLoss(FocalLoss): 60 | def forward(self, x, y): 61 | if isinstance(y, dict): 62 | y0, y1, a = y['y0'], y['y1'], y['a'] 63 | loss = a*super().forward(x, y0) + (1-a)*super().forward(x, y1) 64 | else: 65 | loss = super().forward(x, y) 66 | return loss.mean() -------------------------------------------------------------------------------- /src/about.txt: -------------------------------------------------------------------------------- 1 | Machines have learned to recognize faces, drive cars, play Go, create art, and even smell. But what about learning how to listen to music? Listening to and enjoying music is an experience that is uniquely human. Nothing but vibrations in the air, music can be both incredibly simple and infinitely complex at the same time. 2 | 3 | Can a machine identify, isolate, and extract the emotional response we feel when listening to music? Which emotions can be extracted from music? Can humans even do this? Empirical research suggests that even within a single person, there are many factors such as setting, current temperament, recent events, or even present mindfulness that may influence the emotional response to the same song. 4 | 5 | As both a data scientist and a passionate consumer of all types of music, I was very curious about the answer to these questions and wanted to see if deep learning was powerful enough to capture this relationship. To do this, I employed the help of Spotify, completing nearly half a billion searches over the course of several months — using user-generated playlists to crowdsource our collective emotional experience of music. This data was used to come up with 500 different emotions or moods that can be experienced when listening to music. This was later consolidated to roughly 250 — of which you can explore on this app. 6 | 7 | As you’re exploring the data, keep in mind that this is what the majority of people feel when listening to a particular song. Millions of playlists have been analyzed using the latest breakthroughs in natural language processing, building up a list of tracks that repeatedly appear for a given emotion. The more often a track appeared, the more likely it was included in the final algorithm. While not all results are perfect and there are some biases in the data (be sure to read the forthcoming blog post outlining the technical details of the project), it is truly incredible to see not only the ability of an algorithm to learn from audio data alone, but how we as humans can reach a consensus on how music makes us feel. 8 | 9 | In Latin, "sentire" means both "to listen" and "to feel" -- something that, in my opinion, sums up perfectly the musical experience and this project. This was truly a labor of love and I hope you enjoy exploring the data as much as I did creating it. 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/mxresnet.py: -------------------------------------------------------------------------------- 1 | from fastai.torch_core import * 2 | import torch.nn as nn 3 | import torch, math, sys 4 | import torch.utils.model_zoo as model_zoo 5 | from functools import partial 6 | from fastai.torch_core import Module 7 | import torch.nn.functional as F 8 | 9 | 10 | # https://arxiv.org/abs/1908.08681 11 | class Mish(nn.Module): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, x): 17 | return x * (torch.tanh(F.softplus(x))) 18 | 19 | 20 | #Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 21 | def conv1d(ni:int, no:int, ks:int = 1, stride:int = 1, padding:int = 0, bias:bool = False): 22 | '''Create and initialize a `nn.Conv1d` layer with spectral normalization.''' 23 | conv = nn.Conv1d(ni, no, ks, stride = stride, padding = padding, bias = bias) 24 | nn.init.kaiming_normal_(conv.weight) 25 | if bias: conv.bias.data.zero_() 26 | return spectral_norm(conv) 27 | 28 | 29 | # Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 30 | # Inspired by https://arxiv.org/pdf/1805.08318.pdf 31 | class SimpleSelfAttention(nn.Module): 32 | 33 | def __init__(self, n_in:int, ks = 1, sym = False): 34 | super().__init__() 35 | self.conv = conv1d(n_in, n_in, ks, padding = ks//2, bias = False) 36 | self.gamma = nn.Parameter(tensor([0.])) 37 | self.sym = sym 38 | self.n_in = n_in 39 | 40 | def forward(self, x): 41 | 42 | if self.sym: 43 | # symmetry hack by https://github.com/mgrankin 44 | c = self.conv.weight.view(self.n_in, self.n_in) 45 | c = (c + c.t())/2 46 | self.conv.weight = c.view(self.n_in, self.n_in, 1) 47 | 48 | size = x.size() 49 | x = x.view(*size[:2], -1) 50 | convx = self.conv(x) 51 | xxT = torch.bmm(x, x.permute(0, 2, 1).contiguous()) 52 | o = torch.bmm(xxT, convx) 53 | o = self.gamma * o + x 54 | 55 | return o.view(*size).contiguous() 56 | 57 | # sizes 58 | __all__ = ['MXResNet', 'mxresnet18', 'mxresnet34', 'mxresnet50', 'mxresnet101', 'mxresnet152'] 59 | 60 | # activation function 61 | act_fn = Mish() #nn.ReLU(inplace = True) 62 | 63 | 64 | class Flatten(Module): 65 | def forward(self, x): return x.view(x.size(0), -1) 66 | 67 | def init_cnn(m): 68 | if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0) 69 | if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight) 70 | for l in m.children(): init_cnn(l) 71 | 72 | def conv(ni, nf, ks = 3, stride = 1, bias = False): 73 | return nn.Conv2d(ni, nf, kernel_size = ks, stride = stride, padding = ks//2, bias = bias) 74 | 75 | def noop(x): return x 76 | 77 | def conv_layer(ni, nf, ks = 3, stride = 1, zero_bn = False, act = True): 78 | bn = nn.BatchNorm2d(nf) 79 | nn.init.constant_(bn.weight, 0. if zero_bn else 1.) 80 | layers = [conv(ni, nf, ks, stride = stride), bn] 81 | if act: layers.append(act_fn) 82 | return nn.Sequential(*layers) 83 | 84 | class ResBlock(Module): 85 | 86 | def __init__(self, expansion, ni, nh, stride = 1, sa = False, sym = False): 87 | nf, ni = nh * expansion, ni * expansion 88 | layers = [conv_layer(ni, nh, 3, stride = stride), 89 | conv_layer(nh, nf, 3, zero_bn = True, act = False) 90 | ] if expansion == 1 else [ 91 | conv_layer(ni, nh, 1), 92 | conv_layer(nh, nh, 3, stride = stride), 93 | conv_layer(nh, nf, 1, zero_bn = True, act = False) 94 | ] 95 | self.sa = SimpleSelfAttention(nf, ks = 1, sym = sym) if sa else noop 96 | self.convs = nn.Sequential(*layers) 97 | self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act = False) 98 | self.pool = noop if stride == 1 else nn.AvgPool2d(2, ceil_mode = True) 99 | 100 | def forward(self, x): 101 | return act_fn(self.sa(self.convs(x)) + self.idconv(self.pool(x))) 102 | 103 | def filt_sz(recep): return min(64, 2 ** math.floor(math.log2(recep * 0.75))) 104 | 105 | class MXResNet(nn.Sequential): 106 | 107 | def __init__(self, expansion, layers, c_in = 3, c_out = 1000, sa = False, sym = False): 108 | stem = [] 109 | sizes = [c_in, 32, 64, 64] #modified per Grankin 110 | for i in range(3): 111 | stem.append(conv_layer(sizes[i], sizes[i+1], stride = 2 if i == 0 else 1)) 112 | 113 | block_szs = [64//expansion, 64, 128, 256, 512] 114 | blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i == 0 else 2, sa = sa if i in[len(layers)-4] else False, sym = sym) 115 | for i, l in enumerate(layers)] 116 | super().__init__( 117 | *stem, 118 | nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), 119 | *blocks, 120 | nn.AdaptiveAvgPool2d(1), Flatten(), 121 | nn.Linear(block_szs[-1] * expansion, c_out), 122 | ) 123 | init_cnn(self) 124 | 125 | def _make_layer(self, expansion, ni, nf, blocks, stride, sa = False, sym = False): 126 | return nn.Sequential( 127 | *[ResBlock(expansion, ni if i == 0 else nf, nf, stride if i == 0 else 1, sa if i in [blocks -1] else False, sym) 128 | for i in range(blocks)]) 129 | 130 | def mxresnet(expansion, n_layers, name, pretrained = False, **kwargs): 131 | model = MXResNet(expansion, n_layers, **kwargs) 132 | return model 133 | 134 | me = sys.modules[__name__] 135 | for n, e, l in [ 136 | [ 18 , 1, [2, 2, 2 , 2] ], 137 | [ 34 , 1, [3, 4, 6 , 3] ], 138 | [ 50 , 4, [3, 4, 6 , 3] ], 139 | ]: 140 | name = f'mxresnet{n}' 141 | setattr(me, name, partial(mxresnet, expansion = e, n_layers = l, name = name)) 142 | -------------------------------------------------------------------------------- /src/spectrograms.py: -------------------------------------------------------------------------------- 1 | import os, io, torch, matplotlib 2 | from typing import List, Dict, Tuple 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from copy import deepcopy 7 | from random import random, shuffle 8 | from time import time 9 | from PIL import Image 10 | from PIL.ImageOps import crop 11 | from random import randint 12 | from math import ceil 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 14 | import pylab, librosa 15 | from pydub import AudioSegment 16 | import librosa.display 17 | 18 | class Spectrogram(object): 19 | 20 | '''Build and sample (partition) spectrograms for a given audio file for use in a CNN. 21 | 22 | General Workflow 23 | ----------------- 24 | - Gather .mp3 files 25 | - Convert to mono-channel .wav 26 | - Build full spectrogram 27 | - Get n square partitions as needed (sampled in a sliding manner or randomly) 28 | ''' 29 | 30 | def __init__(self, aspect_ratio:int = 3, min_dpi:int = 100): 31 | self.aspect_ratio = aspect_ratio # the aspect ratio for the full-size spectrogram 32 | self.min_dpi = min_dpi # the min dpi for the output image (increased as necesary to fit aspect ratio) 33 | 34 | def stereo2mono(self, mp3_path:str, export_path:str): 35 | '''Convert a stereo .mp3 file to a mono .wav''' 36 | sound = AudioSegment.from_mp3(mp3_path).set_channels(1) 37 | sound.export(export_path, format = 'wav'); 38 | 39 | def load(self, fname:str, sr:int = 44100): 40 | '''Get the audio time series and sampling rate for a given audio file''' 41 | return librosa.load(fname, sr = sr) 42 | 43 | def show(self, signal:np.ndarray, sr:int, window_sz:int, n_mels:int, hop_length:int, ref_max:bool, figsize:Tuple): 44 | '''Display a Mel Spectrogram for a given audio file in notebook''' 45 | S = librosa.feature.melspectrogram(signal, sr, n_fft = window_sz, n_mels = n_mels, hop_length = hop_length) 46 | plt.figure(figsize = figsize) 47 | librosa.display.specshow( 48 | librosa.power_to_db(S, ref=np.max) if ref_max else librosa.power_to_db(abs(S)), 49 | sr = sr, hop_length = hop_length, x_axis = 'time', y_axis = 'mel' 50 | ) 51 | 52 | def _calc_fig_size_res(self, h:int): 53 | '''Helper function to calculate the new figsize and image resolution (in DPI) for a given height''' 54 | w = h * self.aspect_ratio 55 | assert w % ceil(w) == 0 56 | for i in range(self.min_dpi, 301): 57 | if w % i == 0 and h % i == 0: 58 | dpi = i 59 | figsize = (w / dpi, h / dpi) 60 | return figsize, dpi 61 | 62 | def export(self, export_path:str, img_height:int, signal:np.ndarray, sr:int, window_sz:int, 63 | n_mels:int, hop_length:int, top_db:int = 80, cmap:str = 'coolwarm', to_disk:bool = True): 64 | '''Export a Mel Spectrogram as a .png file for a given audio file. 65 | 66 | Parameters 67 | ---------- 68 | export_path : the path and file name of the created spectrogram 69 | img_height : the height of the spectrogram, the width will be calculated automatically based on `dpi` 70 | signal : the audio signal 71 | sr : the sample rate of the audio 72 | window_sz : n_fft, the number of samples used in each Fourier Transform (the width of the window) 73 | n_mels : how many mel bins are used, this will determine how many pixels tall the spectrogram is 74 | hop_length : the number of samples the Fourier Transform window slides (too large: compressing data, too small: blurring) 75 | top_db : distance between the loudest and softest sound displayed in spectrogram 76 | cmap : the color map used for the spectrogram, default for librosa: "magma" 77 | ''' 78 | 79 | # generate spectro 80 | S = librosa.power_to_db( 81 | abs(librosa.feature.melspectrogram(signal, sr, n_fft = window_sz, 82 | n_mels = n_mels, hop_length = hop_length)), 83 | top_db = top_db 84 | ) 85 | 86 | # build fig 87 | self.img = None 88 | figsize, dpi = self._calc_fig_size_res(img_height) 89 | fig = plt.Figure(figsize=figsize) 90 | canvas = FigureCanvas(fig) 91 | ax = fig.add_subplot(111) 92 | ax.imshow(torch.from_numpy(S).flip(0), cmap = cmap) 93 | fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1) 94 | ax.axis('tight'); ax.axis('off') 95 | 96 | # save to disk 97 | if to_disk: 98 | fig.savefig(export_path, dpi = dpi) 99 | 100 | # keep in memory 101 | else: 102 | buf = io.BytesIO() 103 | fig.savefig(buf, format='png', dpi = dpi) 104 | buf.seek(0) 105 | self.img = deepcopy(Image.open(buf)) 106 | buf.close() 107 | 108 | print('Successfully built spectrogram.') 109 | 110 | 111 | def crop_partitions(self, fname:str, how:str, n:int = 1): 112 | '''Crop n square partitions from a spectrogram image file for CNN. 113 | "how" should be one of the following: 114 | - "slide" : start from the left edge and slide n times slide_window distance 115 | - "center": crop from the center of the spectrogram (should only be used when n=1) 116 | - "random": take partions n times from random places of the spectrogram 117 | ''' 118 | assert how == 'slide' or how == 'center' or how == 'random' 119 | 120 | if self.img: img = deepcopy(self.img) 121 | else : img = Image.open(fname) 122 | w, h = img.size 123 | 124 | if how == 'center': 125 | margin = w - ((w // 2) + (h // 2)) 126 | crop(img, border = (margin, 0, margin, 0)).save(fname.replace('.png', '_1.png')) 127 | 128 | elif how == 'slide': 129 | slide_window = (w - h) // (n - 1) 130 | for i in range(n): 131 | crop(img, border = (0+(slide_window*i), 0, w-h-(slide_window*i), 0) 132 | ).save(fname.replace('.png', f'_{i+1}.png')) 133 | 134 | elif how == 'random': 135 | for i in range(n): 136 | left = randint(0, w-h) 137 | crop(img, border = (left, 0, w-(left+h), 0)).save(fname.replace('.png', f'_{i+1}.png')) 138 | 139 | def load_and_partition(self, wav_fname:str, png_fname:str, img_sz:int, window_sz:int, n_mels:int, hop_length:int, 140 | top_db:int, cmap:str, how:str, n:int, del_full_spectro:bool = False, to_disk:bool = True): 141 | '''Load a .wav file, generate spectrogram, and partition''' 142 | 143 | sig, sr = self.load(wav_fname) 144 | self.export(png_fname, img_sz, sig, sr, window_sz, n_mels, hop_length, top_db, cmap, to_disk) 145 | self.crop_partitions(png_fname, how, n) 146 | if del_full_spectro and to_disk: os.remove(png_fname) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import fastai, pickle, sklearn 2 | from fastai import * 3 | from fastai.vision import * 4 | import torch 5 | from torch.utils.data.dataloader import default_collate 6 | from torch.utils.data import Sampler, SequentialSampler, RandomSampler 7 | from random import shuffle 8 | 9 | 10 | ################################ HELPERS ################################ 11 | 12 | def write_pkl(obj, fname): 13 | '''Save a Pickle file''' 14 | with open(fname, 'wb') as f: pickle.dump(obj, f, protocol = pickle.HIGHEST_PROTOCOL) 15 | 16 | def read_pkl(fname): 17 | '''Open a Pickle file''' 18 | with open(fname, 'rb') as f: 19 | file = pickle.load(f) 20 | return file 21 | 22 | def flat(multi_dim_list): 23 | ''' flatten a multi-dimensional list ''' 24 | return [l for t in multi_dim_list for l in t] 25 | 26 | 27 | ################################ TRAINING METRICS ################################ 28 | 29 | def top_k(input:Tensor, targs:Tensor, k:int, n:int)->Rank0Tensor: 30 | '''Computes the Top-k accuracy (target is in the top k predictions).''' 31 | input = input.topk(k, dim = -1)[1] 32 | targs_dict = {i: [] for i in range(targs.size()[0])} 33 | for t in targs.nonzero(): targs_dict[t[0].item()].append(t[1].item()) 34 | 35 | match = 0 36 | for i, INP in enumerate(input): 37 | for true in targs_dict[i]: 38 | for inp in INP: 39 | if inp.item() == true: match += 1 40 | 41 | return torch.tensor(match / n) 42 | 43 | def top_5(input:Tensor, targs:Tensor)->Rank0Tensor: 44 | '''Static Top 5 Accuracy metric (get around fastai init issues)''' 45 | k, n = 5, targs.size()[0] 46 | return top_k(input, targs, k, n) 47 | 48 | def avg_label_rank(input:Tensor, targs:Tensor)->Rank0Tensor: 49 | '''Computes average rank of multi-label prediction (1 being the best).''' 50 | n_batches, n_labels, ranks = targs.size()[0], targs.size()[1], [] 51 | for i in range(n_batches): 52 | concat = torch.stack([input[i], targs[i]]).T 53 | ranks.append(n_labels - concat[concat[:,0].argsort()][:,1].nonzero().float().mean().item()) 54 | return torch.tensor(np.mean(ranks)).float().mean() 55 | 56 | 57 | ################################ MIXUP ################################ 58 | 59 | # https://github.com/mnpinto/audiotagging2019 60 | class AudioMixup(LearnerCallback): 61 | 62 | def __init__(self, learn): 63 | super().__init__(learn) 64 | 65 | def on_batch_begin(self, last_input, last_target, train, **kwargs): 66 | 67 | if train: 68 | bs = last_input.size()[0] 69 | lambd = np.random.uniform(0, 0.5, bs) 70 | shuffle = torch.randperm(last_target.size(0)).to(last_input.device) 71 | x1, y1 = last_input[shuffle], last_target[shuffle] 72 | a = tensor(lambd).float().view(-1, 1, 1, 1).to(last_input.device) 73 | last_input = a*last_input + (1-a)*x1 74 | last_target = {'y0':last_target, 'y1':y1, 'a':a.view(-1)} 75 | return {'last_input': last_input, 'last_target': last_target} 76 | 77 | class SpecMixUp(LearnerCallback): 78 | 79 | def __init__(self, learn:Learner): 80 | super().__init__(learn) 81 | self.masking_max_percentage = 0.5 82 | self.alpha = .6 83 | 84 | def _spec_augment(self, last_input, last_target): 85 | shuffle = torch.randperm(last_target.size(0)).to(last_input.device) 86 | x1, y1 = last_input[shuffle], last_target[shuffle] 87 | batch_size, channels, height, width = last_input.size() 88 | h_percentage = np.random.uniform(low=0., high=self.masking_max_percentage, size=batch_size) 89 | w_percentage = np.random.uniform(low=0., high=self.masking_max_percentage, size=batch_size) 90 | alpha = (h_percentage + w_percentage) - (h_percentage * w_percentage) 91 | alpha = last_input.new(alpha) 92 | alpha = alpha.unsqueeze(1) 93 | new_input = last_input.clone() 94 | 95 | for i in range(batch_size): 96 | h_mask = int(h_percentage[i] * height) 97 | h = int(np.random.uniform(0.0, height - h_mask)) 98 | new_input[i, :, h:h + h_mask, :] = x1[i, :, h:h + h_mask, :] 99 | w_mask = int(w_percentage[i] * width) 100 | w = int(np.random.uniform(0.0, width - w_mask)) 101 | new_input[i, :, :, w:w + w_mask] = x1[i, :, :, w:w + w_mask] 102 | 103 | new_target = (1-alpha) * last_target + alpha*y1 104 | 105 | return new_input, new_target 106 | 107 | def _mixup(self, last_input, last_target): 108 | lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0)) 109 | lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1) 110 | lambd = last_input.new(lambd) 111 | shuffle = torch.randperm(last_target.size(0)).to(last_input.device) 112 | x1, y1 = last_input[shuffle], last_target[shuffle] 113 | new_input = (last_input * lambd.view(lambd.size(0),1,1,1) + x1 * (1-lambd).view(lambd.size(0),1,1,1)) 114 | if len(last_target.shape) == 2: 115 | lambd = lambd.unsqueeze(1).float() 116 | new_target = last_target.float() * lambd + y1.float() * (1-lambd) 117 | return new_input, new_target 118 | 119 | def on_batch_begin(self, last_input, last_target, train, **kwargs): 120 | if not train: return 121 | new_input, new_target = self._mixup(last_input, last_target) 122 | new_input, new_target = self._spec_augment(new_input, new_target) 123 | return {'last_input': new_input, 'last_target': new_target} 124 | 125 | 126 | class StandardMixUp(LearnerCallback): 127 | 128 | def __init__(self, learn:Learner): 129 | super().__init__(learn) 130 | self.masking_max_percentage=0.25 131 | self.alpha = .4 132 | 133 | def _mixup(self, last_input, last_target): 134 | lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0)) 135 | lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1) 136 | lambd = last_input.new(lambd) 137 | shuffle = torch.randperm(last_target.size(0)).to(last_input.device) 138 | x1, y1 = last_input[shuffle], last_target[shuffle] 139 | new_input = (last_input * lambd.view(lambd.size(0),1,1,1) + x1 * (1-lambd).view(lambd.size(0),1,1,1)) 140 | if len(last_target.shape) == 2: 141 | lambd = lambd.unsqueeze(1).float() 142 | new_target = last_target.float() * lambd + y1.float() * (1-lambd) 143 | return new_input, new_target 144 | 145 | def on_batch_begin(self, last_input, last_target, train, **kwargs): 146 | if not train: return 147 | new_input, new_target = self._mixup(last_input, last_target) 148 | return {'last_input': new_input, 'last_target': new_target} 149 | 150 | 151 | ################################ OBJECTS ################################ 152 | 153 | LABELS = [ 154 | 'acid', 'acoustic', 'africa', 'afternoon', 'airplane', 'ambient', 'angelic', 'anger', 155 | 'angst', 'arab', 'asia', 'atmospheric', 'autumn', 'bad mood', 'bbq', 'beach', 'beautiful', 156 | 'bed', 'beer', 'berlin', 'biking', 'bleak', 'bliss', 'breakfast', 'breakup', 'bright', 157 | 'broken', 'cabin', 'cafe', 'calm', 'caribbean', 'celebration', 'champagne', 'chaotic', 158 | 'chill', 'choir', 'cinematic', 'city', 'city night', 'cleaning', 'clubbing', 'cocaine', 159 | 'cocktail', 'coffee', 'cold', 'commuting', 'concentration', 'cool', 'cosmic', 'cowboy', 160 | 'cozy', 'cry', 'cuddle', 'dance', 'dark', 'date', 'dawn', 'day', 'deep', 'depressed', 161 | 'desire', 'despair', 'dirty', 'dramatic', 'dream', 'drinking', 'driving', 'drug', 'dusk', 162 | 'early', 'ecstasy', 'eerie', 'encouraging', 'energetic', 'epic', 'erotic', 'ethereal', 163 | 'euphoric', 'evening', 'excited', 'fearless', 'fight', 'fitness', 'flapper', 'focus', 164 | 'forest', 'forgiveness', 'free spirit', 'friday', 'fuck', 'fucked up', 'fun', 'gaming', 165 | 'garden', 'gentle', 'gloom', 'gloomy', 'goa', 'going out', 'good mood', 'good vibes', 166 | 'grill', 'grimy', 'gritty', 'groovy', 'grunge', 'guilt', 'guitar', 'gym', 'hallucinating', 167 | 'happiness', 'happy', 'havana', 'hazy', 'healing', 'heartbreak', 'heavy', 'hipster', 'home', 168 | 'hopeless', 'horny', 'hype', 'hypnotic', 'ibiza', 'india', 'insane', 'inspiring', 'intense', 169 | 'intimate', 'introspective', 'iran', 'irish', 'island', 'italy', 'jamaica', 'japan', 170 | 'kaleidoscope', 'kiss', 'lake', 'late', 'latin america', 'lazy', 'lit', 'london', 171 | 'loneliness', 'lonely', 'loud', 'love', 'lsd', 'magic', 'marathon', 'massage', 'meditation', 172 | 'meditative', 'melancholic', 'mellow', 'memphis', 'memserizing', 'meth', 'mexico', 173 | 'middle east', 'misery', 'monday', 'moon', 'morning', 'motivational', 'motorcycle', 174 | 'mountain', 'moving on', 'mystical', 'nashville', 'nature', 'new orleans', 'new york', 175 | 'night', 'nocturnal', 'noise', 'nomad', 'ocean', 'office', 'optimistic', 'painting', 176 | 'paradise', 'paris', 'party', 'passion', 'peaceful', 'pensive', 'piano', 'porch', 'powerful', 177 | 'psychedelic', 'rainy', 'rave', 'reading', 'reflective', 'relax', 'roadtrip', 'romantic', 178 | 'running', 'sad', 'saturday', 'sea', 'sedating', 'seductive', 'sensual', 'sentimental', 179 | 'serene', 'sex', 'sky', 'sleep', 'slow', 'slumber', 'smoking', 'smooth', 'snow', 'soft', 180 | 'somber', 'soothing', 'sorrow', 'soulful', 'southern', 'space', 'spacey', 'spiritual', 181 | 'spring', 'steamy', 'stimulating', 'stoner', 'storm', 'strings', 'stroll', 'study', 'summer', 182 | 'sun', 'sunday', 'sunny', 'sunrise', 'sunset', 'sunshine', 'surf', 'surreal', 'tender', 183 | 'tequila', 'thursday', 'tranquil', 'tranquilizer', 'travel', 'tribal', 'trip', 'trippy', 184 | 'tropical', 'tuesday', 'upbeat', 'uplifting', 'urban', 'vibrant', 'violin', 'visceral', 185 | 'vodka', 'walking', 'wandering', 'wanderlust', 'warm', 'wedding', 'wednesday', 'weed', 186 | 'weekend', 'whiskey', 'wine', 'winter', 'woods', 'work', 'workout', 'writing', 'yacht', 187 | 'yoga', 'zen' 188 | ] 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import urllib, requests, spotipy, os, math, shutil, gc 17 | import streamlit as st 18 | import numpy as np 19 | import pandas as pd 20 | import matplotlib.pyplot as plt 21 | import plotly.express as px 22 | import plotly.graph_objects as go 23 | from spotipy.oauth2 import SpotifyClientCredentials 24 | from copy import deepcopy 25 | from io import BytesIO 26 | from PIL import Image, ImageFile 27 | from fastai.vision import * 28 | from fastai.callbacks import * 29 | from utils import * 30 | from spectrograms import * 31 | from loss import * 32 | from mxresnet import * 33 | 34 | 35 | 36 | 37 | #--------------------------------------------------------------------------------# 38 | # # 39 | # Main # 40 | # ::: Handles the navigation / routing and data loading / caching. # 41 | # # 42 | #--------------------------------------------------------------------------------# 43 | 44 | 45 | def main(): 46 | '''Set main() function. Includes sidebar navigation and respective routing.''' 47 | 48 | st.sidebar.title("Explore") 49 | app_mode = st.sidebar.selectbox( "Choose an Action", [ 50 | "About", 51 | "Choose an Emotion", 52 | "Choose an Artist", 53 | "Classify a Song", 54 | "Emotional Spectrum", 55 | "Show Source Code" 56 | ]) 57 | 58 | # clear tmp 59 | clear_tmp() 60 | 61 | # nav 62 | if app_mode == "About": show_about() 63 | elif app_mode == "Choose an Emotion": explore_classified() 64 | elif app_mode == 'Choose an Artist': explore_artists() 65 | elif app_mode == "Classify a Song": classify_song() 66 | elif app_mode == "Emotional Spectrum": display_emotional_spectrum() 67 | elif app_mode == "Show Source Code": st.code(get_file_content_as_string()) 68 | 69 | 70 | @st.cache 71 | def load_data(): 72 | ''' Load main data source with all labels and values. ''' 73 | return read_pkl(path('data/final_scores_meta.pkl')) 74 | 75 | 76 | @st.cache 77 | def load_tsne(): 78 | ''' Load TSNE data for plotting / viz ''' 79 | return read_pkl(path('data/tsne.pkl')) 80 | 81 | 82 | def clear_tmp(): 83 | ''' Clear /tmp on load. Used for new song classification. ''' 84 | shutil.rmtree(path('tmp')) 85 | for d in [path('tmp'), path('tmp/png'), path('tmp/wav')]: os.mkdir(d) 86 | 87 | 88 | def path(orig_path): 89 | ''' Path handler for local or production ''' 90 | if len(ROOT_DIR) == 0: return orig_path 91 | return f'{ROOT_DIR}/{orig_path}' 92 | 93 | 94 | @st.cache(show_spinner = False) 95 | def get_file_content_as_string(): 96 | ''' Download a single file and make its content available as a string. ''' 97 | url = 'https://raw.githubusercontent.com/zacheberhart/Learning-to-Feel/master/src/app.py' 98 | response = urllib.request.urlopen(url) 99 | return response.read().decode("utf-8") 100 | 101 | 102 | @st.cache(show_spinner = False) 103 | def read_text(fname): 104 | ''' Display copy from a .txt file. ''' 105 | with open(fname, 'r') as f: 106 | text = f.readlines() 107 | return text 108 | 109 | 110 | def show_about(): 111 | ''' Home / About page ''' 112 | st.title('Learning to Listen, to Feel') 113 | for line in read_text(path('about.txt')): 114 | st.write(line) 115 | 116 | 117 | 118 | 119 | #--------------------------------------------------------------------------------# 120 | # # 121 | # Choose an Emotion # 122 | # ::: Allow the user to pick one or more labels to get a list of the top songs # 123 | # ::: classified with the respective label(s). Limit the list of songs returned # 124 | # ::: to 100, but allow the user to choose the quantity and the "Popularity", a # 125 | # ::: a metric provided by Spotify's API. Also allow the user to leave the app # 126 | # ::: and listen to the song on Spotify's Web App using the provided link. # 127 | # # 128 | #--------------------------------------------------------------------------------# 129 | 130 | 131 | def explore_classified(): 132 | 133 | # load all data 134 | df = load_data() 135 | non_label_cols = ['track_id', 'track_title', 'artist_name', 'track_popularity', 'artist_popularity'] 136 | dims = [c for c in df.columns.tolist() if c not in non_label_cols] 137 | 138 | # Mood or Emotion Selection 139 | st.title('Explore All Moods & Emotions') 140 | st.write(''' 141 | Select a mood, an emotion, or a few of each! However, keep in mind that results are best when 142 | you choose as few as possible -- though you will definitely get some pretty funky results the more you add. 143 | ''') 144 | 145 | # filters 146 | labels = st.multiselect("Choose:", dims) 147 | n_songs = st.slider('How many songs?', 1, 100, 20) 148 | popularity = st.slider('How popular?', 0, 100, (0, 100)) 149 | 150 | try: 151 | 152 | # filter data to the labels the user specified 153 | cols = (non_label_cols, labels) 154 | df = filter_data(df, cols, n_songs, popularity) 155 | 156 | # show data 157 | if st.checkbox('Include Preview URLs', value = True): 158 | df['preview'] = add_stream_url(df.track_id) 159 | df['preview'] = df['preview'].apply(make_clickable, args = ('Listen',)) 160 | data = df.drop('track_id', 1) 161 | data = data.to_html(escape = False) 162 | st.write(data, unsafe_allow_html = True) 163 | else: 164 | data = df.drop('track_id', 1) 165 | st.write(data) 166 | 167 | except: pass 168 | 169 | def norm_and_combine(df, labels): 170 | ''' Normalize and log transform scores for better query combination results. ''' 171 | tdf = pd.DataFrame() 172 | for label in labels: 173 | tdf[label] = np.log1p(df[label]) 174 | tdf[label] = tdf[label] / tdf[label].max() 175 | return tdf[labels].sum(1) 176 | 177 | def filter_data(df, cols, n_songs, popularity): 178 | ''' Filter the df based on user-selected label, quantity, and popularity selections. ''' 179 | non_label_cols, label_cols = cols 180 | tdf = deepcopy(df[non_label_cols + label_cols]) 181 | tdf = deepcopy(tdf[(tdf.track_popularity >= popularity[0]) & (tdf.track_popularity <= popularity[1]) 182 | ].drop(['track_popularity', 'artist_popularity'], 1)) 183 | tdf = deepcopy(tdf.drop_duplicates(['track_title', 'artist_name'])) 184 | if len(label_cols) > 1: 185 | tdf['combo'] = norm_and_combine(tdf, label_cols) 186 | for label in label_cols: 187 | tdf = deepcopy(tdf[tdf[label] >= 0.04]) 188 | return tdf.sort_values('combo', ascending = False)[:n_songs].reset_index(drop = True) 189 | else: 190 | return tdf.sort_values(label_cols[0], ascending = False)[:n_songs].reset_index(drop = True) 191 | 192 | def add_stream_url(track_ids): 193 | ''' Build Spotify Track URL given its Track ID. ''' 194 | return [f'https://open.spotify.com/track/{t}' for t in track_ids] 195 | 196 | def make_clickable(url, hyperlink_text): 197 | ''' Convert URL to clickable HTML link. ''' 198 | return f'{hyperlink_text}' 199 | 200 | 201 | 202 | 203 | #--------------------------------------------------------------------------------# 204 | # # 205 | # Choose an Artist # 206 | # ::: Display the top 10 labels for an artist specified by the user and a # 207 | # ::: list of tracks that we have for that artist in our db. # 208 | # # 209 | #--------------------------------------------------------------------------------# 210 | 211 | 212 | def explore_artists(): 213 | 214 | # load all data 215 | df = load_data() 216 | non_label_cols = ['track_id', 'track_title', 'artist_name', 'track_popularity', 'artist_popularity'] 217 | dims = [c for c in df.columns.tolist() if c not in non_label_cols] 218 | 219 | # user input 220 | st.title('Explore Artists') 221 | selected_artist = st.text_input('Search for an Artist:', 'Bon Iver') 222 | search_results = df[df.artist_name.str.lower() == selected_artist.lower()] 223 | 224 | # display results 225 | if len(search_results) > 0: 226 | label_weights = get_top_labels(search_results, dims, 10) 227 | st.plotly_chart(artist_3d_scatter(label_weights), width = 0) 228 | search_results['top_label'] = search_results.iloc[:,5:].astype(float).idxmax(1) 229 | st.write(search_results[['track_title', 'artist_name', 'track_popularity', 'top_label'] 230 | ].sort_values(['track_title', 'track_popularity']).drop_duplicates('track_title').reset_index(drop = True)) 231 | else: 232 | st.write('Sorry, there are no results for that artist in our database :(') 233 | 234 | 235 | def artist_3d_scatter(label_weights): 236 | '''Display an artists position in the emotional spectrum.''' 237 | 238 | # get data 239 | label_weights = label_weights.merge(TSNE, on = 'label') 240 | tdf = label_weights.rename(columns = {'1d0': 'color', '3d0': 'energy', '3d1': 'style', '3d2': 'acousticness'}) 241 | 242 | # build fig 243 | fig = go.Figure(data = [go.Scatter3d( 244 | x = tdf['energy'], y = tdf['style'], z = tdf['acousticness'], 245 | mode = 'markers+text', 246 | text = tdf['label'], textfont = dict(size = 16), 247 | marker = dict( 248 | size = tdf['weight'] * 50, 249 | color = tdf['color'], 250 | opacity = 0.6, 251 | colorscale = 'RdBu', 252 | ) 253 | )]) 254 | 255 | # layout modifications 256 | fig.update_layout(margin = dict(l = 0, r = 0, b = 0, t = 0), scene = dict( 257 | xaxis_title = 'Energy', 258 | yaxis_title = 'Style', 259 | zaxis_title = 'Acousticness', 260 | xaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 261 | yaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 262 | zaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 263 | )) 264 | 265 | return fig 266 | 267 | 268 | def get_top_labels(df, dims, n): 269 | ''' Get the top n labels for a given artist. ''' 270 | label_weights = pd.DataFrame(df[dims].sum().sort_values(ascending = False)).reset_index() 271 | label_weights.columns = ['label', 'weight'] 272 | label_weights.weight = label_weights.weight / label_weights.weight.max() 273 | return label_weights[:n] 274 | 275 | 276 | 277 | 278 | #--------------------------------------------------------------------------------# 279 | # # 280 | # Classify a Song # 281 | # ::: Get Top/Bottom 5 Labels for a track specified by the user. If the track # 282 | # ::: has already been classified, pull from db. Otherwise, pull audio from # 283 | # ::: Spotify and classify using a distilled version of the model. # 284 | # # 285 | #--------------------------------------------------------------------------------# 286 | 287 | 288 | def classify_song(): 289 | ''' 290 | Potential additional features: 291 | - Similar tracks (based on emotional signature) 292 | ''' 293 | 294 | # load all data 295 | df = load_data() 296 | 297 | # copy 298 | st.title('Classify a Song') 299 | st.markdown('Want to analyze a specific track? Enter the Spotify URL (or Track ID if you know it):') 300 | st.markdown('To get the track\'s URL from the Spotify app: \n - Drag the track into the search box below, or \n - Click on Share >> Copy Song Link and paste below.', unsafe_allow_html = True) 301 | 302 | # user input 303 | track_id = st.text_input('Enter Track URL:') 304 | st.markdown('*Note: Unfortunately, due to licensing restrictions, many songs from some of the more popular artists are unavailable.*', unsafe_allow_html = True) 305 | if len(track_id) > 22: 306 | track_id = track_id.split('?')[0].split('/track/')[1] 307 | show_spectros = st.checkbox('Show Spectrograms', value = False) 308 | 309 | # check if a track_id has been entered 310 | if len(track_id) > 0: 311 | 312 | # get track from Spotify API 313 | track = get_spotify_track(track_id) 314 | st.subheader('Track Summary') 315 | st.table(get_track_summmary_df(track)) 316 | 317 | # check if there is track preview available from Spotify 318 | if track['preview_url']: 319 | 320 | # display 30 second track preview 321 | st.subheader('Track Preview (What the Algorithm "Hears")') 322 | st.write('') 323 | preview = get_track_preview(track_id) 324 | st.audio(preview) 325 | 326 | # get top and bottom labels for the track 327 | st.subheader('Track Analysis') 328 | track_df = deepcopy(DF[DF.track_id == track_id].reset_index(drop = True)) 329 | 330 | # return values from db if already classified, otherwise classify 331 | if len(track_df) > 0: 332 | track_df = deepcopy(track_df.iloc[:,5:].T.rename(columns = {0: 'score'}).sort_values('score', ascending = False)) 333 | st.table(pd.DataFrame({'Top 5': track_df[:5].index.tolist(), 'Bottom 5': track_df[-5:].index.tolist()})) 334 | if show_spectros: generate_spectros(preview) 335 | else: 336 | generate_spectros(preview) 337 | track_df = get_predictions() 338 | st.table(pd.DataFrame({'Top 5': track_df[:5].index.tolist(), 'Bottom 5': track_df[-5:].index.tolist()})) 339 | 340 | if show_spectros: 341 | st.subheader('Spectrograms (What the Algorithm "Sees")') 342 | generate_grid() 343 | st.image(image = path('tmp/png/grid.png'), use_column_width = True) 344 | 345 | # Spotify doesn't have preview for track 346 | else: 347 | st.write('Preview unavailable for this track :(') 348 | 349 | 350 | def get_spotify_track(track_id): 351 | ''' Get track from Spotify, given its Track ID. ''' 352 | return SP.track(track_id) 353 | 354 | 355 | def get_track_preview(track_id): 356 | ''' Get a 30 Second Preview, if available. ''' 357 | return requests.get(get_spotify_track(track_id)['preview_url']).content 358 | 359 | 360 | def get_track_summmary_df(track): 361 | ''' Build a summary for a given track for display ''' 362 | return pd.DataFrame([{ 363 | 'Track': track['name'], 364 | 'Artist': track['artists'][0]['name'], 365 | 'Album': track['album']['name'], 366 | 'Popularity': track['popularity'], 367 | }], index = [' '])[['Track', 'Artist', 'Album', 'Popularity']] 368 | 369 | 370 | class SpecMixUpINCR(SpecMixUp): 371 | '''Spectral MixUp for Conv Net''' 372 | def __init__(self, learn:Learner): 373 | super().__init__(learn) 374 | self.masking_max_percentage = 0.5 375 | self.alpha = 0.8 376 | 377 | 378 | def get_predictions(): 379 | '''Get predictions for a given song. Note that this is just a distilled version 380 | of the model. Currently, it averages predictions on all spectros found in /tmp''' 381 | 382 | model_weights = { 383 | 'mxrn18_partition4_multi-mixup-4_bce_448-sz_32-bs_6-ep_0.0001-lr_2': 0.526, 384 | 'mxrn18_partition5_multi-mixup-4_bce_448-sz_32-bs_5-ep_0.0001-lr_2': 0.474, 385 | } 386 | 387 | # build DataBunch for FastAI 388 | data = (ImageList.from_folder(path('tmp/png')) 389 | .split_none() 390 | .label_from_folder() 391 | .transform(size = IMG_SZ) 392 | .databunch(bs = 1) 393 | .normalize(imagenet_stats) 394 | ) 395 | 396 | # get predictions for track with models (total of 8 predictions: [4 spectros x 2 models]) 397 | all_preds = pd.DataFrame() 398 | for model, weight in model_weights.items(): 399 | print('Predicting with:', model) 400 | learn = load_learner(MODEL_PATH, f'{model}_export.pkl') 401 | model_preds = [[item.item() for item in torch.sigmoid(learn.predict(data.train_ds[i][0])[2])] for i in range(4)] 402 | all_preds = all_preds.append(deepcopy(pd.DataFrame(pd.DataFrame(model_preds, columns = LABELS).mean() * weight))) 403 | del learn, model_preds; gc.collect() 404 | all_preds = deepcopy(all_preds.reset_index().groupby('index').mean().sort_values(0).rename(columns = {0: 'score'})) 405 | 406 | return all_preds.sort_values('score', ascending = False) 407 | 408 | 409 | def generate_spectros(audio): 410 | '''Generate spectrograms of a given audio for input into conv net.''' 411 | 412 | # convert mp3 to wav 413 | spec = Spectrogram() 414 | sound = AudioSegment.from_file(BytesIO(audio)).set_channels(1) 415 | sound.export(path('tmp/wav/user_classify.wav'), format = 'wav') 416 | 417 | # generate spectrograms 418 | spec.load_and_partition( 419 | wav_fname = path('tmp/wav/user_classify.wav'), 420 | png_fname = path('tmp/png/user_classify.png'), 421 | img_sz = 448, 422 | window_sz = 8192, 423 | n_mels = 512, 424 | hop_length = 128, 425 | top_db = 90, 426 | cmap = 'magma', 427 | how = 'slide', 428 | n = 4, 429 | to_disk = False, 430 | ) 431 | 432 | 433 | def generate_grid(): 434 | '''Generate a grid of images from a set of images in a directory (used to display 435 | spectrograms in lieu of CSS styling.''' 436 | 437 | # Config: 438 | images_dir = path('tmp/png') 439 | result_grid_filename = f'{images_dir}/grid.png' 440 | result_figsize_resolution = 30 # default: 40 441 | images_list = os.listdir(images_dir) 442 | images_count = len(images_list) 443 | 444 | # Calculate the grid size: 445 | grid_size = math.ceil(math.sqrt(images_count)) 446 | 447 | # Create plt plot: 448 | fig, axes = plt.subplots(grid_size, grid_size, figsize=(result_figsize_resolution, result_figsize_resolution)) 449 | 450 | current_file_number = 0 451 | for image_filename in images_list: 452 | x_position = current_file_number % grid_size 453 | y_position = current_file_number // grid_size 454 | plt_image = plt.imread(images_dir + '/' + images_list[current_file_number]) 455 | axes[x_position, y_position].imshow(plt_image) 456 | print((current_file_number + 1), '/', images_count, ': ', image_filename) 457 | current_file_number += 1 458 | 459 | plt.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0) 460 | plt.savefig(result_grid_filename) 461 | 462 | 463 | 464 | 465 | #--------------------------------------------------------------------------------# 466 | # # 467 | # Display Emotional Spectrum # 468 | # ::: Display all labels as a 3D Scatter chart for the user to explore. # 469 | # # 470 | #--------------------------------------------------------------------------------# 471 | 472 | 473 | def display_emotional_spectrum(): 474 | st.title('Emotional Spectrum') 475 | st.write( 476 | """Visually explore the algorithm's mapping of the emotional spectrum of the musical experience below. 477 | Use the \"Expand\" button in the top right corner of the chart for the best view!""" 478 | ) 479 | st.plotly_chart(all_labels_scatter(), width = 0) 480 | 481 | 482 | def all_labels_scatter(): 483 | ''' Display a 3D Scatter Plot using all of the labels in the algorithm. ''' 484 | 485 | tdf = TSNE.rename(columns = {'1d0': 'color', '3d0': 'energy', '3d1': 'style', '3d2': 'acousticness'}) 486 | 487 | fig = go.Figure(data = [go.Scatter3d( 488 | x = tdf['energy'], y = tdf['style'], z = tdf['acousticness'], 489 | mode = 'markers+text', 490 | text = tdf['label'], 491 | marker = dict( 492 | color = tdf['color'], 493 | opacity = 0.6, 494 | colorscale = 'RdBu', 495 | ) 496 | )]) 497 | 498 | fig.update_layout(margin = dict(l = 0, r = 0, b = 0, t = 0), scene = dict( 499 | xaxis_title = 'Energy', 500 | yaxis_title = 'Style', 501 | zaxis_title = 'Acousticness', 502 | xaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 503 | yaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 504 | zaxis = dict(showticklabels = False, nticks = 5, range = [-200, 200]), 505 | )) 506 | 507 | return fig 508 | 509 | 510 | 511 | 512 | #--------------------------------------------------------------------------------# 513 | # # 514 | # Execute # 515 | # # 516 | #--------------------------------------------------------------------------------# 517 | 518 | if __name__ == "__main__": 519 | 520 | # display and machine options 521 | pd.set_option('display.max_colwidth', -1) 522 | defaults.device = torch.device('cpu') 523 | ImageFile.LOAD_TRUNCATED_IMAGES = True 524 | 525 | # Spotify API 526 | SP = spotipy.Spotify(client_credentials_manager = SpotifyClientCredentials( 527 | client_id = 'redacted', 528 | client_secret = 'redacted', 529 | )) 530 | 531 | # data & constants 532 | ROOT_DIR = '' 533 | MODEL_PATH = path('models/exported') 534 | IMG_SZ = 448 535 | DF = load_data() 536 | TSNE = load_tsne() 537 | 538 | # execute 539 | main() 540 | 541 | 542 | 543 | --------------------------------------------------------------------------------