├── dstm ├── __init__.py ├── model │ ├── encoders │ │ ├── __init__.py │ │ └── transformer.py │ ├── __init__.py │ ├── long_term.py │ ├── coded_short_term_model.py │ ├── slow_weight_models.py │ ├── ppm.py │ ├── baseline.py │ ├── module.py │ ├── short_term.py │ └── mc.py └── util │ ├── __init__.py │ ├── data │ └── synthetic_data.py │ ├── metrics.py │ ├── constants.py │ ├── plotting.py │ └── load_data.py ├── setup.py ├── README.md └── dstm.py /dstm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dstm/model/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dstm/model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def delete_param(p, in_path, out_path=False): 3 | obj = torch.load(in_path) 4 | hparams = obj['hyper_parameters']['hparams'] 5 | del vars(hparams)['p'] 6 | if out_path: 7 | torch.save(obj, out_path) 8 | def save_param(p, val, in_path, out_path=False): 9 | obj = torch.load(in_path) 10 | hparams = obj['hyper_parameters']['hparams'] 11 | vars(hparams)[p] = val 12 | #obj['hyper_parameters']['hparams'] = hparams 13 | if out_path: 14 | torch.save(obj, out_path) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup( 3 | python_requires='>=3.8.0', 4 | name='dstm', 5 | version='0.0.1', 6 | packages=['dstm'], 7 | install_requires=[ 8 | 'torch >= 1.8.0', 9 | 'numpy >= 1.19.0', 10 | 'scikit-learn >= 0.24.0', 11 | 'pytorch-lightning >= 1.3.1', 12 | 'matplotlib >= 3.3.0', 13 | # 'librosa >= 0.8.0', 14 | 'pretty_midi >= 0.2.9', 15 | # 'test-tube >= 0.7.5 ', 16 | # 'nnAudio' 17 | ], 18 | # extras_require={ 19 | # 'dev': [ 20 | # 'tensorboard >= 2.2.0', 21 | # 'jupyter >= 1.0.0', 22 | # ] 23 | # } 24 | ) 25 | -------------------------------------------------------------------------------- /dstm/util/__init__.py: -------------------------------------------------------------------------------- 1 | #import yaml 2 | import torch 3 | import numpy as np 4 | # def get_config(): 5 | # conf = None 6 | # with open('config.yml', 'r') as file: 7 | # conf = yaml.safe_load(file) 8 | # return conf 9 | 10 | def entropy(h_src): 11 | return - (h_src * torch.log(h_src)).sum(-1) 12 | def efficiency(h_src): 13 | m = h_src.shape[-1] 14 | ent = entropy(h_src) 15 | return ent/np.log(m) 16 | #return 17 | 18 | def add_centered(t1, t2, pitch_relative_to, time_relative_to, relative_pitch=True): 19 | d = t2.shape[-1] 20 | r = t1.shape[-2] 21 | if relative_pitch: 22 | pitch_index = slice(d-pitch_relative_to-1, 2*d - pitch_relative_to-1) 23 | else: 24 | pitch_index = slice(None) 25 | t1[max((r-time_relative_to), 0):, pitch_index] += t2[max(0, time_relative_to-r):time_relative_to, :] 26 | def piano_to_pitch_change(label, probs): 27 | changes = torch.diff(label, prepend=torch.tensor([-1], device = label.device)) != 0 28 | return probs[changes], label[changes] -------------------------------------------------------------------------------- /dstm/util/data/synthetic_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.load_data import DataPreprocessing 3 | class SyntheticData(DataPreprocessing): 4 | def __init__(self, nb_classes): 5 | #self.nb_classes = nb_classes 6 | super().__init__(nb_classes) 7 | def variations(self, length, n_reps, n_vars_pr_motif, n_samples): 8 | samples = [] 9 | for _ in range(n_samples): 10 | motif = torch.multinomial(1/self.nb_classes *torch.ones(self.nb_classes), length, replacement=True) 11 | sample = [motif] 12 | #Keep the last (sample indices) 13 | for _ in range(n_reps): 14 | #For now variations are not unique but we can add another while loop 15 | var_indices = torch.multinomial(1/(length-1)*torch.ones(length-1), n_vars_pr_motif, replacement=False) 16 | var = motif.clone() 17 | 18 | for index in var_indices: 19 | while True: 20 | var_symb = torch.multinomial(1/self.nb_classes *torch.ones(self.nb_classes), 1).squeeze(0) 21 | if var_symb != var[index]: 22 | var[index] = var_symb 23 | break 24 | sample.append(var) 25 | sample = torch.cat(sample) 26 | samples.append(sample) 27 | return torch.stack(samples) 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /dstm/util/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import precision_score 2 | import scipy 3 | import numpy as np 4 | class Metrics: 5 | def precision(labels, predictions): 6 | return precision_score(y_true=labels, y_pred=predictions, average='micro') 7 | def precision_over_time(labels, predictions,max_t=2000): 8 | precisions = [] #[[] for _ in range()] 9 | labels_t = [[] for _ in range(max_t)] 10 | predictions_t = [[] for _ in range(max_t)] 11 | for label, prediction in zip(labels, predictions): 12 | for i in range(max_t): 13 | labels_t.append(label[i]) 14 | predictions_t.append(prediction[i]) 15 | #for i in range(len(labels))#range(labels.shape[1]): 16 | for label_t, prediction_t in zip(labels_t, predictions_t): 17 | precisions.append(Metrics.precision(label_t, prediction_t)) 18 | return precisions 19 | def tp_stats(labels, predictions): 20 | tp = np.array(labels) == np.array(predictions) 21 | return tp.mean(), tp.std(ddof=1), len(tp) 22 | def t_test_from_stats(mean1, std1, nobs1, mean2, std2, nobs2): 23 | return scipy.stats.ttest_ind_from_stats(mean1, std1, nobs1, mean2, std2, nobs2, equal_var=False) 24 | def t_test(labels, predictions_a, predictions_b): 25 | tp_a = np.array(labels) == np.array(predictions_a) 26 | tp_b = np.array(labels) == np.array(predictions_b) 27 | return scipy.stats.ttest_ind(tp_a, tp_b, axis=0, equal_var=False, nan_policy='propagate') #not supported in current scipy version: , alternative='two-sided') -------------------------------------------------------------------------------- /dstm/model/long_term.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dstm.model.slow_weight_models import DilatedConvBlock 3 | from dstm.model.encoders.transformer import StandardTransformerRNN, TransformerDecoder, RelativeEncodingTransformerRNNLayer, LinearTransformerRNNLayer 4 | from dstm.model.module import Module 5 | class LTM(Module): 6 | def __init__(self, d, hparams): 7 | """ 8 | Args: 9 | d ([type]): [description] 10 | hparams (obj): namespace of hyper parameters 11 | """ 12 | super().__init__(d, hparams) 13 | if hparams.encoder == "transformer_abs": 14 | raise NotImplementedError("Need absolute encoding") 15 | self.slow_weight_model = StandardTransformerRNN(m=d, d=d, dropout=hparams.dropout, num_layers=hparams.transformer_layers, nhead=hparams.transformer_n_head) 16 | elif hparams.encoder == "transformer_rel": 17 | #TODO: 18 | decoder_layer = RelativeEncodingTransformerRNNLayer(num_heads=hparams.transformer_n_head, model_dim=d, hidden_dim=None, 19 | dropout=hparams.dropout, device=None, dtype=None, rel_clip_length=hparams.seq_max_length+1, pos_enc=True) 20 | self.slow_weight_model = TransformerDecoder(d=d, decoder_layer=decoder_layer, num_layers=hparams.transformer_layers, abs_enc=False) 21 | elif hparams.encoder == "transformer_lin": 22 | decoder_layer = LinearTransformerRNNLayer(num_heads=hparams.transformer_n_head, model_dim=d, hidden_dim=None, 23 | dropout=hparams.dropout, device=None, dtype=None, rel_clip_length=hparams.seq_max_length+1, pos_enc=False) 24 | self.slow_weight_model = TransformerDecoder(d=d, decoder_layer=decoder_layer, num_layers=hparams.transformer_layers, abs_enc=True) 25 | 26 | else: 27 | self.slow_weight_model = DilatedConvBlock(m=d, d=d, filters=hparams.filters, activation=hparams.activation, dropout=hparams.dropout) 28 | 29 | @staticmethod 30 | def add_model_specific_args(parser): 31 | pass 32 | def forward(self, h_src, s_tar, W): 33 | raise NotImplementedError("forward") 34 | #return pred, W 35 | 36 | def probs(self, batch, hard=False): 37 | """Returns the probabilities for full sequence 38 | 39 | Args: 40 | batch ([type]): [description] 41 | 42 | Returns: 43 | [type]: [description] 44 | """ 45 | logits = self.slow_weight_model(batch) 46 | probs = torch.nn.functional.softmax(logits, dim=2)[:, :-1, :] 47 | return probs, -1 -------------------------------------------------------------------------------- /dstm/model/coded_short_term_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def step(h_src, s_tar, W, method): 5 | if method in ["softmax_transition_table"]: 6 | #if method in ["softmax_transition_table", "gumbel"]: 7 | return step_transition_table(h_src, s_tar, W) 8 | elif method in ["gumbel", "softmax_normalize_after", "l2_normalized", "positive_h", "generate_target_pitch_shared_weights", "elu"]: 9 | #elif method in ["softmax_normalize_after", "l2_normalized", "positive_h"]: 10 | return step_normalize_after(h_src, s_tar, W) 11 | elif method in ["unbounded_h"]: 12 | return step_unbounded_h(h_src, s_tar, W) 13 | else: 14 | raise NotImplementedError("Short Term Method Unkown.") 15 | 16 | def step_transition_table(h_src, s_tar, W): 17 | """[summary] 18 | Args: 19 | h_src (Tensor): Bxm 20 | s_tar (Tensor): Bxd 21 | W (Tensor): Bxmxd 22 | Returns: 23 | [type]: [description] 24 | """ 25 | # predict using old W 26 | W_normed = F.normalize(W, p=1, dim=2) 27 | pred = torch.bmm(h_src.unsqueeze(1), W_normed) 28 | pred = pred.squeeze(1) 29 | # update using W 30 | W = W + torch.bmm(h_src.unsqueeze(2), s_tar.unsqueeze(1)) 31 | return pred, W 32 | 33 | 34 | def step_normalize_after(h_src, s_tar, W): 35 | """[summary] 36 | Args: 37 | h_src (Tensor): Bxm 38 | s_tar (Tensor): Bxd 39 | W (Tensor): Bxmxd 40 | Returns: 41 | [type]: [description] 42 | """ 43 | # predict using old W 44 | #print(torch.bmm(h_src.unsqueeze(1), W).min(), torch.bmm(h_src.unsqueeze(1), W).max()) 45 | pred = F.normalize(torch.bmm(h_src.unsqueeze(1), W), p=1, dim=2) 46 | pred = pred.squeeze(1) 47 | #del W_normed 48 | # update using W 49 | W = W + torch.bmm(h_src.unsqueeze(2), s_tar.unsqueeze(1)) 50 | return pred, W 51 | 52 | def step_unbounded_h(h_src, s_tar, W): 53 | #TODO: this does prboably not work to sum dot products and use softmax (exponential) 54 | """[summary] 55 | Args: 56 | h_src (Tensor): Bxm 57 | s_tar (Tensor): Bxd 58 | W (Tensor): Bxmxd 59 | Returns: 60 | [type]: [description] 61 | """ 62 | # predict using old W 63 | # For scaled dot product transformer... 64 | #pred = F.softmax(1/h_src.shape[1] *torch.bmm(h_src.unsqueeze(1), W), dim=2) 65 | pred = F.softmax(torch.bmm(h_src.unsqueeze(1), W), dim=2) 66 | #CHANGED 67 | pred = pred.squeeze(1) 68 | #del W_normed 69 | # update using W 70 | W = W + torch.bmm(h_src.unsqueeze(2), s_tar.unsqueeze(1)) 71 | return pred, W 72 | 73 | def matching_network(h_srcs, s_tars): 74 | shape = list(s_tars.shape) 75 | shape[1] += 1 76 | probs = torch.empty(shape).type_as(h_srcs) 77 | # todo: for now set the prior uniformly 78 | probs[:, 0, :] = 1/shape[2] 79 | for i in range(1, shape[1]): 80 | dot_product = h_srcs[:, :i :].bmm(h_srcs[:, i, :].unsqueeze(2)) 81 | softmax = F.softmax(dot_product.permute([0, 2, 1]),dim=-1) 82 | probs[:, i, :] = softmax.bmm(s_tars[:,:i,:]).squeeze(1) 83 | return probs -------------------------------------------------------------------------------- /dstm/util/constants.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | 3 | def cm_to_inch(cm): 4 | return cm/2.54 5 | 6 | class Constants: 7 | # Should it rather be depending on d (the number of outputs???) 8 | # This would then be different pr. dataset. 9 | zero_prob_add = 1e-3 10 | 11 | text_width = cm_to_inch(17.2) 12 | text_height = cm_to_inch(24.3) 13 | column_sep_ = cm_to_inch(.5) 14 | column_width = (text_width - column_sep_)/2 15 | dpi = 1000 16 | matplotlib_rcparams = { 17 | "text.usetex": True, 18 | "font.family": "Helvetica", 19 | "savefig.dpi": 1000, 20 | #"font.sans-serif": ["Helvetica"] 21 | } 22 | 23 | _linestyle_densely_dashed = (0, (5, 1)) 24 | #linestyle_densely_dashed_dotted = (0, (5, 1, 1, 1, 1, 1)) 25 | #_linestyle_densely_dashed_dotted = (0, (15, 1)) 26 | #path_effects_border_line = (pe.Stroke(linewidth=2, foreground='black'), pe.Normal()) 27 | styles = { 28 | "ccstm-elu-rev": {"name": "CCSTM-512 (stm)","linestyle": {"color": colorsys.hls_to_rgb(0, .3, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$0$", "color": colorsys.hls_to_rgb(0, .3, 1,)}}, 29 | "ccstm-elu-32-rev": {"name": "CCSTM-32 (stm)" , "linestyle": {"color": colorsys.hls_to_rgb(0, .5, 1,), "linestyle": "solid"}, "pointstyle": {"marker": r"$1$", "color": colorsys.hls_to_rgb(0, .5, 1,)}}, 30 | #"dcstm-annealing": {"name": "DCSTM-512 (stm)", "linestyle": {"color": colorsys.hls_to_rgb(0, .7, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$2$", "color": colorsys.hls_to_rgb(0, .7, 1,)}}, 31 | "dcstm-rev": {"name": "DCSTM-512 (stm)", "linestyle": {"color": colorsys.hls_to_rgb(0, .7, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$2$", "color": colorsys.hls_to_rgb(0, .7, 1,)}}, 32 | #"io-mc-0": {"name": "MC-0 (stm)" , "linestyle": {"color": colorsys.hls_to_rgb(0.1, .4, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$3$", "color": colorsys.hls_to_rgb(0.1, .4, 1,)}}, 33 | "io-mc-3": {"name": "MC-3 (stm)", "linestyle": {"color": colorsys.hls_to_rgb(0.1, .6, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$3$", "color": colorsys.hls_to_rgb(0.1, .6, 1,)}}, 34 | 'ppm': {"name": "PPM (stm)", "linestyle": {"color": colorsys.hls_to_rgb(.7, .5, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$4$", "color": colorsys.hls_to_rgb(.7, .5, 1,)}}, 35 | 'repetition': {"name": "Repetion (stm)", "linestyle": {"color": colorsys.hls_to_rgb(.9, .5, 1,), "linestyle":"solid"}, "pointstyle": {"marker": r"$5$", "color": colorsys.hls_to_rgb(.9, .5, 1,)}}, 36 | 'ltm-dccnn-rev': {"name": "WaveNet-512 (ltm)", "linestyle": {"color": colorsys.hls_to_rgb(.4, .35, 1,), "linestyle": _linestyle_densely_dashed}, "pointstyle": {"marker": r"$6$", "color": colorsys.hls_to_rgb(0.4, .35, 1,)}}, 37 | 'ltm-transformer-rel': {"name":"Transformer-512 (ltm)", "linestyle": {"color": colorsys.hls_to_rgb(.6, .3, 1,), "linestyle": _linestyle_densely_dashed}, "pointstyle": {"marker": r"$7$", "color": colorsys.hls_to_rgb(0.6, .3, 1,)}}, 38 | 'ltm-transformer-rel-32-rev': {"name": "Transformer-32 (ltm)", "linestyle": {"color": colorsys.hls_to_rgb(.6, .5, 1,), "linestyle": _linestyle_densely_dashed}, "pointstyle": {"marker": r"$8$", "color": colorsys.hls_to_rgb(0.6, .5, 1,)}}, 39 | #'ltm-transformer-lin': {"name":"Transformer-lin (ltm)", "linestyle": {"color": colorsys.hls_to_rgb(.6, .7, 1,), "linestyle": _linestyle_densely_dashed}, "pointstyle": {"marker": r"$10$", "color": colorsys.hls_to_rgb(0.6, .7, 1,)}} 40 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiable Short-Term Model (DSTM) 2 | This repository holds code for [Differentiable Short-Term Models](https://doi.org/10.5334/tismir.123): 3 | > ## Abstract 4 | > As pieces of music are usually highly self-similar, online-learning short-term models 5 | are well-suited for musical sequence prediction tasks. Due to their simplicity and 6 | interpretability, Markov chains (MCs) are often used for such online learning, with 7 | Prediction by Partial Matching (PPM) being a more sophisticated variant of simple 8 | MCs. PPM, also used in the well-known IDyOM model, constitutes a variable-order MC 9 | that relies on exact matches between observed *n*-grams and weights more recent 10 | events higher than those further in the past. We argue that these assumptions are 11 | limiting and propose the Differentiable Short-Term Model (DSTM) that is not limited 12 | to exact matches of *n*-grams and can also learn the relative importance of events. 13 | During (offline-)training, the DSTM learns representations of *n*-grams that are useful 14 | for constructing fast weights (that resemble an MC transition matrix) in online learning 15 | of *intra-opus* pitch prediction. We propose two variants: the Discrete Code Short- 16 | Term Model and the Continuous Code Short-Term Model. We compare the models to 17 | different baselines on the [*“TheSession“*](https://github.com/IraKorshunova/folk-rnn/) dataset and find, among other things, that 18 | the Continuous Code Short-Term Model has a better performance than Prediction by 19 | Partial Matching, as it adapts faster to changes in the data distribution. We perform 20 | an extensive evaluation of the models, and we discuss some analogies of DSTMs 21 | with linear transformers. 22 | ## Install 23 | ### pytorch 24 | It is recommended to obtain pytorch using the official [install instructions](https://pytorch.org/get-started). 25 | 26 | ### dstm 27 | Install the dstm package by running: 28 | 29 | `` 30 | pip install -e . 31 | `` 32 | ## `dstm.py` 33 | `dstm.py` is used for training and evaluating models. The program is build with [pytorch-ligthing](https://github.com/Lightning-AI/lightning) and supports the arguments of [pytorch_ligthing.trainer.trainer.Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags) (e.g, (multiple) GPUs)). 34 | 35 | 36 | For information on program arguments run: 37 | 38 | `` 39 | python dstm.py --help 40 | `` 41 | 42 | 43 | ### Training 44 | Training is on by default but can be switched off by using `--skip_train`. For instance, to reproduce the best performing Continous Code Short-Term Model from the paper run: 45 | 46 | `` 47 | python dstm.py --batch_size=16 --dropout=0.1 --encoder_output_dim=512 --early_stopping --filters 512 512 512 512 512 512 512 512 --lr=0.0001 --activation=selu dstm --short_term_method=elu 48 | `` 49 | 50 | To speed up training using GPU acceleration consider adding `` --accelerator=gpu --devices=0,1,2,3 --strategy=ddp`` and changing ``--batch_size=4``. 51 | 52 | When experiencing memory issues, try lowering the batch size (e.g., `--batch_size=4`). 53 | 54 | ### Evaluation 55 | Evaluation is on by default but can be switched off by using `--skip_test`. We provide two pretrained models from the paper: `out/session/model/ccstm.ckpt` and `out/session/model/dcstm.ckpt`. These are downloaded from the CDN on first run. The performances of the checkpoints can be evaluated by: 56 | `` 57 | python dstm.py --batch_size=16 --skip_train --checkpoint="out/session/model/ccstm.ckpt" --encoder_output_dim=512 --no_log dstm 58 | `` 59 | 60 | ## Issues 61 | Feel free to open an issue in case something does not work. 62 | -------------------------------------------------------------------------------- /dstm/model/slow_weight_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | 6 | class DilatedConvBlock(torch.nn.Module): 7 | def __init__(self, m, d, activation="relu", filters=[(128, 128, 128, 128)], dropout=0.0, residual=True): 8 | """[summary] 9 | 10 | Args: 11 | m (int): number of discrete codes 12 | d (int): dimension of output 13 | filters list(tuple): filters in first layer. List of tuples of number of filter and receptive field 14 | 15 | """ 16 | super().__init__() 17 | self.activation = activation 18 | self.residual =residual 19 | cnns = [] 20 | #self.r = 2 21 | self.dropout = torch.nn.Dropout(p=dropout) 22 | if activation == "gated_activation": 23 | # Gated activation 24 | cnns.append(torch.nn.Conv2d( 25 | 1, 2*filters[0], kernel_size=(2, d))) 26 | for i in range(1, len(filters)): 27 | cnns.append(torch.nn.Conv2d( 28 | filters[i-1], 2*filters[i], dilation=(2**i, 1), kernel_size=(2, 1))) 29 | cnns.append(torch.nn.Conv2d( 30 | filters[-1], 2*m, dilation=(2**(len(filters)), 1), kernel_size=(2, 1))) 31 | else: 32 | cnns.append(torch.nn.Conv2d( 33 | 1, filters[0], kernel_size=(2, d))) 34 | for i in range(1, len(filters)): 35 | #TODO: readded bug 36 | #for i in range(1, len(filters) - 1): 37 | cnns.append(torch.nn.Conv2d( 38 | filters[i-1], filters[i], dilation=(2**i, 1), kernel_size=(2, 1))) 39 | cnns.append(torch.nn.Conv2d( 40 | filters[-1], m, dilation=(2**(len(filters)), 1), kernel_size=(2, 1))) 41 | 42 | self.cnns = torch.nn.ModuleList(cnns) 43 | 44 | def forward(self, ss): 45 | """[summary] 46 | 47 | Args: 48 | ss (Tensor): BxTxd sequence of T previous tokens 49 | 50 | Returns: 51 | probs (Tensor): Bx(T+1)xd sequence of probs/one-hot 52 | logits (Tensor): Bx(T+1)xd sequence of logits 53 | """ 54 | in_ = ss.unsqueeze(1) 55 | if self.activation == "gated_activation": 56 | #Gated activation 57 | for i, cnn in enumerate(self.cnns): 58 | residual = in_ 59 | if i < len(self.cnns) -1: 60 | in_ = F.pad(in_, (0, 0, cnn.dilation[0], 0, 0, 0)) 61 | else: 62 | in_ = F.pad(in_, (0, 0, self.cnns[-1].dilation[0] + 1, 0, 0, 0)) 63 | in_ = self.dropout(in_) 64 | in_ = cnn(in_) 65 | n_filters = in_.shape[1] // 2 66 | ouput = torch.tanh(in_[:, :n_filters, :, :]) 67 | gates = torch.sigmoid(in_[:, n_filters:, :, :]) 68 | in_ = ouput * gates 69 | if i > 0 and i < len(self.cnns) -1 and self.residual: 70 | in_ += residual 71 | else: 72 | if self.activation == "relu": 73 | activation_fn = torch.nn.ReLU() 74 | elif self.activation == "selu": 75 | activation_fn = torch.nn.SELU() 76 | else: 77 | activation_fn = torch.nn.Identity() 78 | for i, cnn in enumerate(self.cnns[:-1]): 79 | residual = in_ 80 | in_ = F.pad(in_, (0, 0, cnn.dilation[0], 0, 0, 0)) 81 | in_ = self.dropout(in_) 82 | in_ = activation_fn(cnn(in_)) 83 | if i > 0 and self.residual: 84 | in_ = in_ + residual 85 | in_ = F.pad(in_, (0, 0, self.cnns[-1].dilation[0] + 1, 0, 0, 0)) 86 | in_ = self.dropout(in_) 87 | in_ = self.cnns[-1](in_) 88 | logits = in_.squeeze(-1) 89 | logits = logits.permute(0, 2, 1) 90 | return logits -------------------------------------------------------------------------------- /dstm/model/ppm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rpy2.robjects.packages import importr 3 | from rpy2.robjects.vectors import IntVector, FloatVector 4 | import numpy as np 5 | from dstm.util.load_data import EssenPreprocessing, SessionPreprocessing 6 | from dstm.util.metrics import Metrics 7 | import copy 8 | from pathlib import Path 9 | import pickle 10 | import functools 11 | import tqdm.contrib.concurrent 12 | import tqdm 13 | ppm = importr("ppm") 14 | EPOCHS = 1 15 | LOOP_DATA = False 16 | class PPMTrainer: 17 | def __init__(self, type_="simple", io=True, dataPreprocessor=None,**kwargs): 18 | self.dataPreprocessor = dataPreprocessor 19 | self.global_time = 1 20 | #maybe max order should be receptive field of compared model 21 | self.type = type_ 22 | self.io = io 23 | alphabet_size = self.dataPreprocessor.d 24 | shared_args = {"alphabet_size":alphabet_size, **kwargs} 25 | if type_ == "simple": 26 | self.model_gen = lambda: ppm.new_ppm_simple(**shared_args) 27 | elif type_ == "decay": 28 | self.model_gen = lambda: ppm.new_ppm_decay(**shared_args, ltm_half_life = 2) 29 | else: 30 | raise NotImplementedError("type is not allowed") 31 | self.model = self.model_gen() 32 | 33 | def one_hot_to_r(self, seq): 34 | return (1 + seq.argmax(dim=-1)).tolist() 35 | def forward(self, batch): 36 | nll = [] 37 | sum_of_lengths = 0 38 | preds = [] 39 | labels = [] 40 | if self.io: 41 | self.model = self.model_gen() 42 | for seq, length in zip(batch[0], batch[1]): 43 | seq_r = self.one_hot_to_r(seq[:length]) 44 | if self.type == "decay": 45 | #time = [i + 0.0 for i in range(1,length + 1)] 46 | time = list(range(self.global_time, self.global_time + length)) 47 | self.global_time += length 48 | res = ppm.model_seq(model=self.model, seq=seq_r, time=IntVector(time), train=True) 49 | ic = res[4] 50 | props = res[6] 51 | elif self.type == "simple": 52 | res = ppm.model_seq(model=self.model, seq=seq_r, train=True) 53 | ic = res[2] 54 | props = res[4] 55 | #information content 56 | nll += ic 57 | sum_of_lengths += length 58 | labels += seq_r 59 | #hard (use max) 60 | #preds += (torch.multinomial(torch.tensor(np.array(props)), 1) +1).view(-1).tolist() 61 | preds += list(np.array(props).argmax(axis=-1)+1) 62 | return nll, preds, labels, sum_of_lengths 63 | def train(self): 64 | batch_size = 16 65 | train_loader = self.dataPreprocessor.get_data_loader('train', 66 | batch_size=batch_size, 67 | num_workers=0, 68 | shuffle=True 69 | ) 70 | for _ in range(EPOCHS): 71 | for i, batch in enumerate(train_loader): 72 | nll, _, _, sum_of_lengths = self.forward(batch) 73 | print("train step {}/{}, nll: {}".format(i, len(train_loader.dataset)//batch_size+ 1, nll/sum_of_lengths)) 74 | def predict_all(self, split, max_workers): 75 | nll_all = [] 76 | sum_of_lengths_all = 0 77 | preds_all = [] 78 | labels_all = [] 79 | loader = self.dataPreprocessor.get_data_loader(split, 80 | batch_size=1, 81 | num_workers=0, 82 | shuffle=False 83 | ) 84 | 85 | for i, batch in tqdm.tqdm(enumerate(loader)): 86 | if np.prod(batch[0].shape): 87 | nll, preds, labels, sum_of_lengths = self.forward(batch) 88 | nll_all += nll 89 | sum_of_lengths_all += sum_of_lengths 90 | preds_all += preds 91 | labels_all += labels 92 | else: 93 | print("No elements at position {}".format(i)) 94 | return preds_all, nll_all, labels_all, sum_of_lengths_all 95 | # def proc_one(i, batch): 96 | # if np.prod(batch[0].shape): 97 | # #nll, preds, labels, sum_of_lengths = self.forward(batch) 98 | # return self.forward(batch) 99 | # else: 100 | # print("No elements at position {}".format(i)) 101 | # # nll, preds, labels, sum_of_lengths = zip(*tqdm.contrib.concurrent.process_map(proc_one, range(len(loader)), loader)) 102 | # # preds_all = functools.reduce(lambda a, b: a+b, nll) 103 | # # nll_all = functools.reduce(lambda a, b: a+b, preds) 104 | # packed_res = zip(*tqdm.contrib.concurrent.process_map(proc_one, range(len(loader)), loader), max_workers=max_workers) 105 | # reduced = map(lambda x: functools.reduce(lambda a, b: a+b, x), packed_res) 106 | # return reduced 107 | 108 | def validate(self, split): 109 | preds_all, nll_all, labels_all, sum_of_lengths_all = self.predict_all(split, max_workers=None) 110 | return sum(nll_all)/sum_of_lengths_all, *Metrics.tp_stats(labels=labels_all, predictions=preds_all) #Metrics.precision(labels=labels_all, predictions=preds_all) 111 | -------------------------------------------------------------------------------- /dstm/model/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from pathlib import Path 4 | from dstm.util.load_data import SessionPreprocessing 5 | from dstm.util.constants import Constants 6 | from scipy.stats import norm 7 | import numpy as np 8 | import tqdm 9 | class BaseLine: 10 | # abstract model 11 | #pre_processing = {} 12 | train_loaders = {} 13 | test_loaders = {} 14 | def __init__(self, dataset): 15 | self.dataset = dataset 16 | if dataset in BaseLine.train_loaders: 17 | pass 18 | elif dataset == "session": 19 | pre_processing = SessionPreprocessing(loop=False) 20 | BaseLine.train_loaders[self.dataset] = pre_processing.get_data_loader('train', shuffle=False) 21 | BaseLine.test_loaders[self.dataset] = pre_processing.get_data_loader('test', shuffle=False) 22 | else: 23 | raise NotImplementedError("dataset {} is not implimented".format(dataset)) 24 | @property 25 | def train_loader(self): 26 | if self.dataset in BaseLine.train_loaders: 27 | pass 28 | elif self.dataset == "session": 29 | preprocessing = SessionPreprocessing(loop=False) 30 | BaseLine.train_loaders[self.dataset] = preprocessing.get_data_loader('train', shuffle=False) 31 | return BaseLine.train_loaders[self.dataset] 32 | 33 | @property 34 | def test_loader(self): 35 | if self.dataset in BaseLine.test_loaders: 36 | pass 37 | elif self.dataset == "session": 38 | preprocessing = SessionPreprocessing(loop=False) 39 | BaseLine.test_loaders[self.dataset] = preprocessing.get_data_loader('test', shuffle=False) 40 | return BaseLine.test_loaders[self.dataset] 41 | 42 | def get_model_folder(self): 43 | folder = 'out/{}/model/{}'.format(self.dataset, self.model) 44 | Path(folder).mkdir(parents=True, exist_ok=True) 45 | return folder 46 | 47 | class Prior(BaseLine): 48 | def __init__(self, **kwargs): 49 | self.model = "Prior" 50 | super().__init__(**kwargs) 51 | self.model_file = "{}/{}".format(self.get_model_folder(), "Prior.p") 52 | def _save_model(self): 53 | torch.save(self.params, self.model_file) 54 | def fit(self): 55 | data = torch.cat([self.train_loader.dataset]).argmax(-1).tolist() 56 | d = self.train_loader.dataset[0].shape[-1] 57 | self.params = plt.hist(data, bins=list(range(0, d+1)), align='left', density=True)[0] 58 | self._save_model() 59 | def load_model(self): 60 | self.params = torch.load(self.model_file) 61 | def nll(self): 62 | d = self.test_loader.dataset[0].shape[-1] 63 | data = torch.cat([*self.test_loader.dataset]).view(-1, d) 64 | logprobs = torch.log(torch.FloatTensor(self.params)).view(d, 1) 65 | return - (data.mm(logprobs)).mean() 66 | class IOPrior(BaseLine): 67 | def __init__(self, **kwargs): 68 | self.model = "PriorIO" 69 | super().__init__(**kwargs) 70 | def predict(self): 71 | d = self.test_loader.dataset[0].shape[-1] 72 | tps = [] 73 | logprobs = [] 74 | predictions = [] 75 | for x in self.test_loader.dataset: 76 | if len(x) == 0: 77 | continue 78 | predictions.append(-1) 79 | tps.append(False) 80 | hist = torch.zeros(d) 81 | hist[x[0].argmax()] = 1 82 | logprobs.append(np.log(Constants.zero_prob_add)) 83 | for i, symbol in enumerate(x[1:], 1): 84 | #piece = x.argmax(-1) 85 | #hist = np.histogram(piece[:i], bins=list(range(0, d+1)), density=True)[0] 86 | #logprob = torch.log(x.mm(torch.FloatTensor(hist).view(d, 1))) 87 | prop = hist[hist.argmax(-1)]/hist.sum(-1) 88 | logprob = np.log(prop) if prop > 0 else np.log(Constants.zero_prob_add) 89 | logprobs.append(logprob.item()) 90 | prediction = hist.argmax(-1) 91 | hist[symbol.argmax(-1)] += 1 92 | predictions.append(prediction.item()) 93 | tps.append((symbol.argmax(-1) == prediction).item()) 94 | return predictions, logprobs, tps 95 | def nll_precision(self): 96 | _, logprobs, tps = self.predict() 97 | return - torch.tensor(logprobs).mean(), torch.tensor(tps).float().mean() 98 | class Repetition(BaseLine): 99 | def __init__(self, **kwargs): 100 | self.model = "Repetition" 101 | super().__init__(**kwargs) 102 | def predict(self): 103 | tps = [] 104 | predictions = [] 105 | logprobs = [] 106 | for x in tqdm.tqdm(self.test_loader.dataset): 107 | if len(x) == 0: 108 | continue 109 | piece = x.argmax(-1) 110 | prediction = torch.zeros_like(piece) 111 | prediction[0] = -1 112 | prediction[1:] = piece[:-1] 113 | predictions.append(prediction) 114 | tp = prediction == piece 115 | tps.append(tp.float()) 116 | #result = piece[:-1] == piece[1:] 117 | result = (~tp)* Constants.zero_prob_add + (tp) 118 | #d = x.shape[-1] 119 | #eps = 1e-7 120 | #result = (~result)*(1-(1/d+eps))/(d-1) + (1/d + eps )*(result) 121 | logprobs.append(torch.log(result)) 122 | 123 | return torch.cat(predictions), torch.cat(logprobs), torch.cat(tps) 124 | def nll(self): 125 | logprobs = [] 126 | for x in self.test_loader.dataset: 127 | piece = x.argmax(-1) 128 | result = piece[:-1] == piece[1:] 129 | result = (~result)* Constants.zero_prob_add + (result) 130 | logprobs.append(torch.log(result)) 131 | return - torch.cat(logprobs).mean() 132 | def precision(self): 133 | # tps = [] 134 | # for x in self.test_loader.dataset: 135 | # piece = x.argmax(-1) 136 | # result = piece[:-1] == piece[1:] 137 | # tps.append(result.float()) 138 | _, tps = self.predict() 139 | #return torch.cat(tps).mean() 140 | return tps.mean() 141 | 142 | def confidence_interval(p, n): 143 | bound = norm.ppf(0.975, loc=0, scale=1)*np.sqrt(p*(1-p)/n) 144 | return p - bound, p + bound -------------------------------------------------------------------------------- /dstm.py: -------------------------------------------------------------------------------- 1 | from dstm.model.short_term import DSTM 2 | from dstm.model.long_term import LTM 3 | from dstm.model.module import Module 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 6 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 7 | from dstm.util.load_data import SessionPreprocessing 8 | from pytorch_lightning.loggers import WandbLogger 9 | from argparse import ArgumentParser 10 | import os 11 | import pathlib 12 | # class Tee(object): 13 | # def __init__(self, *files): 14 | # self.files = files 15 | # def write(self, obj): 16 | # for f in self.files: 17 | # f.write(obj) 18 | # f.flush() # If you want the output to be visible immediately 19 | # def flush(self) : 20 | # for f in self.files: 21 | # f.flush() 22 | # # we expect the bus_order to work 23 | 24 | if __name__ == '__main__': 25 | # freeze_support() 26 | parser = ArgumentParser() 27 | parser.add_argument('--batch_size', type=int, default=32, help="batch size") 28 | parser.add_argument('--n_train_workers', type=int, default=16, help="number of training data loader workers") 29 | parser.add_argument('--seq_max_length', type=int, default=float('inf'), help="truncate sequences at specified length") 30 | parser.add_argument('--checkpoint', type=str, help="initialize model with spefified checkpoint") 31 | #parser.add_argument('--dataset', type=str, choices=["essen", "nes", "session", "pop"], default="session", help="(experimental) use specified dataset") 32 | parser.add_argument('--dataset', type=str, choices=["session"], default="session", help="use specified dataset") 33 | parser.add_argument('--skip_train', action='store_true', default=False, help="skip training step") 34 | parser.add_argument('--skip_test', action='store_true', default=False, help="skip testing step") 35 | #parser.add_argument('--dataset_size', type=str, choices=["small", "medium", "large"], default="large") 36 | parser.add_argument('--loop_data', action='store_true', default=False, help="(experimental) looping data augmentation") 37 | parser.add_argument('--job_id', type=int, help="inject job id (used for naming processes)") 38 | parser.add_argument('--early_stopping', action='store_true', default=True, help="when set, training will stop after no optimization progress for 3 epochs") 39 | parser.add_argument('--save_last', action='store_true', default=False, help="force saving of last model checkpoint (oterwise best model checkpoint will be kept)") 40 | parser.add_argument("--no_log", action='store_true', default=False, help="Skip wandb logging.") 41 | #parser.add_argument('--output_folder', type=str, default="out") 42 | subparsers = parser.add_subparsers(dest="subparser_name") 43 | parser = Module.add_model_specific_args(parser, subparsers) 44 | parser_cstm = subparsers.add_parser(name='dstm', help="subcommand for Differential Short-Term Models (DSTMs) models.") 45 | DSTM.add_model_specific_args(parser_cstm) 46 | parser_ltm = subparsers.add_parser(name='ltm', help="subcommand for Long-Term Models (LTMs) ") 47 | LTM.add_model_specific_args(parser_ltm) 48 | 49 | parser = pl.Trainer.add_argparse_args(parser) 50 | args = parser.parse_args() 51 | 52 | # #NOTE: reserve all gpus 53 | # reserve_gpus(args.gpus) 54 | 55 | #basedir = "out/{}/{}".format(args.dataset,args.dataset_size) 56 | basedir = "out/{}".format(args.dataset) 57 | 58 | 59 | #NOTE: get dataset 60 | if args.dataset == "session": 61 | dataPreprocessor = SessionPreprocessing(loop=args.loop_data, max_workers=args.n_train_workers if args.n_train_workers is not None else 1) 62 | dataPreprocessor.prepare_dataset() 63 | 64 | else: 65 | raise NotImplementedError("Dataset {} is not implimented".format(args.dataset)) 66 | d = dataPreprocessor.d 67 | 68 | model_type = args.subparser_name 69 | if args.checkpoint is not None: 70 | if model_type == "ltm": 71 | mc = LTM.load_from_checkpoint(args.checkpoint) 72 | else: 73 | mc = DSTM.load_from_checkpoint(args.checkpoint) 74 | 75 | else: 76 | if model_type == "ltm": 77 | mc = LTM(d, args) 78 | else: 79 | mc = DSTM(d, args) 80 | 81 | if args.job_id is None: 82 | job_id = os.getpid() 83 | else: 84 | job_id = args.job_id 85 | # TODO: add time_stamp 86 | 87 | #print_logs 88 | # for o in ['out', 'err']: 89 | # folder = '{}/{}'.format(basedir, o) 90 | # pathlib.Path(folder).mkdir(parents=True, exist_ok=True) 91 | # f = open('{}/{}-{}.{}'.format(folder, model_type, job_id, o), 'w') 92 | # n = "std" + o 93 | # vars(sys)[n] = Tee(vars(sys)[n], f) 94 | 95 | 96 | #logging 97 | if args.no_log: 98 | logger = False 99 | else: 100 | log_name = model_type 101 | logger = WandbLogger(project = 'DSTM') 102 | logger.log_hyperparams(args) 103 | 104 | #savemodel 105 | modeldir = "{}/model".format(basedir) 106 | pathlib.Path(modeldir).mkdir(parents=True, exist_ok=True) 107 | model_name = "{}-{}".format(model_type, job_id) + '-{epoch:02d}-{val_loss:.2f}-{val_precision:.2f}' 108 | checkpoint_callback = ModelCheckpoint( 109 | dirpath=modeldir, 110 | filename=model_name, 111 | save_last=args.save_last 112 | ) 113 | checkpoint_callback.CHECKPOINT_NAME_LAST = "{}-last".format(model_name) 114 | lr_monitor = LearningRateMonitor(logging_interval='step') 115 | # swa_callback = StochasticWeightAveraging( 116 | # swa_epoch_start=0., 117 | # #annealing_epochs=0, 118 | 119 | # ) 120 | callbacks = [ 121 | checkpoint_callback, 122 | # swa_callback, 123 | lr_monitor 124 | ] 125 | if args.early_stopping: 126 | callbacks.append(EarlyStopping(monitor='val_loss')) 127 | trainer = pl.Trainer.from_argparse_args( 128 | args, 129 | callbacks=callbacks, 130 | default_root_dir=basedir, 131 | logger=logger, 132 | #plugins=DDPPlugin(find_unused_parameters=False) 133 | ) 134 | train_loader = dataPreprocessor.get_data_loader('train', 135 | seq_max_length = args.seq_max_length, 136 | batch_size=args.batch_size, 137 | num_workers=args.n_train_workers, 138 | shuffle=True, 139 | pin_memory=True 140 | ) 141 | if args.seq_max_length < float('inf'): 142 | val_loader = dataPreprocessor.get_data_loader('valid', 143 | batch_size = 2, 144 | num_workers=0, 145 | pin_memory=True 146 | ) 147 | else: 148 | val_loader = dataPreprocessor.get_data_loader('valid', 149 | batch_size=args.batch_size, 150 | num_workers=0, 151 | pin_memory=True 152 | ) 153 | 154 | 155 | if not args.skip_train: 156 | print("Training") 157 | trainer.fit(mc, train_loader, val_loader) 158 | if not args.skip_test: 159 | if args.seq_max_length < float('inf'): 160 | test_loader = dataPreprocessor.get_data_loader('test', 161 | batch_size = 1, 162 | num_workers=0, 163 | pin_memory=True 164 | ) 165 | else: 166 | test_loader = dataPreprocessor.get_data_loader('test', 167 | batch_size=args.batch_size, 168 | num_workers=0, 169 | pin_memory=True 170 | ) 171 | print("Testing") 172 | trainer.validate(mc, test_loader) -------------------------------------------------------------------------------- /dstm/model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dstm.util.metrics import Metrics 3 | import pytorch_lightning as pl 4 | from argparse import ArgumentParser 5 | from dstm.util.constants import Constants 6 | import sys 7 | from dstm.model.coded_short_term_model import * 8 | import tqdm 9 | 10 | class Module(pl.LightningModule): 11 | def __init__(self, d, hparams): 12 | """ 13 | Args: 14 | d ([type]): [description] 15 | hparams (obj): namespace of hyper parameters 16 | """ 17 | super().__init__() 18 | if hparams.no_one_hot: 19 | def general_cross_entropy(logprobs, labels): 20 | return -(logprobs * labels).mean() 21 | self.loss = general_cross_entropy 22 | else: 23 | self.loss = torch.nn.NLLLoss(reduction='mean') 24 | self.d = d 25 | self.lr = hparams.lr 26 | self.save_hyperparameters('d', 'hparams') 27 | @staticmethod 28 | def add_model_specific_args(parent_parser, subparsers): 29 | parser = ArgumentParser(description="model.module.Model", parents=[parent_parser], add_help=False) 30 | parser.add_argument('--lr', type=float, default=1e-3, help="learning rate") 31 | parser.add_argument('--l2_regularization_strength', type=float, default=0.0, help="use encoder l2 regularization") 32 | parser.add_argument('--dropout', type=float, default=0.0, help="use dropout") 33 | parser.add_argument('--no_one_hot', action='store_true', default=False, help="(experimental) data is not one-hot encoded") 34 | parser.add_argument('--encoder', type=str, choices=["transformer_rel", "transformer_lin", "transformer_abs", "dccnn"], default="dccnn", help="encoder model to be used") 35 | #TODO should only be enabled when dcccn 36 | parser.add_argument('--filters', type=int, nargs="*", 37 | default=[1024, 1024, 1024, 1024, 1024, 1024]) 38 | parser.add_argument('--activation', type=str, choices=["relu", "selu", "gated_activation", "none"], default="selu", help="activation to be used for dccnn") 39 | #TODO: should only be enabled when transformer is used 40 | parser.add_argument('--transformer_n_head', type=int, default=7, help="number of heads to be used for multi-head attention in transformer encoders") 41 | parser.add_argument('--transformer_layers', type=int, default=6, help="number of causal transformer layers to be used") 42 | parser.add_argument('--encoder_output_dim', type=int, default=512, help="encoder output dimensionality") 43 | return parser 44 | 45 | def probs(self, batch): 46 | raise NotImplementedError("Abstract method") 47 | def shared_step(self, batch): 48 | if self.hparams.hparams.no_one_hot: 49 | labels = batch[0] 50 | else: 51 | labels = batch[0].argmax(dim=-1) 52 | probs, _ = self.probs(batch[0]) 53 | #In case of zero padding for differing lengths 54 | #TODO: could be more efficient with a mask... 55 | if len(batch) == 2: 56 | lengths = batch[1] 57 | labels_flattened_list = [] 58 | probs_flattened_list = [] 59 | for label, prob, length in zip(labels, probs, lengths): 60 | labels_flattened_list.append(label[:length]) 61 | probs_flattened_list.append(prob[:length]) 62 | labels_flattened = torch.cat(labels_flattened_list) 63 | probs_flattened = torch.cat(probs_flattened_list) 64 | else: 65 | labels_flattened = labels.reshape(-1) 66 | probs_flattened = probs.reshape(-1, self.d) 67 | #Add constant for zero prob 68 | probs_flattened += Constants.zero_prob_add*(probs_flattened == 0) 69 | logprobs_flattened = torch.log(probs_flattened) 70 | loss = self.loss(logprobs_flattened, labels_flattened) 71 | return probs_flattened, loss, labels_flattened 72 | 73 | def forward(self, batch): 74 | return self.shared_step(batch) 75 | def training_step(self, batch, batch_idx): 76 | _, loss, _ = self.shared_step(batch) 77 | self.log('train_loss', loss) 78 | grad_abs_max = -1 79 | for param in filter(lambda x: x.requires_grad, self.parameters()): 80 | m = torch.max(torch.abs(param)) 81 | if m > grad_abs_max: 82 | grad_abs_max = m 83 | # sch = self.lr_schedulers() 84 | # if sch is not None: 85 | # sch.step() 86 | self.log('grad_abs_elem_max', grad_abs_max, 87 | on_step=True, prog_bar=True, logger=True) 88 | return loss 89 | 90 | def validation_step(self, batch, batch_idx): 91 | seq_length = batch[0].shape[1] 92 | if seq_length == 0: 93 | return torch.FloatTensor([]).type_as(batch[0]), torch.LongTensor([]) 94 | seq_max_length = self.hparams.hparams.seq_max_length 95 | #NOTE in this case batch_size is 1 for transformers 96 | if seq_length > seq_max_length and self.hparams.hparams.encoder != 'dccnn': 97 | probs_batch = [] 98 | labels_batch = [] 99 | for i in range(seq_max_length, seq_length + 1): 100 | #lengths = [max(0,min(seq_max_length, l-i)) for l in batch[1]] 101 | batch_small = batch[0][:, i-seq_max_length:i, :] 102 | labels = batch[0][:, i-seq_max_length:i, :].argmax(dim=-1) 103 | probs, _ = self.probs(batch_small) 104 | if i == seq_max_length: 105 | probs_batch.append(probs) 106 | labels_batch.append(labels) 107 | #break 108 | else: 109 | probs_batch.append(probs[:, -1:]) 110 | labels_batch.append(labels[:, -1:]) 111 | probs_batch = torch.cat(probs_batch, dim=1) 112 | labels_batch = torch.cat(labels_batch, dim=1) 113 | probs_flattened = [] 114 | labels_flattened = [] 115 | for length, prob, label in zip(batch[1], probs_batch, labels_batch): 116 | probs_flattened.append(prob[:length]) 117 | labels_flattened.append(label[:length]) 118 | probs_flattened = torch.cat(probs_flattened) 119 | labels_flattened = torch.cat(labels_flattened) 120 | else: 121 | probs_flattened, _, labels_flattened = self.shared_step(batch) 122 | return probs_flattened, labels_flattened 123 | 124 | 125 | def validation_epoch_end(self, outputs): 126 | unziped = list(zip(*outputs)) 127 | probs_flattened = torch.cat(unziped[0]) 128 | labels_flattened = torch.cat(unziped[1]) 129 | if self.hparams.hparams.no_one_hot: 130 | probs_flattened = probs_flattened 131 | else: 132 | predictions = probs_flattened.argmax(dim=-1) 133 | probs_flattened += Constants.zero_prob_add*(probs_flattened == 0) 134 | logprobs_flattened = torch.log(probs_flattened) 135 | loss = self.loss(logprobs_flattened, labels_flattened) 136 | precision, tp_std, num_obs = Metrics.tp_stats(labels=labels_flattened.cpu(), predictions=predictions.cpu()) 137 | 138 | self.log('val_precision', precision, sync_dist=True) 139 | self.log('tp_std', tp_std, sync_dist=True) 140 | self.log('val_loss', loss, sync_dist=True) 141 | 142 | def train_batch_end(self, trainer, pl_module, outputs): 143 | super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :) 144 | percent = (self.train_batch_idx / self.total_train_batches) * 100 145 | sys.stdout.flush() 146 | 147 | def configure_optimizers(self): 148 | #Only add l2 reg to weihgt parametersq 149 | optimizer = torch.optim.Adam( 150 | self.parameters(), 151 | lr=self.lr, 152 | weight_decay=self.hparams.hparams.l2_regularization_strength 153 | ) 154 | #NOTE: let's try SWA Instead... 155 | # scheduler = torch.optim.lr_scheduler.OneCycleLR( 156 | # optimizer, 157 | # max_lr=3.5e-4, 158 | # epochs=10, 159 | # steps_per_epoch=2400 160 | # ) 161 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 162 | # optimizer, 163 | # T_max=20 164 | # ) 165 | # lr_scheduler_config = { 166 | # 'scheduler': scheduler, 167 | # 'interval': "epoch", 168 | # 'frequency': 1, 169 | # 'name': None 170 | 171 | # } 172 | #return [optimizer], [lr_scheduler_config] 173 | return optimizer 174 | 175 | def add_weight_decay(net, l2_value, skip_list=()): 176 | decay, no_decay = [], [] 177 | for name, param in net.named_parameters(): 178 | if not param.requires_grad: continue # frozen weights 179 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 180 | no_decay.append(param) 181 | else: 182 | decay.append(param) 183 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}] -------------------------------------------------------------------------------- /dstm/model/short_term.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from dstm.model.slow_weight_models import DilatedConvBlock 4 | import numpy as np 5 | from dstm.model.coded_short_term_model import * 6 | from dstm.model.module import Module 7 | import argparse 8 | 9 | def restricted_float(min_, max_): 10 | def restricted_float_(x): 11 | try: 12 | x = float(x) 13 | except ValueError: 14 | raise argparse.ArgumentTypeError("{} not a floating-point literal".format(x,)) 15 | if x <= min_ or x > max_: 16 | raise argparse.ArgumentTypeError("{} not in range ({}, {}]".format(x, min_, max_)) 17 | return x 18 | return restricted_float_ 19 | 20 | class DSTM(Module): 21 | def __init__(self, d, hparams): 22 | """ 23 | Args: 24 | d ([type]): [description] 25 | hparams (obj): namespace of hyper parameters 26 | """ 27 | super().__init__(d, hparams) 28 | self.m = hparams.encoder_output_dim 29 | self.W_ = None 30 | self.probs_ = None 31 | if hparams.generate_target_pitch: 32 | m = self.m + d 33 | filters = np.array(hparams.filters) + d 34 | else: 35 | m = self.m 36 | filters = np.array(hparams.filters) 37 | self.slow_weight_model = DilatedConvBlock(m=m, d=d, filters=filters, activation=hparams.activation, dropout=hparams.dropout) 38 | def _create_lambda_with_globals(s): 39 | return eval(s, globals()) 40 | @staticmethod 41 | def add_model_specific_args(parser): 42 | parser.add_argument('--gumbel_temperature', type=restricted_float(0,float('inf')), default=1, help="temperature controlling the approximation of gumbel sampels to one-hot encoded categorical samples") 43 | parser.add_argument('--gumbel_hard', action='store_true', default=False, help="straight through gumbel softmax") 44 | parser.add_argument('--simulated_annealing', type=str, default='lambda gradient_step: max(0.8, -.7/90000*gradient_step + 1.5)', help="lambda function string representation of gumbel temperature as a function of number of gradient steps.") 45 | parser.add_argument('--short_term_method', type=str, choices=["gumbel", 46 | "softmax_transition_table", 47 | "softmax_normalize_after", 48 | "unbounded_h", 49 | "l2_normalized", 50 | "positive_h", 51 | "elu", 52 | "matching_network" 53 | ], default="softmax_normalize_after", help="short-term method") 54 | # TODO: for now only use weight sharing 55 | parser.add_argument('--generate_target_pitch', action='store_true', default=False, help="generate s_vector in fast weight frequency matrix") 56 | def transitions(self, batch): 57 | hard=self.hparams.hparams.gumbel_hard 58 | logits = self.slow_weight_model(batch) 59 | s_tar = batch 60 | if self.hparams.hparams.short_term_method == 'gumbel': 61 | if self.training: 62 | #h_src = torch.nn.functional.gumbel_softmax(logits, tau=self.hparams.hparams.gumbel_temperature, hard=hard, eps=1e-10, dim=2) 63 | #TODO: disabled hard in training. Hard is therefore only affecting evaluation. 64 | h_src = torch.nn.functional.gumbel_softmax(logits, tau=self.hparams.hparams.gumbel_temperature, hard=False, eps=1e-10, dim=2) 65 | else: 66 | if hard: 67 | h_src = torch.nn.functional.one_hot(logits.argmax(-1), num_classes=self.hparams.hparams.encoder_output_dim).float() 68 | #h_src = torch.nn.functional.gumbel_softmax(logits, tau=self.hparams.hparams.gumbel_temperature, hard=True, eps=1e-10, dim=2) 69 | else: 70 | # TODO: READD 71 | #h_src = torch.nn.functional.gumbel_softmax(logits, tau=self.hparams.hparams.gumbel_temperature, hard=False, eps=1e-10, dim=2) 72 | h_src = torch.nn.functional.softmax(logits/self.hparams.hparams.gumbel_temperature, dim=-1) 73 | elif self.hparams.hparams.short_term_method in ["softmax_transition_table", "softmax_normalize_after"]: 74 | logits /= self.hparams.hparams.gumbel_temperature 75 | h_src = torch.nn.functional.softmax(logits, dim=2) 76 | if hard: 77 | #greedy 78 | if not self.training: 79 | h_src = torch.nn.functional.one_hot(h_src.max(dim=-1)[1], num_classes=self.hparams.hparams.encoder_output_dim).float() 80 | else: 81 | og_shape = h_src.shape 82 | index = torch.multinomial(h_src.reshape(-1, self.m), 1).view(-1) 83 | h_src_hard = torch.nn.functional.one_hot(index, num_classes=self.hparams.hparams.encoder_output_dim).float() 84 | h_src_hard = h_src_hard.view(og_shape) 85 | h_src = h_src_hard - h_src.detach() + h_src 86 | elif self.hparams.hparams.short_term_method in ['unbounded_h', "matching_network"]: 87 | h_src = logits 88 | elif self.hparams.hparams.short_term_method == 'positive_h': 89 | #used in combination with ordinary normalization 90 | h_src = torch.nn.functional.relu(logits) 91 | elif self.hparams.hparams.short_term_method == 'l2_normalized': 92 | #used in combination with ordinary normalization 93 | positive = torch.nn.functional.relu(logits) 94 | h_src = F.normalize(positive, p=2, dim=-1) 95 | elif self.hparams.hparams.short_term_method == 'elu': 96 | h_src = torch.nn.functional.elu(logits) + 1 97 | else: 98 | raise NotImplementedError('short_term method: {} is not implimented'.format(self.hparams.hparams.short_term_method)) 99 | if self.hparams.hparams.generate_target_pitch: 100 | s_tar = torch.nn.functional.softmax(logits[:, :, -self.d:], dim=2)[:, 1:, :] 101 | return h_src, s_tar 102 | 103 | 104 | def probs(self, batch): 105 | """Returns the probabilities for full sequence 106 | 107 | Args: 108 | batch ([type]): [description] 109 | 110 | Returns: 111 | [type]: [description] 112 | """ 113 | probs = torch.empty(batch.shape).type_as(batch) 114 | h_src, s_tar = self.transitions(batch) 115 | #W = 1/(self.d)*torch.ones(batch.shape[0], self.m, self.d).type_as(batch) 116 | if self.hparams.hparams.generate_target_pitch: 117 | probs = matching_network(h_src, s_tar) 118 | return probs, None 119 | else: 120 | alpha = (self.d*100) 121 | W = 1/alpha * \ 122 | torch.ones(batch.shape[0], self.m, self.d).type_as(batch) 123 | for i in range(batch.shape[1]): 124 | prob, W = step(h_src[:, i, :], s_tar[:, i, :], W, self.hparams.hparams.short_term_method) 125 | probs[:, i, :] = prob 126 | return probs, W 127 | 128 | def generate_(self, samples, W, start_i, s_tar, update_W=True, hard=False): 129 | with torch.no_grad(): 130 | for i in range(start_i, samples.shape[1]): 131 | #TODO: avoid recalc 132 | h_src, _ = self.transitions(samples, hard=hard) 133 | prob, W_new = step(h_src[:, i, :], s_tar, W, self.hparams.hparams.short_term_method) 134 | W = W_new if update_W else W 135 | s_tar = torch.multinomial(prob, num_samples=1).squeeze(-1) 136 | s_tar = torch.nn.functional.one_hot(s_tar, num_classes=self.d).float() 137 | samples[:, i] = s_tar 138 | return samples 139 | 140 | def generate(self, init_tone, W, steps, hard=False): 141 | samples = torch.empty(W.shape[0], steps, self.d, dtype=torch.float32) 142 | s_tar = init_tone 143 | samples[:,0] = init_tone 144 | 145 | return self.generate_(samples, W, 1, s_tar, hard) 146 | def generate_prime(self, piece, steps, n_samples, update_W, hard=False): 147 | with torch.no_grad(): 148 | _, W_init = self.probs(piece.unsqueeze(0)) 149 | W_init = W_init.detach() 150 | shape = list(W_init.shape) 151 | shape[0] = n_samples 152 | length = piece.shape[0] 153 | W = torch.empty(shape, dtype=torch.float32) 154 | samples = torch.empty(n_samples, length + steps, piece.shape[1]) 155 | for i in range(n_samples): 156 | W[i] = W_init.clone() 157 | samples[i,:length] = piece.detach().clone() 158 | s_tar = samples[:, -1] 159 | samples = self.generate_(samples, W, length, s_tar, hard) 160 | return samples 161 | def training_step(self, batch, batch_idx): 162 | loss = super().training_step(batch, batch_idx) 163 | self.log('annealing_temperature', torch.as_tensor(self.hparams.hparams.gumbel_temperature)) 164 | #NOTE: Simulated anealing 165 | if self.hparams.hparams.simulated_annealing: 166 | if self.global_step % 500 == 0: 167 | simmulated_annealing = DSTM._create_lambda_with_globals(self.hparams.hparams.simulated_annealing) 168 | self.hparams.hparams.gumbel_temperature = simmulated_annealing(self.global_step) 169 | return loss -------------------------------------------------------------------------------- /dstm/util/plotting.py: -------------------------------------------------------------------------------- 1 | from dstm.util import add_centered 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from matplotlib.colors import ListedColormap 5 | from dstm.util.constants import Constants 6 | import torch 7 | from pathlib import Path 8 | from dstm.model.baseline import confidence_interval 9 | import matplotlib 10 | from dstm.util.constants import Constants 11 | import tqdm 12 | plt.rcParams.update(Constants.matplotlib_rcparams) 13 | 14 | from dstm.util import add_centered 15 | #fontsize = 16 16 | #titlefontsize = 20 17 | 18 | def plot_piano_roll(piano_roll, colors = ["w","g"],time_sig="4/4"): 19 | bounds = np.arange(len(colors)+1) 20 | bounds = np.array(bounds) -0.5 21 | cmap = ListedColormap(colors) 22 | norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N) 23 | fig, ax = plt.subplots(figsize=(Constants.text_height-0.1, Constants.text_height/4)) 24 | ax.set_yticks(np.arange(0.5, 48, 12)) 25 | ax.set_yticklabels(np.arange(48, 96, 12)) 26 | if time_sig == "4/4": 27 | ticks_pr_2 = 16*2 28 | elif time_sig == "3/4": 29 | ticks_pr_2 = 12*2 30 | else: 31 | raise NotImplementedError("time signature not implimented.") 32 | x_ticks = np.arange(-0.5, len(piano_roll), ticks_pr_2) 33 | ax.set_xticks(x_ticks) 34 | 35 | ax.set_xticklabels((x_ticks + 0.5).astype(int)) 36 | ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=ticks_pr_2)) 37 | ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=12)) 38 | ax.grid(which='both', linewidth=0.1) 39 | 40 | plt.imshow(piano_roll.T, cmap=cmap, norm=norm, interpolation="none", origin="lower", aspect="equal") 41 | plt.xlabel("time ($t$)") 42 | plt.ylabel("\\textit{pitch}") 43 | fig.tight_layout() 44 | return fig, ax 45 | def plot_prediction(label, prediction, d, window=(None, None), time_sig="4/4"): 46 | eye = np.eye(d) 47 | preds_one_hot = eye[:, prediction] 48 | summed = preds_one_hot + 2*eye[:, label] 49 | colors = ['w', 'r', 'g', 'g'] 50 | piano_roll = summed[:, slice(window[0], window[1])].T 51 | fig, ax = plot_piano_roll(piano_roll, colors,time_sig=time_sig) 52 | plt.colorbar(ticks=[1., 2.], shrink=0.2, boundaries=[.5, 1.5, 2.5], format=matplotlib.ticker.FuncFormatter(lambda val, pos: "Error" if val == 1. else "True"), orientation="horizontal",pad=0.3) 53 | return fig, ax 54 | 55 | 56 | def plot_prediction_dmc(mc, sample, window=(None, 57 | None), time_sig="4/4"): 58 | with torch.no_grad(): 59 | sample = sample.unsqueeze(0) 60 | probs = mc.eval().probs(sample) 61 | preds = probs[0].argmax(-1).cpu().squeeze() 62 | d = mc.d 63 | return plot_prediction(sample.argmax(-1)[0].detach().cpu().numpy(), preds, d, window, time_sig=time_sig) 64 | 65 | def plot_saliency_map(mc, piano_roll, time, output): 66 | # freeze weights 67 | piano_roll = piano_roll.unsqueeze(0) 68 | for param in mc.parameters(): 69 | param.requires_grad = False 70 | piano_roll.requires_grad = True 71 | out1 = mc.probs(piano_roll, hard=False)[0, time, output] 72 | piano_roll.grad = None 73 | out1.backward() 74 | plt.matshow(piano_roll.grad.data[0][:time].T) 75 | 76 | def compute_aggregate_saliency_map(mc, data_loader, path="out/session/large/results", device="cpu", method="log", model="Continuous", relative_pitch=True): 77 | # freeze weights 78 | for param in mc.parameters(): 79 | param.requires_grad = False 80 | r = 2**(len(mc.hparams.hparams.filters) + 1) 81 | if relative_pitch: 82 | height_saliency_map = mc.d * 2 -1 83 | else: 84 | height_saliency_map = mc.d 85 | 86 | saliency_map = torch.zeros(r, height_saliency_map).to(device) 87 | for batch_id, (batch, lengths) in enumerate(tqdm.tqdm(data_loader, desc=f'Computing aggregate saliency map {model}')): 88 | if batch.shape[1] == 0: 89 | continue 90 | batch = batch.to(device) 91 | batch.requires_grad = True 92 | if method == "log": 93 | probs, _ = torch.log(mc.probs(batch)) 94 | elif method == "normal": 95 | probs, _ = mc.probs(batch) 96 | else: 97 | raise NotImplementedError("method {} is not defined") 98 | n_grads_pr_piece = 10 99 | # def jac(batch): 100 | # probs, _ = mc.probs(batch) 101 | # max = probs.max(-1) 102 | # time_indexes = torch.randperm(batch.shape[1]) 103 | # #ziped_indexes = torch.cat([torch.arange(batch.).unsqueeze(0), indexes.unsqueeze(0)]).tolist() 104 | # return max[:,time_indexes[:n_grads_pr_piece]] 105 | # jac = torch.autograd.functional.jacobian(jac, batch) 106 | # if relative_pitch: 107 | # pitch_index = slice(mc.d-output-1, 2*mc.d - output-1) 108 | # else: 109 | # pitch_index = slice(None) 110 | # saliency_map[max((r-time), 0):, pitch_index] += jac.sum((0,1))[max(0, time-r):time, :] 111 | for prob, length in zip(probs, lengths): 112 | 113 | 114 | #time_pitch = set([]) 115 | # for i in range(n_grads_pr_piece): 116 | # while True : 117 | # time = np.random.random_integers(0, length-1) 118 | # #output = np.random.random_integers(0, mc.d-1) 119 | # #if (time, output) not in time_pitch: 120 | # if time not in time_pitch: 121 | # break 122 | # #time_pitch.add((time,output)) 123 | # time_pitch.add((time)) 124 | for time in np.random.choice(np.arange(0, length, 1), min(n_grads_pr_piece, length.item()), replace=False): 125 | #pick always the note which is predicted 126 | output = prob[time].argmax(-1) 127 | out1 = prob[time, output] 128 | batch.grad = None 129 | out1.backward(retain_graph=True) 130 | summed = torch.abs(batch.grad.data).sum(axis=0) 131 | add_centered(saliency_map, summed, pitch_relative_to=output, time_relative_to=time, relative_pitch=relative_pitch) 132 | torch.save({"saliency_map": saliency_map.cpu(), "r":r, "d": mc.d}, "{}/saliency_map_{}_{}{}.p".format(path, method, model,"_rp" if relative_pitch else "")) 133 | def plot_aggregate_saliency_map(path="fig",method="log",model="Continuous", relative_pitch=True): 134 | data = torch.load("{}/saliency_map_{}_{}{}.p".format(path, method, model,"_rp" if relative_pitch else "")) 135 | r = data["r"] 136 | d = data["d"] 137 | saliency_map = data["saliency_map"] 138 | # for t_name, t, pdf_path in [("Sensitivity salency map (log transformed)", torch.log, "{}/saliency_map_{}_t_{}_{}.pdf".format(path, method, "log", model)), 139 | # ("Sensitivity salency map",lambda x: x, "{}/saliency_map_{}_{}.pdf".format(path, method, model))]: 140 | # fig, ax = plt.subplots(1, 1) 141 | # plt.imshow(t(saliency_map.T).cpu()) 142 | # #plt.imshow(saliency_map.T.cpu()) 143 | # ax.set_xticks(list(range(0, r-1, int(r/(16*1))))) 144 | # ax.set_xticklabels(list(range(r,1, int(-r/(16*1))))) 145 | # #plt.title(t_name, fontsize=titlefontsize) 146 | # #plt.xlabel("\\textit{time lag}", fontsize=fontsize) 147 | # #plt.ylabel("\\textit{relative pitch}", fontsize=fontsize) 148 | # plt.title(t_name) 149 | # plt.xlabel("\\textit{time lag}") 150 | # plt.ylabel("\\textit{relative pitch}") 151 | # y_ticks = np.arange(0, 2*(d) - 1, 16) 152 | # ax.set_yticks(y_ticks) 153 | # ax.set_yticklabels(y_ticks - d + 1) 154 | # #fig.savefig(pdf_path) 155 | #fig, ax = plt.subplots(figsize=(5.78, 4 )) 156 | #fig, ax = plt.subplots(figsize=(6, 2.5 )) 157 | fig, ax = plt.subplots(figsize=(Constants.text_width, Constants.text_width/2)) 158 | #ax.xaxis.tick_top() 159 | # #plt.imshow(saliency_map.T.cpu()) 160 | ax.set_xticks(list(range(0, r-1, int(r/(16*1))))) 161 | ax.set_xticklabels(list(range(r,1, int(-r/(16*1))))) 162 | #ax.xaxis.grid(True, which='minor') 163 | ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) 164 | ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) 165 | #ax.tick_params(which='both', width=2) 166 | ax.tick_params(which='major', length=7) 167 | ax.tick_params(which='minor', length=4) 168 | 169 | # ax.set_xticks( np.arange(int(r/16-1), r, int(r/(16*1))) ) 170 | # ax.set_xticklabels(np.arange(int(r - r/16) + 1, 0, int(-r/(16*1)))) 171 | 172 | #plt.title(t_name, fontsize=titlefontsize) 173 | #plt.xlabel("\\textit{time lag}", fontsize=fontsize) 174 | #plt.ylabel("\\textit{relative pitch}", fontsize=fontsize) 175 | plt.title("\\textbf{{{} sensitivity}}".format(Constants.styles[model]['name'][:-6])) 176 | #plt.title("DCSTM sensitivity".format(model)) 177 | plt.xlabel("time lag") 178 | if relative_pitch: 179 | plt.ylabel("relative pitch") 180 | y_ticks = np.arange(0, 2*(d) - 1, 24) 181 | ax.set_yticks(y_ticks) 182 | ax.set_yticklabels(y_ticks - d + 1) 183 | else: 184 | plt.ylabel("\\textit{absolute pitch}") 185 | y_ticks = np.arange(1, d, 12) 186 | ax.set_yticks(y_ticks) 187 | ax.set_yticklabels(y_ticks+ 1) 188 | im = ax.imshow(saliency_map.T.cpu() + 1e-5, cmap="jet", interpolation=None, norm=matplotlib.colors.LogNorm(),origin="lower" ) 189 | fig.tight_layout() 190 | 191 | fig.colorbar(im, orientation="horizontal",pad=0.25) 192 | #ax.set_xlim(511-290,511-280) 193 | fig.show() 194 | pdf_path = "fig/saliency_map_{}_{}{}.pdf".format(method, model, "_rp" if relative_pitch else "") 195 | fig.savefig(pdf_path) 196 | 197 | def p_confidence_interval(ps, n): 198 | bounds = list(map(lambda p: confidence_interval(p, n)[1] - p, ps)) 199 | plt.errorbar(x = list(range(len(ps))), 200 | y=np.array([ps]).T, 201 | yerr=np.array(bounds), 202 | capsize=5.0, 203 | marker='o', 204 | linestyle="" 205 | ) -------------------------------------------------------------------------------- /dstm/model/mc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pathlib import Path 4 | import pickle 5 | import random 6 | import numpy as np 7 | from dstm.util.metrics import Metrics 8 | from dstm.util.constants import Constants 9 | import tqdm 10 | class MarkovChain: 11 | def __init__(self, d, order): 12 | """[summary] 13 | 14 | Args: 15 | d (int): dimension of output 16 | order (int): order of markov chain 17 | """ 18 | self.d = d 19 | self.order = order 20 | #self.transition_matrix = torch.zeros((order + 1) * [d]) 21 | self.transition_matrix = torch.sparse.FloatTensor(*((order + 1) * [d])) 22 | #self.transition_matrix.spadd() 23 | def add(transition_matrix, context_index, label_index): 24 | indices = tuple(list(context_index) + [label_index]) 25 | #transition_matrix[indices] = transition_matrix[indices] + 1 26 | #note: bug in add_ doesn't update the object 27 | #transition_matrix.spadd(1) 28 | # might be super ineficient but sparse api seems very beta, and adding 1 i 29 | transition_matrix = transition_matrix + torch.sparse.FloatTensor(torch.LongTensor(indices).reshape(-1,1), torch.FloatTensor([1]), transition_matrix.shape) 30 | return transition_matrix#= transition_matrix[indices] + 1 31 | def fit(self, dataset): 32 | """[summary] 33 | 34 | Args: 35 | dataset (Tensor): NxTxd 36 | """ 37 | for context_batch, label_index_batch in ContextLabelIterator(dataset, self.order): 38 | context_index_batch = context_batch.argmax(dim = -1) 39 | for context, label in zip(context_index_batch, label_index_batch): 40 | self.transition_matrix = MarkovChain.add(self.transition_matrix, context, label) 41 | #if self.transition_matrix.device.type is not "cpu": 42 | # TODO: might be needed to sum indices. For now do nothing as 43 | self.transition_matrix = self.transition_matrix.coalesce() 44 | #samples = torch.argmax(dataset, dim = -1) 45 | # for piece in samples: 46 | # for i in range(len(piece) - self.order): 47 | # e = tuple(piece[i:(i+self.order+1)]) 48 | # self.transition_matrix[e] = self.transition_matrix[e] + 1 49 | # Only normalize when predicting 50 | #self.transition_matrix = F.normalize(self.transition_matrix, p=1, dim=-1) 51 | def save_model(self): 52 | folder = 'out/model/mc' 53 | path = '{}/mc_o{}.pkl'.format(folder, self.order) 54 | Path(folder).mkdir(parents=True, exist_ok=True) 55 | with open(path, 'wb') as f: 56 | pickle.dump([self.d, self.order, self.transition_matrix], f) 57 | 58 | def load_model(order): 59 | mc = MarkovChain(0, order) 60 | path = 'out/model/mc/mc_o{}.pkl'.format(order) 61 | with open(path, 'rb') as f: 62 | mc.d, mc.order, mc.transition_matrix = pickle.load(f) 63 | return mc 64 | def forward1D(self, context): 65 | """One-step dist 66 | 67 | Args: 68 | context (Tensor): Txd 69 | """ 70 | indices = torch.argmax(context, dim=-1) 71 | #return self.transition_matrix[tuple(indices)] 72 | #dense = self.transition_matrix[tuple(indices)].to_dense() 73 | row = self.transition_matrix[tuple(indices)].to_dense() 74 | normalized = F.normalize(row, p=1, dim=-1) 75 | return normalized 76 | def forward(self, context_batch): 77 | return torch.stack(list(map(self.forward1D, context_batch)), dim = 0) 78 | def predict1D(self, context): 79 | """One-step prediction 80 | 81 | Args: 82 | context (Tensor): orderxd 83 | """ 84 | indices = torch.argmax(context, dim = -1) 85 | probs = self.forward1D(context) 86 | #TODO: bug 87 | if probs.sum() == 0: 88 | prediction = torch.Tensor([-1]) 89 | # prediction = torch.zeros(self.d) 90 | # prediction[-1] = -1 91 | else: 92 | #multinomial 93 | prediction = torch.multinomial(probs, 1) #.unsqueeze(0) #torch.Tensor(random.choices(np.eye(self.d), weights=dist, k=1)[0]) 94 | return prediction 95 | 96 | def predict(self, context_batch): 97 | """Batched one-step prediction 98 | 99 | Args: 100 | context (Tensor): Bxorderxd 101 | """ 102 | return torch.stack(list(map(self.predict1D, context_batch)), dim = 0) 103 | def nll(self, samples, reduce=True): 104 | """[summary] 105 | 106 | Args: 107 | samples ([type]): BxTxd 108 | 109 | Returns: 110 | [type]: [description] 111 | """ 112 | 113 | #nll_loss = torch.nn.NLLLoss() 114 | nll = 0 115 | dim = samples.shape 116 | for context_batch, label_indices_batch in ContextLabelIterator(samples, self.order): 117 | prob = self.forward(context_batch) + Constants.zero_prob_add #1e-3 118 | logprob = torch.log(prob) 119 | indices = (list(range(dim[0])), label_indices_batch) 120 | nll -= float(logprob[indices].sum()) 121 | #nll -= nll_loss(logprob, label_batch) 122 | if reduce: 123 | nll /= (dim[0] * dim[1]) 124 | print('WARNING: nll of first notes are not computed.') 125 | return nll 126 | def precision(self, samples): 127 | """[summary] 128 | 129 | Args: 130 | samples ([type]): BxTxd 131 | 132 | Returns: 133 | [type]: [description] 134 | """ 135 | predictions = [] 136 | labels = [] 137 | for context_batch, label_indices_batch in ContextLabelIterator(samples, self.order): 138 | #note: hack for fixing return [-1,...,-1] when distribtuion is not well defined. In this case label -1 is returned 139 | #m, i = self.predict(context_batch).max(-1) 140 | #prediction = m * i 141 | preiction = self.predict(context_batch) 142 | predictions += prediction 143 | labels += label_indices_batch 144 | print('WARNING: precision of first notes are not computed.') 145 | return Metrics.precision(labels, predictions) 146 | class ContextLabelIterator: 147 | def __init__(self, samples, order): 148 | self.samples = samples 149 | self.order = order 150 | def __iter__(self): 151 | self.i = 0 152 | return self 153 | def __next__(self): 154 | dim = self.samples.shape 155 | if self.i < dim[1]- (self.order): 156 | context_batch = self.samples[:, self.i:(self.i+self.order), :] 157 | label_indices_batch = self.samples[:, (self.i + self.order), :].argmax(dim=-1) 158 | self.i += 1 159 | return context_batch, label_indices_batch 160 | else: 161 | raise StopIteration 162 | class IOFitContextLabelIterator(ContextLabelIterator): 163 | def __init__(self, sample_batch, order, mc): 164 | #sample_batch = sample.unsqueeze(0) 165 | super().__init__(sample_batch, order) 166 | self.mc = mc 167 | self.old_context_indices = None 168 | self.old_label_indices = None 169 | def __next__(self): 170 | if self.old_context_indices is not None: 171 | #old_context_indices = self.old_context_indices.argmax(dim=-1).reshape(-1) 172 | self.mc.transition_matrix = MarkovChain.add(self.mc.transition_matrix, self.old_context_indices, self.old_label_indices) 173 | context_batch, label_indices_batch = super().__next__() 174 | 175 | self.old_context_indices = context_batch.argmax(-1).reshape(-1) 176 | self.old_label_indices = label_indices_batch.squeeze(0).reshape(-1) 177 | return context_batch, label_indices_batch 178 | 179 | 180 | class IOMarkovChain(MarkovChain): 181 | def __init__(self, d, order): 182 | """[summary] 183 | 184 | Args: 185 | d (int): dimension of output 186 | order (int): order of markov chain 187 | """ 188 | super().__init__(d, order) 189 | # self.d = d 190 | # self.order = order 191 | # self.transition_matrix = torch.zeros((order + 1) * [d]) 192 | def nll_prec(data_loader, order): 193 | # nll = 0 194 | # predictions = [] 195 | # labels = [] 196 | # #B = len(data_loader) #samples.shape[0] 197 | # #T = data_loader[0].shape[1] 198 | # #d = data_loader[0].shape[2] 199 | # num_elem = 0 200 | # for batch in data_loader: 201 | # sample = batch[0] 202 | # d = sample.shape[2] 203 | # iomc = IOMarkovChain(d, order) 204 | # for context, label_index in IOFitContextLabelIterator(sample, order, iomc): 205 | # #note: hack for fixing return [0,...,-1] when distribtuion is not well defined. In this case label -1 is returned 206 | # #m, i = iomc.predict(context).max(-1) 207 | # #prediction = m * i 208 | # prediction = iomc.predict(context) 209 | # prediction = prediction.squeeze(-1) 210 | # predictions.append(prediction) 211 | # labels.append(label_index) 212 | # prob = iomc.forward(context) 213 | # p = prob[0, label_index] 214 | # p = p if p> 0 else torch.Tensor([Constants.log_zero_prob]) 215 | # nll -= torch.log(p) 216 | # num_elem +=1 217 | # precision = Metrics.precision(labels, predictions) 218 | # return nll/ num_elem, precision 219 | _, logprops, tps = IOMarkovChain.predict_all(data_loader, order) 220 | return -logprops.mean(), tps.float().mean() 221 | 222 | def forward(self, context_batch): 223 | row = super().forward(context_batch) 224 | return F.normalize(row, p=1) 225 | def predict_all(data_loader, order): 226 | predictions = [] 227 | tps = [] 228 | logprobs = [] 229 | labels = [] 230 | for batch in tqdm.tqdm(data_loader): 231 | sample = batch[0] 232 | d = sample.shape[2] 233 | iomc = IOMarkovChain(d, order) 234 | if sample.shape[1] == 0: 235 | continue 236 | for _ in range(order): 237 | predictions.append(torch.tensor(-1)) 238 | tps.append(False) 239 | logprobs.append(np.log(Constants.zero_prob_add)) 240 | for context, label_index in IOFitContextLabelIterator(sample, order, iomc): 241 | prediction = iomc.predict(context) 242 | prediction = prediction.squeeze(-1) 243 | predictions.append(prediction) 244 | labels.append(label_index) 245 | tps.append(prediction == label_index) 246 | prob = iomc.forward(context) 247 | p = prob[0, label_index] 248 | p = p if p> 0 else torch.Tensor([Constants.zero_prob_add]) 249 | logprobs.append(torch.log(p)) 250 | 251 | return torch.tensor(predictions), torch.tensor(logprobs), torch.tensor(tps) 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /dstm/util/load_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | from pathlib import Path 5 | import re 6 | import pretty_midi 7 | from functools import reduce 8 | import urllib 9 | import tarfile 10 | import tqdm.contrib.concurrent 11 | def loop_data(dataset, max_length): 12 | """[summary] 13 | 14 | Args: 15 | split (str): train/test/valid 16 | max_length (int): Filters and truncates to this size . Defaults to 512. 17 | 18 | Returns: 19 | [type]: [description] 20 | """ 21 | # dataset = [] 22 | # dataset = load_nes(split) 23 | n_samples = len(dataset) 24 | n_outs = dataset[0].shape[1] 25 | piano_rolls = np.empty((n_samples, max_length, n_outs), dtype="float32") 26 | 27 | def inner_loop(i): 28 | l = 0 29 | while True: 30 | # Augment data by looping until data is specified length 31 | for s in dataset[i]: 32 | # Use only channels 0, 1 (Pulse 1, Pulse 2) 33 | piano_rolls[i, l] = s 34 | l += 1 35 | if l == max_length: 36 | return 37 | 38 | for i, _ in enumerate(dataset): 39 | inner_loop(i) 40 | return torch.Tensor(piano_rolls) 41 | 42 | 43 | def load_nes(split): 44 | """Reads NES dataset 45 | 46 | Args: 47 | split (str): train/test/valid 48 | 49 | Returns: 50 | [type]: [description] 51 | """ 52 | dataset = [] 53 | for file in Path('data/nesmdb/nesmdb24_seprsco/{}/'.format(split)).rglob("*.pkl"): 54 | with open(file, "rb") as f: 55 | song = pickle.load(f) 56 | dataset.append(song[2]) 57 | return dataset 58 | 59 | 60 | class CustomDataset(torch.utils.data.Dataset): 61 | def __init__(self, data): 62 | self.data = data 63 | def pin_memory(self): 64 | self.data = list(map(lambda x: x.pin_memory(), self.data)) 65 | return self 66 | def __getitem__(self, index): 67 | return torch.Tensor(self.data[index]) 68 | 69 | def __len__(self): 70 | return len(self.data) 71 | def get_longest_seq(self): 72 | return reduce(lambda a, b: a if a > int(b.shape[0]) else int(b.shape[0]), self.data, -1) 73 | 74 | class PicklableSeqCollate: 75 | def __init__(self, seq_max_length) -> None: 76 | self.seq_max_length = seq_max_length 77 | def __call__(self, x): 78 | lengths = [] 79 | y = len(x) * [None] 80 | for i, t in enumerate(x): 81 | seq_length = t.shape[0] 82 | if seq_length > self.seq_max_length: 83 | lengths.append(self.seq_max_length) 84 | #NOTE: randomly sample seq of max length 85 | end_index = np.random.randint(self.seq_max_length, seq_length) 86 | #x[i] = t[(end_index-seq_max_length):end_index, :] 87 | y[i] = t[(end_index-self.seq_max_length):end_index, :] 88 | else: 89 | lengths.append(seq_length) 90 | y[i] = t 91 | lengths = torch.tensor(lengths) 92 | #padded = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 93 | padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True) 94 | return padded, lengths 95 | 96 | 97 | class DataPreprocessing: 98 | """Abstract class impliment the 99 | """ 100 | 101 | def __init__(self, nb_classes, loop=False): 102 | self.nb_classes = nb_classes 103 | self.eye = np.eye(nb_classes) 104 | self.max_length = 1024 105 | self.loop = loop 106 | #self.dataset_size= dataset_size 107 | 108 | 109 | def preprocessed_cache(self, split, num_samples=None): 110 | folder = f'{self.folder}/cache' 111 | path = '{}/{}'.format(folder, split) 112 | path += ('_loop.pt' if self.loop else '.pt') 113 | #TODO: For now we store whole piano roll, though could be compressed by argmax 114 | if self.loop: 115 | raise Exception("Deprecated Parameter. Should be removed.") 116 | if Path(path).exists(): 117 | dataset = torch.load(path) 118 | return dataset[:num_samples] 119 | else: 120 | print("Creating cache for {} data partion".format(split)) 121 | Path(f'{folder}').mkdir(parents=True, exist_ok=True) 122 | dataset = self.get_encoded(split, num_samples) 123 | if self.loop: 124 | dataset = loop_data(dataset, self.max_length) 125 | torch.save(dataset, path) 126 | return dataset 127 | 128 | 129 | def get_data_loader(self, split, dataset=None, num_samples=None, seq_max_length=float('inf'), **kwargs): 130 | if dataset is None: 131 | dataset = self.preprocessed_cache(split, num_samples) 132 | if self.loop: 133 | collate_fn = None 134 | else: 135 | # def collate_fn(x): 136 | # lengths = [] 137 | # y = len(x) * [None] 138 | # for i, t in enumerate(x): 139 | # seq_length = t.shape[0] 140 | # if seq_length > seq_max_length: 141 | # lengths.append(seq_max_length) 142 | # #NOTE: randomly sample seq of max length 143 | # end_index = np.random.randint(seq_max_length, seq_length) 144 | # #x[i] = t[(end_index-seq_max_length):end_index, :] 145 | # y[i] = t[(end_index-seq_max_length):end_index, :] 146 | # else: 147 | # lengths.append(seq_length) 148 | # y[i] = t 149 | # lengths = torch.tensor(lengths) 150 | # #padded = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 151 | # padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True) 152 | # return padded, lengths 153 | collate_fn = PicklableSeqCollate(seq_max_length) 154 | cds = CustomDataset(dataset) 155 | data_loader = torch.utils.data.DataLoader( 156 | cds, 157 | collate_fn=collate_fn, 158 | #sampler = DistributedSampler(cds) 159 | **kwargs 160 | ) 161 | return data_loader 162 | # TODO: this is only for midi with specified file_regex 163 | 164 | def masked_one_hot_encode(self, s, mask): 165 | encoded = self.one_hot_encode(s) 166 | return encoded[..., mask] 167 | 168 | def one_hot_encode(self, s): 169 | return self.eye[s] 170 | 171 | 172 | class NesPreprocessing(DataPreprocessing): 173 | def __init__(self, **kwargs): 174 | self.folder = 'out/nes_mdb/data_cache' 175 | self.masks = [ 176 | [i for i in range(0, 109) if i not in range(1, 32)], 177 | [i for i in range(0, 109) if i not in range(1, 32)], 178 | [i for i in range(0, 109) if i not in range(1, 21)], 179 | [i for i in range(0, 109) if i not in range(17, 109)] 180 | ] 181 | # self.nb_classes = 109 182 | self.d = len(self.masks[0]) 183 | super().__init__(nb_classes=109, **kwargs) 184 | 185 | def get_encoded(self, split): 186 | data_list = load_nes(split) 187 | one_hot_encoded = [] 188 | for sample in data_list: 189 | one_hot_encoded.append( 190 | self.masked_one_hot_encode(sample[:, 0], self.masks[0])) 191 | one_hot_encoded.append( 192 | self.masked_one_hot_encode(sample[:, 1], self.masks[1])) 193 | return one_hot_encoded 194 | 195 | 196 | class EssenPreprocessing(DataPreprocessing): 197 | def __init__(self, **kwargs): 198 | self.folder = 'out/essen/data_cache' 199 | self.data_folder = 'data/essen_all' 200 | # todo: compute auto 201 | self.mask = [1, 2, 3, 4, 5, 45, 46, 47, 48, 49, 50, 51, 52, 202 | 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 203 | 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 204 | 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 205 | 92, 93, 94, 96, 255, 256] 206 | # self.nb_classes = 257 207 | self.d = len(self.mask) 208 | self.file_regex = re.compile(r'^([a-z]+)([0-9]+)\.mid$') 209 | super().__init__(nb_classes=257, **kwargs) 210 | 211 | def get_encoded(self, split): 212 | data_list = self.load_essen(split) 213 | one_hot_encoded = [] 214 | for sample in data_list: 215 | one_hot_encoded.append( 216 | self.masked_one_hot_encode(sample, self.mask)) 217 | return one_hot_encoded 218 | 219 | def load_essen_midi(path): 220 | pattern = midi.read_midifile(str(path)) 221 | # Turn everything into 16th nodes 222 | track = pattern[0] 223 | res = pattern.resolution 224 | # ts = pattern[0][5] 225 | step_size = res // 4 226 | notes = [] 227 | midinotes = track[8:-1] 228 | global_step = 0 229 | next_note = -1 230 | next_change = 0 231 | rest_number = 256 232 | can_shorten = False 233 | for event in midinotes: 234 | next_change = next_change + event.tick 235 | current_note = next_note 236 | if isinstance(event, midi.NoteOffEvent): 237 | next_note = rest_number 238 | # if next_change - global_step > step_size: 239 | # can_shorten = True 240 | # else: 241 | # can_storten = False 242 | else: 243 | next_note = event.data[0] 244 | note_on_time = global_step 245 | # Shorten merging note if possible 246 | # Does not work for 3 consecutive 16th notes as the mid is removed can_shorten 247 | if global_step < next_change and len(notes) > 1 and notes[-1] == current_note and notes[-2] == current_note: 248 | # silence 249 | notes[-1] = rest_number 250 | 251 | while global_step < next_change: 252 | notes.append(current_note) 253 | global_step += step_size 254 | return notes 255 | 256 | def load_essen(self, split): 257 | """Reads NES dataset 258 | 259 | Args: 260 | split (str): train/test/valid 261 | 262 | Returns: 263 | [type]: [description] 264 | """ 265 | dataset = [] 266 | for path in Path('{}/{}/'.format(self.data_folder, split)).glob("*.mid"): 267 | # for path in Path('data/essen_all/').rglob("*.mid"): 268 | dataset.append(EssenPreprocessing.load_essen_midi(path)) 269 | return dataset 270 | 271 | 272 | class SessionPreprocessing(DataPreprocessing): 273 | N_FILES = 45849 274 | def __init__(self, max_workers=10,**kwargs): 275 | self.folder = 'out/session' 276 | self.data_folder = 'data/session' 277 | self.max_workers=max_workers 278 | #self.file_regex = re.compile(r'^([a-z]+)([0-9]+)\.mid$') 279 | # todo: compute auto 280 | # self.mask = list(range(255)) 281 | self.mask = [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 282 | 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95] 283 | # self.nb_classes = 257 284 | self.d = len(self.mask) + 1 285 | #self.file_regex = re.compile(r'^sessiontune([0-9]+)\.mid$') 286 | #TODO: nb_classes is not really meaningfull any longer 287 | super().__init__(nb_classes=self.d, **kwargs) 288 | # def set_dataset_files(self): 289 | # if self.dataset_size == "small": 290 | # self.dataset_files = {'train': "train_6000.pt", "valid": "valid_600.pt", "test": "test_600.pt"} 291 | # elif self.dataset_size == "medium": 292 | # self.dataset_files = {'train': "train_15000.pt", "valid": "valid_1500.pt", "test": "test_1500.pt"} 293 | # elif self.dataset_size == "large": 294 | # self.dataset_files = {'train': "train_36679.pt", "valid": "valid_4585.pt", "test": "test_4585.pt"} 295 | # else: 296 | # raise NotImplementedError("Dataset size {} is not implimented.".format(self.dataset_size)) 297 | def get_encoded(self, split, num_samples): 298 | data_list = self.load(split) 299 | one_hot_encoded = [] 300 | for i, sample in enumerate(data_list): 301 | one_hot_encoded.append(sample) 302 | if num_samples is not None and i == num_samples-1: 303 | break 304 | return one_hot_encoded 305 | 306 | def load_midi(self, path): 307 | # TODO: for other dataset (change 120 to load from file) 308 | pm = pretty_midi.PrettyMIDI(path) 309 | pps = 120/60 310 | sixtheensthps = pps/(16) 311 | sixtheensthps 312 | piano_roll = pm.get_piano_roll(fs=1/sixtheensthps) 313 | masked = piano_roll.transpose()[:, self.mask] 314 | masked_with_rest = np.concatenate((np.zeros((masked.shape[0], 1)), masked), axis=1) 315 | #TODO: might be inefficient: 316 | #forces output to be unison (by sellecting only highest pitch) 317 | for i, outs in enumerate(masked_with_rest): 318 | highest_pitch = 0 319 | for j, pitch in enumerate(outs): 320 | if pitch != 0.0: 321 | highest_pitch = j 322 | masked_with_rest[i, j] = 0.0 323 | masked_with_rest[i, highest_pitch] = 1.0 324 | return torch.FloatTensor(masked_with_rest) 325 | def get_files_split(self, split): 326 | return sorted(Path(f'{self.data_folder}/{split}/').rglob("*.mid"), key=lambda x: int(x.name[11:-4])) 327 | def load(self, split): 328 | """Reads midi dataset 329 | 330 | Args: 331 | split (str): train/test/valid 332 | 333 | Returns: 334 | [type]: [description] 335 | """ 336 | #dataset = [] 337 | # instead we return a generator 338 | #for path in Path('{}/{}/'.format(self.data_folder, split)).glob("*.mid"): 339 | # dataset.append(SessionPreprocessing.load_midi(str(path))) 340 | files = self.get_files_split(split) 341 | return tqdm.contrib.concurrent.thread_map( 342 | lambda path: self.load_midi(str(path)), 343 | files, 344 | max_workers=self.max_workers, 345 | chunksize=int(len(files)/(10*self.max_workers))) 346 | #def download_dataset(self): 347 | 348 | def prepare_dataset(self): 349 | #TODO: should be more robust (checksum) 350 | download_folder = Path(self.folder).parent.joinpath("download") 351 | if not download_folder.exists(): 352 | download_folder.mkdir(parents=True, exist_ok=True) 353 | for compressed_archive, url, out_dir in [ 354 | (f'{download_folder}/dataset.tar.gz', 355 | 'https://github.com/IraKorshunova/folk-rnn/raw/master/data/midi.tgz', 356 | 'data' 357 | ), 358 | (f'{download_folder}/checkpoints.tar.gz', 359 | 'https://drive.jku.at/ssf/s/readFile/share/44884/1121046473828436352/publicLink/checkpoints.tar.gz', 360 | f'{self.folder}/model' 361 | )]: 362 | if not Path(compressed_archive).exists(): 363 | print(f"Downloading {compressed_archive} from CDN") 364 | urllib.request.urlretrieve(url, compressed_archive) 365 | Path(out_dir).mkdir(parents=True, exist_ok=True) 366 | tar = tarfile.open(compressed_archive) 367 | print(f"Extracting {compressed_archive}") 368 | tar.extractall(out_dir) 369 | splits = ["train", "valid", "test"] 370 | if not np.all([Path(f'{self.folder}/cache/{split}.pt').exists() for split in splits]): 371 | self.create_data_split() 372 | for split in splits: 373 | self.preprocessed_cache(split) 374 | def get_path_titles(self): 375 | paths_sorted = [Path(f'{self.data_folder}/sessiontune{i}.mid') for i in range(SessionPreprocessing.N_FILES)] 376 | with open(f'{self.data_folder}/allabcwrepeats_parsed', 'rt', encoding='UTF-8') as f: 377 | #with open(f'{Path(self.folder).parent.joinpath("download")}/allabcwrepeats_parsed', 'rt', encoding='UTF-8') as f: 378 | str_ = f.read() 379 | # NOTE: remove empty str and the last file is not included in midi. 380 | abc_strs = str_.split('\n\n')[:-2] 381 | regex = re.compile('T:(.+)\n') 382 | titles = list(map(lambda abc_str: re.sub(r'[^\w\-_\. ]', '_', regex.search(abc_str)[1]), abc_strs)) 383 | return paths_sorted, titles 384 | def create_data_split(self): 385 | print('Creating datasplit.') 386 | paths_sorted, titles = self.get_path_titles() 387 | assert len(paths_sorted) == len(titles), "Abc and midi files should match." 388 | data_dict = {} 389 | for title, path in zip(titles, paths_sorted): 390 | if title not in data_dict: 391 | data_dict[title] = [] 392 | data_dict[title].append(path) 393 | 394 | 395 | data_dict_keys = list(data_dict.keys()) 396 | np.random.default_rng(42).shuffle(data_dict_keys) 397 | #n_files = sum(map(len, data_dict)) 398 | n_train = int(10/12 * SessionPreprocessing.N_FILES) 399 | n_valid = int(11/12 * SessionPreprocessing.N_FILES) 400 | #n_test = n_files - n_train - n_valid 401 | n_splits_acc = [n_train, n_valid, SessionPreprocessing.N_FILES] 402 | #n_splits = np.diff(n_splits_acc, prepend=0) 403 | n_splits = [38213, 3818, 3818] 404 | partition_keys = [ 405 | [], 406 | [], 407 | [] 408 | ] 409 | split = 0 410 | file_counter = 0 411 | for key in data_dict_keys: 412 | if file_counter > n_splits_acc[split]: 413 | split += 1 414 | partition_keys[split].append(key) 415 | file_counter += len(data_dict[key]) 416 | folders = ['train', 'valid', 'test'] 417 | # for split in folders: 418 | file_counts = [] 419 | for split_str, split_keys, n_split in zip(folders, partition_keys, n_splits): 420 | split_folder = Path('{}/{}'.format(self.data_folder, split_str)) 421 | if split_folder.exists(): 422 | if len(list(split_folder.glob('**/*.mid'))) == n_split: 423 | file_counts.append(n_split) 424 | continue 425 | else: 426 | raise Exception('Unexpected number of files.') 427 | else: 428 | split_folder.mkdir(parents=True, exist_ok=True) 429 | file_count = 0 430 | for key in split_keys: 431 | paths = data_dict[key] 432 | folder = Path(f'{self.data_folder}/{split_str}/{key}') 433 | folder.mkdir() 434 | for path in paths: 435 | p = Path(f"{folder}/{path.name}").resolve() 436 | p.symlink_to(path.resolve()) 437 | file_count+=1 438 | file_counts.append(file_count) 439 | print('Train/Valid/Test file count: {}/{}/{}'.format(*file_counts)) 440 | -------------------------------------------------------------------------------- /dstm/model/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import MultiheadAttention, LayerNorm, Linear, TransformerDecoderLayer, Dropout 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.modules.transformer import Transformer, _get_activation_fn, _get_clones 5 | import copy 6 | import torch.nn.functional as F 7 | import math 8 | import pytorch_lightning as pl 9 | 10 | def lin_proj(X, Wq, Wk, Wv): 11 | batch_size, seq_length, model_dim = X.shape 12 | num_heads, _ , hidden_dim = Wq.shape 13 | HX = torch.cat([ X for _ in range(num_heads)]).reshape(num_heads, batch_size, seq_length, model_dim) 14 | 15 | BWq = torch.cat([Wq for _ in range(batch_size)]).reshape(batch_size, num_heads, model_dim, hidden_dim) 16 | HWqT = BWq.permute(1, 0, 3, 2) # num_heads, batch_size, elem_dim, elem_dim 17 | 18 | q_proj_flat = HX.reshape(-1, seq_length, model_dim).bmm(HWqT.reshape(-1, model_dim, hidden_dim)) # flattened batch mult 19 | q_proj = q_proj_flat.view(num_heads, batch_size, seq_length, hidden_dim) 20 | 21 | BWk = torch.cat([Wk for _ in range(batch_size)]).reshape(batch_size, num_heads, model_dim, hidden_dim) 22 | HWkT = BWk.permute(1, 0, 3, 2) # num_heads, batch_size, elem_dim, seq_length 23 | 24 | k_proj_flat = HX.reshape(-1, seq_length, model_dim).bmm(HWkT.reshape(-1, model_dim, hidden_dim)) # flattened batch mult 25 | k_proj = k_proj_flat.view(num_heads, batch_size, seq_length, hidden_dim) 26 | 27 | 28 | 29 | BWv = torch.cat([Wv for _ in range(batch_size)]).reshape(batch_size, num_heads, model_dim, hidden_dim) 30 | HWvT = BWv.permute(1, 0, 3, 2) # num_heads, batch_size, elem_dim, seq_length 31 | 32 | v_proj_flat = HX.reshape(-1, seq_length, model_dim).bmm(HWvT.reshape(-1, model_dim, hidden_dim)) # flattened batch mult 33 | b_proj = v_proj_flat.view(num_heads, batch_size, seq_length, hidden_dim) 34 | 35 | 36 | return q_proj, k_proj, b_proj 37 | #TODO: should we use the non_flattened versions instead? 38 | def att_logits(q_proj, k_proj): #, num_heads, batch_size): 39 | # ATT prior to softmax 40 | n_heads, batch_size, seq_length, elem_dim = q_proj.shape 41 | q_proj_flat = q_proj.view(-1, seq_length, elem_dim) 42 | k_proj_flat = k_proj.view(-1, seq_length, elem_dim) 43 | att_logits_flat = q_proj_flat.reshape(-1, seq_length, elem_dim).bmm(k_proj_flat.transpose(-1, -2)) 44 | #att = attflat.reshape(num_heads, batch_size, seq_length, seq_length) 45 | att_logits = att_logits_flat.view(n_heads, batch_size, seq_length, seq_length) 46 | return att_logits 47 | 48 | # # associciate elemts i,j with vector of size elem_dim # This is however only pr. head... according to Huang 49 | # Rt = torch.rand(seq_length, elem_dim, seq_length) 50 | 51 | # # For each projected sequenence element x_i -> z_i dot it by the relative interactions of the position a_{i,j} 52 | # Sflat = resqflat.transpose(1,0).bmm(Rt).transpose(0, 1) 53 | # S = Sflat.reshape(num_heads, batch_size, seq_length, seq_length) 54 | 55 | # i = 3 56 | # j = 2 57 | # print(S[h, b, i, j ] - resq[h, b][i:i+1].mm(Rt[i,:,j].unsqueeze(1))) 58 | # #print(S[h, b, i, j] - S[h, b, j, i]) 59 | # Relative with head (Shaw 2018) 60 | def rel_pos_enc(q_proj, Rt): # Rt: num_heads, seq_length, elem_dim, seq_length 61 | num_heads = q_proj.shape[0] 62 | batch_size = q_proj.shape[1] 63 | seq_length = q_proj.shape[2] 64 | elem_dim = q_proj.shape[3] 65 | Rtflat = Rt.view(-1, elem_dim, seq_length) 66 | #resqflat = resq.view(-1, seq_length, elem_dim) 67 | Sflat = q_proj.transpose(1,2).reshape(-1, batch_size, elem_dim).bmm(Rtflat) 68 | #Sflat = resqflat.transpose(1,2).reshape(-1, batch_size, elem_dim).bmm(Rtflat) 69 | S = Sflat.view(num_heads, seq_length, batch_size, seq_length).transpose(1,2) # num_heads, batch_size, seq_length, seq_length 70 | return S 71 | 72 | # Music Transformer (CZA Huang 2019) 73 | def rel_pos_enc_eff(q_proj, ErT): 74 | num_heads, batch_size, seq_length, hidden_dim = q_proj.shape 75 | BErTflat = torch.cat([ErT for _ in range(batch_size)]) # batch_size*num_heads, elem_dim seq_length 76 | # TODO: might be inefficient 77 | HErT = BErTflat.reshape(batch_size, num_heads, hidden_dim, seq_length).transpose(0,1) #num_heads, batch_size, elem_dim seq_length 78 | HErTflat = HErT.reshape(-1, hidden_dim, seq_length) 79 | resqflat = q_proj.view(-1, seq_length, hidden_dim) 80 | relflat = resqflat.bmm(HErTflat) 81 | 82 | padded = torch.nn.functional.pad(relflat, (1, 0, 0, 0, 0, 0)) 83 | reshaped = padded.reshape(-1, seq_length+1, seq_length) 84 | Sflat = reshaped[:, 1:, :] 85 | S = Sflat.view(num_heads, batch_size, seq_length, seq_length) 86 | return S 87 | 88 | 89 | class MultiheadAttentionRelativeEncoding(torch.nn.Module): 90 | #TODO: now we have fixed legnths.... 91 | def __init__(self, rel_clip_length, num_heads=7, model_dim=49, hidden_dim=None, dropout=0.0, device=None, dtype=None, pos_enc=True): 92 | super().__init__() 93 | factory_kwargs={'device': device, 'dtype': dtype} 94 | self.num_heads = num_heads 95 | self.model_dim = model_dim 96 | self.dropout = dropout 97 | self.rel_clip_length = rel_clip_length 98 | self.pos_enc = pos_enc 99 | if hidden_dim is None: 100 | if model_dim % num_heads: 101 | raise Exception("model_dim should be divisible with num_heads") 102 | hidden_dim = model_dim // num_heads 103 | self.hidden_dim = hidden_dim 104 | self.Wq = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 105 | self.Wk = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 106 | self.Wv = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 107 | self.Wo = Parameter(torch.empty((self.model_dim, self.model_dim), **factory_kwargs)) 108 | if self.pos_enc: 109 | self.att_rel_emb = Parameter(torch.empty((self.num_heads, self.hidden_dim, self.rel_clip_length), **factory_kwargs)) # num_heads, elem_dim 110 | self._reset_parameters() 111 | 112 | def _reset_parameters(self): 113 | torch.nn.init.xavier_uniform_(self.Wq) 114 | torch.nn.init.xavier_uniform_(self.Wk) 115 | torch.nn.init.xavier_uniform_(self.Wv) 116 | torch.nn.init.xavier_uniform_(self.Wo) 117 | if self.pos_enc: 118 | torch.nn.init.xavier_uniform_(self.att_rel_emb) 119 | 120 | 121 | def calc_heads_parallel(self, X, mask, scale=True): 122 | q_proj, k_proj, v_proj = lin_proj(X, self.Wq, self.Wk, self.Wv) 123 | _, batch_size ,seq_length, _ = q_proj.shape 124 | 125 | #scale 126 | if scale: 127 | q_proj /= math.sqrt(self.hidden_dim) 128 | #ordinary att 129 | logits = att_logits(q_proj, k_proj) 130 | #Srel = rel_pos_enc_eff(q_proj, self.att_rel_emb) 131 | if self.pos_enc: 132 | Srel = rel_pos_enc_eff(q_proj, self.att_rel_emb[:, :, (self.rel_clip_length-seq_length):]) 133 | logits = logits + Srel + mask 134 | else: 135 | logits = logits + mask 136 | attn = torch.nn.functional.softmax(logits, dim=-1) 137 | if self.dropout > 0.0 and self.training: 138 | # TODO: change to module 139 | attn = torch.nn.functional.dropout(attn, p=self.dropout) 140 | #TODO add relative to values 141 | attn_flat = attn.view(-1, seq_length, seq_length) 142 | v_proj_flat = v_proj.view(-1, seq_length, self.hidden_dim) 143 | output_flat = attn_flat.bmm(v_proj_flat) 144 | output = output_flat.view(self.num_heads, batch_size, seq_length, self.hidden_dim) 145 | return output, attn 146 | 147 | def combine_heads(self, output): 148 | _, batch_size, seq_length, _ = output.shape 149 | output_perm = output.permute(1, 2, 0, 3) 150 | output_perm_flattened = output_perm.reshape(-1, self.model_dim) 151 | output_summed = output_perm_flattened.mm(self.Wo) 152 | return output_summed.view(batch_size, seq_length, self.model_dim) 153 | 154 | def forward(self, X, mask): 155 | _, seq_length, _ = X.shape 156 | #mask = Transformer.generate_square_subsequent_mask(None, seq_length) 157 | output, _ = self.calc_heads_parallel(X, mask) 158 | return self.combine_heads(output) 159 | 160 | 161 | 162 | 163 | class AbstractTransformerRNNLayer(torch.nn.Module): 164 | def __init__(self, num_heads, model_dim, hidden_dim, dim_feedforward=2048, dropout=0.1, activation="relu", 165 | layer_norm_eps=1e-5, device=None, dtype=None, **kwargs) -> None: 166 | factory_kwargs = {'device': device, 'dtype': dtype} 167 | #super(TransformerDecoderLayer, self).__init__() 168 | super().__init__() 169 | self.model_dim = model_dim 170 | #NOTE: Calling like this gives a receptive field of d_model i.e. m 171 | self.self_attn = self.get_multi_head_attention(num_heads=num_heads, model_dim=model_dim, hidden_dim=hidden_dim, dropout=dropout, 172 | **factory_kwargs, **kwargs) 173 | # Implementation of Feedforward model 174 | self.linear1 = Linear(model_dim, dim_feedforward, **factory_kwargs) 175 | self.dropout = Dropout(dropout) 176 | self.linear2 = Linear(dim_feedforward, model_dim, **factory_kwargs) 177 | 178 | self.norm1 = LayerNorm(model_dim, eps=layer_norm_eps, **factory_kwargs) 179 | self.norm3 = LayerNorm(model_dim, eps=layer_norm_eps, **factory_kwargs) 180 | self.dropout1 = Dropout(dropout) 181 | self.dropout3 = Dropout(dropout) 182 | 183 | self.activation = _get_activation_fn(activation) 184 | 185 | def get_multi_head_attention(self, num_heads=7, model_dim=49, hidden_dim=None, dropout=0.0, device=None, dtype=None, **kwargs): 186 | raise NotImplementedError("Abstract method should be implimented") 187 | 188 | def forward_self_attention(self, tgt, tgt_mask, **kwargs): 189 | raise NotImplementedError("Abstract method should be implimented") 190 | def forward(self, tgt, tgt_mask = None, **kwargs): 191 | tgt2 = self.forward_self_attention(tgt, tgt_mask, **kwargs) 192 | tgt = tgt + self.dropout1(tgt2) 193 | tgt = self.norm1(tgt) 194 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 195 | tgt = tgt + self.dropout3(tgt2) 196 | tgt = self.norm3(tgt) 197 | return tgt 198 | 199 | class RelativeEncodingTransformerRNNLayer(AbstractTransformerRNNLayer): 200 | #NOTE: adapted from torch.nn.TransformerDecoderLayer 201 | def get_multi_head_attention(self, num_heads=7, model_dim=49, hidden_dim=None, dropout=0.0, device=None, dtype=None, rel_clip_length=512, pos_enc=True): 202 | return MultiheadAttentionRelativeEncoding(rel_clip_length=rel_clip_length, num_heads=num_heads, model_dim=model_dim, 203 | hidden_dim=hidden_dim, dropout=dropout, device=device, dtype=dtype, pos_enc=pos_enc) 204 | 205 | def forward_self_attention(self, tgt, tgt_mask, **kwargs): 206 | return self.self_attn.forward(X=tgt, mask=tgt_mask,) 207 | 208 | # https://pytorch.org/tutorials/beginner/transformer_tutorial.html 209 | class PositionalEncoding(torch.nn.Module): 210 | 211 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 212 | super().__init__() 213 | self.dropout = torch.nn.Dropout(p=dropout) 214 | #self.scaling_embedding = torch.nn.Parameter(torch.rand(1)) #, **factory_kwargs) 215 | #torch.nn.init.xavier_uniform_(self.scaling_embedding) 216 | position = torch.arange(max_len).unsqueeze(1) 217 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 218 | pe = torch.zeros(max_len, 1, d_model) 219 | pe[:, 0, 0::2] = torch.sin(position * div_term) 220 | pe[:, 0, 1::2] = (torch.cos(position * div_term))[:, :d_model//2] 221 | self.register_buffer('pe', pe) 222 | 223 | def forward(self, x, scaling_embedding=0.1): 224 | """ 225 | Args: 226 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 227 | """ 228 | #x = x + scaling_embedding*self.pe[:x.size(0)] 229 | x = x + x * scaling_embedding * self.pe[:x.size(0)] 230 | return x 231 | #return self.dropout(x) 232 | 233 | class TransformerDecoder(torch.nn.Module): 234 | 235 | __constants__ = ['norm'] 236 | 237 | def __init__(self, d, decoder_layer, num_layers=6, norm=None,abs_enc=False): 238 | super().__init__() 239 | 240 | self.layers = _get_clones(decoder_layer, num_layers) 241 | 242 | #decoder_layer, num_layers, d, m, norm=None) 243 | 244 | 245 | self.num_layers = num_layers 246 | self.norm = norm 247 | #self.m = m 248 | self.d = d 249 | if abs_enc: 250 | self.abs_enc = PositionalEncoding(d, decoder_layer.dropout.p, max_len=10000) 251 | else: 252 | self.abs_enc = False 253 | 254 | 255 | def forward(self, tgt): 256 | tgt = torch.nn.functional.pad(tgt, (self.layers[0].model_dim - self.d, 0, 1, 0, 0, 0)) 257 | if self.abs_enc: 258 | tgt = self.abs_enc(tgt) 259 | #TODO: don't generate every time but use max length and then just index 260 | tgt_mask = Transformer.generate_square_subsequent_mask(tgt.shape[1]).type_as(tgt) # .cuda( ) 261 | output = tgt 262 | for mod in self.layers: 263 | output = mod(tgt=output, tgt_mask=tgt_mask) 264 | if self.norm is not None: 265 | output = self.norm(output) 266 | return output 267 | 268 | class StandardTransformerRNNLayer(TransformerDecoderLayer): 269 | #NOTE: adapted from torch.nn.TransformerDecoderLayer 270 | __constants__ = ['batch_first'] 271 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", 272 | layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None) -> None: 273 | factory_kwargs = {'device': device, 'dtype': dtype} 274 | super(TransformerDecoderLayer, self).__init__() 275 | #NOTE: Calling like this gives a receptive field of d_model i.e. m 276 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 277 | **factory_kwargs) 278 | # Implementation of Feedforward model 279 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 280 | self.dropout = Dropout(dropout) 281 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 282 | 283 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 284 | self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 285 | self.dropout1 = Dropout(dropout) 286 | self.dropout3 = Dropout(dropout) 287 | 288 | self.activation = _get_activation_fn(activation) 289 | 290 | def forward(self, tgt, memory, tgt_mask = None, memory_mask= None, 291 | tgt_key_padding_mask= None, memory_key_padding_mask= None): 292 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 293 | key_padding_mask=tgt_key_padding_mask)[0] 294 | tgt = tgt + self.dropout1(tgt2) 295 | tgt = self.norm1(tgt) 296 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 297 | tgt = tgt + self.dropout3(tgt2) 298 | tgt = self.norm3(tgt) 299 | return tgt 300 | 301 | # def _get_clones(module, N): 302 | # return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 303 | 304 | # def _get_activation_fn(activation): 305 | # if activation == "relu": 306 | # return F.relu 307 | # elif activation == "gelu": 308 | # return F.gelu 309 | 310 | # raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 311 | 312 | # def generate_square_subsequent_mask(sz): 313 | # #TODO: will be depreacted as should be static torch.nn.Transformer.generate_square_subsequent_mask, however, is not yet in pytorch "1.9" 314 | # return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) 315 | 316 | class StandardTransformerRNN(torch.nn.Module): 317 | 318 | def __init__(self, m, d, activation='relu', dropout=0.0, nhead=7, num_layers=6): 319 | super().__init__() 320 | self.m = m 321 | self.d = d 322 | transformer_layer = StandardTransformerRNNLayer(d_model=m, nhead=nhead, dim_feedforward=2048, dropout=dropout, activation='relu', layer_norm_eps=1e-05, batch_first=True, device=None, dtype=None) 323 | self.transformer_decoder = torch.nn.TransformerDecoder(transformer_layer, num_layers=num_layers) 324 | 325 | def forward(self, ss): 326 | tgt = torch.nn.functional.pad(ss, (self.m - self.d, 0, 1, 0, 0, 0)) 327 | #TODO: memory have no effect but needs to be supplied when using TransformerDecoder. Alternatively write own TransformerDecoder 328 | memory = torch.zeros(ss.shape[0], 1, self.m) #.cuda() 329 | #TODO: don't generate every time but use max length and then just index 330 | mask = Transformer.generate_square_subsequent_mask(None, tgt.shape[1]).type_as(tgt) #.cuda() 331 | return self.transformer_decoder(tgt, memory, tgt_mask=mask, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None) 332 | #NOTE: since we need to project to a larger plane for now we just pad 333 | 334 | class LinearTransformer(torch.autograd.Function): 335 | """https://arxiv.org/pdf/2006.16236.pdf""" 336 | 337 | @staticmethod 338 | def forward(ctx, phi_Q, phi_K, V): 339 | """ 340 | In the forward pass we receive a Tensor containing the input and return 341 | a Tensor containing the output. ctx is a context object that can be used 342 | to stash information for backward computation. You can cache arbitrary 343 | objects for use in the backward pass using the ctx.save_for_backward method. 344 | """ 345 | num_heads, batch_size, seq_length, hidden_dim = phi_Q.shape 346 | #_, batch_size ,seq_length, _ = q_proj.shape 347 | #S = torch.zeros(num_heads, batch_size, hidden_dim, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 348 | #Z = torch.zeros(num_heads, batch_size, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 349 | #V_bar = torch.empty(num_heads, batch_size, seq_length, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 350 | # for i in range(seq_length): 351 | # S = S + torch.matmul(phi_K[:, :, i, :, None], V[:,:,i, None, :]) 352 | # #Z = Z + phi_K[:,:,i, :] 353 | # #output[:, :, i, :] = torch.matmul(phi_Q[:, :, i, None, :], S).squeeze(-2) / Z 354 | # V_bar[:, :, i, :] = torch.matmul(phi_Q[:, :, i, None, :], S).squeeze(-2) 355 | S = torch.matmul(phi_K[:, :, :, :, None], V[:,:,:, None, :]) 356 | S = S.cumsum(2) 357 | V_bar = torch.matmul(phi_Q[:, :, :, None, :], S).squeeze(-2) 358 | ctx.save_for_backward(phi_Q, phi_K, V, S) 359 | return V_bar 360 | 361 | @staticmethod 362 | def backward(ctx, G): 363 | """ 364 | In the backward pass we receive a Tensor containing the gradient of the loss 365 | with respect to the output, and we need to compute the gradient of the loss 366 | with respect to the input. 367 | """ 368 | #phi_Q, phi_K, V, = ctx.saved_tensors 369 | phi_Q, phi_K, V, S = ctx.saved_tensors 370 | #num_heads, batch_size, seq_length, hidden_dim = phi_Q.shape 371 | #S = torch.zeros(num_heads, batch_size, hidden_dim, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 372 | #TODO: we could try to accumualte S, and store so we don't need to recompute... 373 | # grad_phi_Q = torch.empty(num_heads, batch_size, seq_length, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 374 | # for i in range(seq_length): 375 | # S = S + torch.matmul(phi_K[:, :, i, :, None], V[:,:,i, None, :]) 376 | # grad_phi_Q[:, :, i, :] = torch.matmul(G[:, :, i, None, :], S.permute(0,1,3,2)).squeeze(-2) 377 | grad_phi_Q = torch.matmul(G[:, :, :, None, :], S.permute(0, 1, 2, 4, 3)).squeeze(-2) 378 | #S = torch.zeros(num_heads, batch_size, hidden_dim, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 379 | #grad_phi_K = torch.empty(num_heads, batch_size, seq_length, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 380 | #grad_V = torch.empty(num_heads, batch_size, seq_length, hidden_dim, dtype=phi_Q.dtype, device=phi_Q.device) 381 | # for i in range(seq_length-1, -1, -1): 382 | # S = S + torch.matmul(phi_Q[:, :, i, :, None], G[:, :, i, None, :]) 383 | # #print(S.permute(0,1,3,2).shape) 384 | # #print(phi_K[:, :, i, :, None].shape) 385 | # grad_V[:, :, i, :] = torch.matmul(S.permute(0,1,3,2), phi_K[:, :, i, :, None]).squeeze(-1) 386 | # grad_phi_K[:, :, i, :] = torch.matmul(S, V[:, :, i, :, None]).squeeze(-1) 387 | S = torch.matmul(phi_Q[:, :, :, :, None], G[:, :, :, None, :]) 388 | # reverse cumsum https://github.com/pytorch/pytorch/issues/33520 389 | S = S + torch.sum(S, dim=2, keepdims=True) - torch.cumsum(S, dim=2) 390 | grad_V = torch.matmul(S.permute(0, 1, 2, 4, 3), phi_K[:, :, :, :, None]).squeeze(-1) 391 | grad_phi_K = torch.matmul(S, V[:, :, :, :, None]).squeeze(-1) 392 | return grad_phi_Q, grad_phi_K, grad_V 393 | 394 | linear_transformer = LinearTransformer.apply 395 | 396 | class MultiheadAttentionLinear(torch.nn.Module): 397 | #TODO: now we have fixed legnths.... 398 | def __init__(self, rel_clip_length, num_heads=7, model_dim=49, hidden_dim=None, dropout=0.0, device=None, dtype=None, pos_enc=True): 399 | super().__init__() 400 | factory_kwargs={'device': device, 'dtype': dtype} 401 | self.num_heads = num_heads 402 | self.model_dim = model_dim 403 | self.dropout = dropout 404 | self.rel_clip_length = rel_clip_length 405 | #self.pos_enc = pos_enc 406 | if hidden_dim is None: 407 | if model_dim % num_heads: 408 | raise Exception("model_dim should be divisible with num_heads") 409 | hidden_dim = model_dim // num_heads 410 | self.hidden_dim = hidden_dim 411 | self.Wq = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 412 | self.Wk = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 413 | self.Wv = Parameter(torch.empty((self.num_heads, self.model_dim, self.hidden_dim), **factory_kwargs)) 414 | self.Wo = Parameter(torch.empty((self.model_dim, self.model_dim), **factory_kwargs)) 415 | # if self.pos_enc: 416 | # self.att_rel_emb = Parameter(torch.empty((self.num_heads, self.hidden_dim, self.rel_clip_length), **factory_kwargs)) # num_heads, elem_dim 417 | self._reset_parameters() 418 | 419 | def _reset_parameters(self): 420 | torch.nn.init.xavier_uniform_(self.Wq) 421 | torch.nn.init.xavier_uniform_(self.Wk) 422 | torch.nn.init.xavier_uniform_(self.Wv) 423 | torch.nn.init.xavier_uniform_(self.Wo) 424 | # if self.pos_enc: 425 | # torch.nn.init.xavier_uniform_(self.att_rel_emb) 426 | 427 | 428 | def calc_heads_parallel(self, X, mask, scale=True): 429 | q_proj, k_proj, v_proj = lin_proj(X, self.Wq, self.Wk, self.Wv) 430 | q_proj = torch.nn.functional.elu(q_proj) + 1 431 | k_proj = torch.nn.functional.elu(k_proj) + 1 432 | #num_heads, batch_size, seq_length, hidden_dim = q_proj.shape 433 | #_, batch_size ,seq_length, _ = q_proj.shape 434 | # S = torch.zeros(num_heads, batch_size, hidden_dim, hidden_dim, dtype=X.dtype, device=X.device) 435 | # Z = torch.zeros(num_heads, batch_size, hidden_dim, dtype=X.dtype, device=X.device) 436 | # output = torch.empty(num_heads, batch_size, seq_length, hidden_dim, dtype=X.dtype, device=X.device) 437 | # for i in range(seq_length): 438 | # S = S + torch.matmul(k_proj[:, :, i, :, None], v_proj[:,:,i, None, :]) 439 | # Z = Z + k_proj[:,:,i, :] 440 | # output[:, :, i, :] = torch.matmul(q_proj[:, :, i, None, :], S).squeeze(-2) / Z 441 | # return output 442 | #TODO: dropout should actucally be inside linear_transformer function 443 | 444 | 445 | V_bar = linear_transformer(q_proj, k_proj, v_proj) 446 | Z = k_proj.cumsum(2) 447 | Z = torch.matmul(q_proj.unsqueeze(-2), Z.unsqueeze(-1)).squeeze(-1) 448 | return V_bar / Z 449 | #TODO: refactor to avoid code, dubl 450 | def combine_heads(self, output): 451 | _, batch_size, seq_length, _ = output.shape 452 | output_perm = output.permute(1, 2, 0, 3) 453 | output_perm_flattened = output_perm.reshape(-1, self.model_dim) 454 | output_summed = output_perm_flattened.mm(self.Wo) 455 | return output_summed.view(batch_size, seq_length, self.model_dim) 456 | 457 | def forward(self, X, mask): 458 | _, seq_length, _ = X.shape 459 | if self.training: 460 | #TODO: change to non functional 461 | X = torch.nn.functional.dropout(X, p=self.dropout) 462 | output = self.calc_heads_parallel(X, mask) 463 | output = self.combine_heads(output) 464 | 465 | return output 466 | 467 | class LinearTransformerRNNLayer(AbstractTransformerRNNLayer): 468 | #NOTE: adapted from torch.nn.TransformerDecoderLayer 469 | def get_multi_head_attention(self, num_heads=7, model_dim=49, hidden_dim=None, dropout=0.0, device=None, dtype=None, rel_clip_length=512, pos_enc=True): 470 | return MultiheadAttentionLinear(rel_clip_length=rel_clip_length, num_heads=num_heads, model_dim=model_dim, 471 | hidden_dim=hidden_dim, dropout=dropout, device=device, dtype=dtype, pos_enc=pos_enc) 472 | 473 | def forward_self_attention(self, tgt, tgt_mask, **kwargs): 474 | return self.self_attn.forward(X=tgt, mask=tgt_mask,) 475 | --------------------------------------------------------------------------------