├── LICENSE ├── README.md └── code ├── LyricsCommentData.py ├── attention_modules.py ├── data.py ├── eval.py ├── model.py ├── model_fusion.py ├── modeling_bart.py ├── music_encoder.py ├── train.py └── train_fusion.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhang Yixiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | The code repository for our paper "Interpreting Song Lyrics with an Audio-Informed Pre-trained Language Model", which is accepted to ISMIR 2022. 3 | 4 | ## Dataset 5 | 6 | Please refer to https://zenodo.org/record/7429711. 7 | 8 | ## Checkpoints 9 | 10 | Please refer to https://drive.google.com/drive/folders/18EUUx-KT9xGJ1uq2UoOgj0X9BpngNn_T?usp=sharing. 11 | 12 | ## Data Structure 13 | 14 | You can simply use `pickle` to load this dataset. It contains a list of `MusicData` objects. 15 | 16 | TBD 17 | 18 | -------------------------------------------------------------------------------- /code/LyricsCommentData.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | 4 | 5 | @dataclass 6 | class LyricsCommentData(object): 7 | music4all_id: str 8 | songmeanings_id: str 9 | lyrics: str 10 | comment: str 11 | 12 | def get_audio_path(self): # get audio path from id 13 | self.audio_path = os.path.join("Music4All/music4all/audios", 14 | self.music4all_id + '.mp3' 15 | ) 16 | return self.audio_path -------------------------------------------------------------------------------- /code/attention_modules.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Code adopted from https://github.com/huggingface/pytorch-pretrained-BERT 3 | 4 | import math 5 | import copy 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | 11 | # Gelu 12 | def gelu(x): 13 | """Implementation of the gelu activation function. 14 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 15 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 16 | Also see https://arxiv.org/abs/1606.08415 17 | """ 18 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 19 | 20 | 21 | # LayerNorm 22 | try: 23 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 24 | except ImportError: 25 | # print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 26 | class BertLayerNorm(nn.Module): 27 | def __init__(self, hidden_size, eps=1e-12): 28 | """Construct a layernorm module in the TF style (epsilon inside the square root). 29 | """ 30 | super(BertLayerNorm, self).__init__() 31 | self.weight = nn.Parameter(torch.ones(hidden_size)) 32 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 33 | self.variance_epsilon = eps 34 | 35 | def forward(self, x): 36 | u = x.mean(-1, keepdim=True) 37 | s = (x - u).pow(2).mean(-1, keepdim=True) 38 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 39 | return self.weight * x + self.bias 40 | 41 | 42 | class BertConfig(object): 43 | def __init__(self, 44 | vocab_size, 45 | hidden_size=768, 46 | num_hidden_layers=12, 47 | num_attention_heads=12, 48 | intermediate_size=3072, 49 | hidden_act="gelu", 50 | hidden_dropout_prob=0.1, 51 | max_position_embeddings=512, 52 | attention_probs_dropout_prob=0.1, 53 | type_vocab_size=2): 54 | self.vocab_size = vocab_size 55 | self.hidden_size = hidden_size 56 | self.num_hidden_layers = num_hidden_layers 57 | self.num_attention_heads = num_attention_heads 58 | self.hidden_act = hidden_act 59 | self.intermediate_size = intermediate_size 60 | self.hidden_dropout_prob = hidden_dropout_prob 61 | self.max_position_embeddings = max_position_embeddings 62 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 63 | self.type_vocab_size = type_vocab_size 64 | 65 | 66 | class BertSelfAttention(nn.Module): 67 | def __init__(self, config): 68 | super(BertSelfAttention, self).__init__() 69 | if config.hidden_size % config.num_attention_heads != 0: 70 | raise ValueError( 71 | "The hidden size (%d) is not a multiple of the number of attention " 72 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 73 | self.num_attention_heads = config.num_attention_heads 74 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 75 | self.all_head_size = self.num_attention_heads * self.attention_head_size 76 | 77 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 78 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 79 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 80 | 81 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 82 | 83 | def transpose_for_scores(self, x): 84 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 85 | x = x.view(*new_x_shape) 86 | return x.permute(0, 2, 1, 3) 87 | 88 | def forward(self, hidden_states, attention_mask): 89 | mixed_query_layer = self.query(hidden_states) 90 | mixed_key_layer = self.key(hidden_states) 91 | mixed_value_layer = self.value(hidden_states) 92 | 93 | query_layer = self.transpose_for_scores(mixed_query_layer) 94 | key_layer = self.transpose_for_scores(mixed_key_layer) 95 | value_layer = self.transpose_for_scores(mixed_value_layer) 96 | 97 | # Take the dot product between "query" and "key" to get the raw attention scores. 98 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 99 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 100 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 101 | if attention_mask is not None: 102 | attention_scores = attention_scores + attention_mask 103 | 104 | # Normalize the attention scores to probabilities. 105 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 106 | 107 | # This is actually dropping out entire tokens to attend to, which might 108 | # seem a bit unusual, but is taken from the original Transformer paper. 109 | attention_probs = self.dropout(attention_probs) 110 | 111 | context_layer = torch.matmul(attention_probs, value_layer) 112 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 113 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 114 | context_layer = context_layer.view(*new_context_layer_shape) 115 | return context_layer 116 | 117 | 118 | class BertSelfOutput(nn.Module): 119 | def __init__(self, config): 120 | super(BertSelfOutput, self).__init__() 121 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 122 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 123 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 124 | 125 | def forward(self, hidden_states, input_tensor): 126 | hidden_states = self.dense(hidden_states) 127 | hidden_states = self.dropout(hidden_states) 128 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 129 | return hidden_states 130 | 131 | 132 | class BertAttention(nn.Module): 133 | def __init__(self, config): 134 | super(BertAttention, self).__init__() 135 | self.self = BertSelfAttention(config) 136 | self.output = BertSelfOutput(config) 137 | 138 | def forward(self, input_tensor, attention_mask): 139 | self_output = self.self(input_tensor, attention_mask) 140 | attention_output = self.output(self_output, input_tensor) 141 | return attention_output 142 | 143 | 144 | class BertIntermediate(nn.Module): 145 | def __init__(self, config): 146 | super(BertIntermediate, self).__init__() 147 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 148 | self.intermediate_act_fn = gelu 149 | 150 | def forward(self, hidden_states): 151 | hidden_states = self.dense(hidden_states) 152 | hidden_states = self.intermediate_act_fn(hidden_states) 153 | return hidden_states 154 | 155 | 156 | class BertOutput(nn.Module): 157 | def __init__(self, config): 158 | super(BertOutput, self).__init__() 159 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 160 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 161 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 162 | 163 | def forward(self, hidden_states, input_tensor): 164 | hidden_states = self.dense(hidden_states) 165 | hidden_states = self.dropout(hidden_states) 166 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 167 | return hidden_states 168 | 169 | 170 | class BertLayer(nn.Module): 171 | def __init__(self, config): 172 | super(BertLayer, self).__init__() 173 | self.attention = BertAttention(config) 174 | self.intermediate = BertIntermediate(config) 175 | self.output = BertOutput(config) 176 | 177 | def forward(self, hidden_states, attention_mask): 178 | attention_output = self.attention(hidden_states, attention_mask) 179 | intermediate_output = self.intermediate(attention_output) 180 | layer_output = self.output(intermediate_output, attention_output) 181 | return layer_output 182 | 183 | 184 | class BertEncoder(nn.Module): 185 | def __init__(self, config): 186 | super(BertEncoder, self).__init__() 187 | layer = BertLayer(config) 188 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 189 | 190 | def forward(self, hidden_states, attention_mask=None, output_all_encoded_layers=True): 191 | all_encoder_layers = [] 192 | for layer_module in self.layer: 193 | hidden_states = layer_module(hidden_states, attention_mask) 194 | if output_all_encoded_layers: 195 | all_encoder_layers.append(hidden_states) 196 | if not output_all_encoded_layers: 197 | all_encoder_layers.append(hidden_states) 198 | return all_encoder_layers 199 | 200 | 201 | class BertEmbeddings(nn.Module): 202 | """Construct the embeddings from word, position and token_type embeddings. 203 | """ 204 | 205 | def __init__(self, config): 206 | super(BertEmbeddings, self).__init__() 207 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 208 | 209 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 210 | # any TensorFlow checkpoint file 211 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 212 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 213 | 214 | def forward(self, input_ids, token_type_ids=None): 215 | seq_length = input_ids.size(1) 216 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 217 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids[:, :, 0]) 218 | 219 | position_embeddings = self.position_embeddings(position_ids) 220 | 221 | embeddings = input_ids + position_embeddings 222 | # embeddings = input_ids 223 | embeddings = self.LayerNorm(embeddings) 224 | embeddings = self.dropout(embeddings) 225 | return embeddings 226 | 227 | 228 | class PositionalEncoding(nn.Module): 229 | def __init__(self, config): 230 | super(PositionalEncoding, self).__init__() 231 | emb_dim = config.hidden_size 232 | max_len = config.max_position_embeddings 233 | self.position_enc = self.position_encoding_init(max_len, emb_dim) 234 | 235 | @staticmethod 236 | def position_encoding_init(n_position, emb_dim): 237 | ''' Init the sinusoid position encoding table ''' 238 | 239 | # keep dim 0 for padding token position encoding zero vector 240 | position_enc = np.array([ 241 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)] 242 | if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)]) 243 | 244 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim 245 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim 246 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 247 | 248 | def forward(self, word_seq): 249 | position_encoding = self.position_enc.unsqueeze(0).expand_as(word_seq) 250 | position_encoding = position_encoding.to(word_seq.device) 251 | word_pos_encoded = word_seq + position_encoding 252 | return word_pos_encoded 253 | 254 | 255 | class BertPooler(nn.Module): 256 | def __init__(self, config): 257 | super(BertPooler, self).__init__() 258 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 259 | self.activation = nn.Tanh() 260 | 261 | def forward(self, hidden_states): 262 | # We "pool" the model by simply taking the hidden state corresponding 263 | # to the first token. 264 | first_token_tensor = hidden_states[:, 0] 265 | pooled_output = self.dense(first_token_tensor) 266 | pooled_output = self.activation(pooled_output) 267 | return pooled_output -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | from torch.utils.data import Dataset 5 | import pickle 6 | import random 7 | from . import LyricsCommentData 8 | 9 | class LyricsCommentsDataset(Dataset): 10 | 11 | def __init__(self, random=False): 12 | super(LyricsCommentsDataset, self).__init__() 13 | self.random = random 14 | with open("dataset.pkl", "rb") as f: 15 | self.data = pickle.load(f) 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | 20 | def __getitem__(self, item): 21 | lyrics = self.data[item].lyrics 22 | # if random: 23 | # comment = random.choice(self.data[item].comments) 24 | # else: 25 | comment = self.data[item].comments[0] 26 | # the longest? 27 | for i, (tmp_item, _) in enumerate(self.data[item].comments): 28 | if len(tmp_item) > len(comment[0]): 29 | comment = self.data[item].comments[i] 30 | 31 | comment = comment[0] # keep comments w/o rating 32 | 33 | return [lyrics, comment] 34 | 35 | 36 | class LyricsCommentsDatasetClean(Dataset): 37 | 38 | def __init__(self, random=False): 39 | super(LyricsCommentsDatasetClean, self).__init__() 40 | self.random = random 41 | with open("cleaned_dataset.pkl", "rb") as f: 42 | self.data = pickle.load(f) 43 | 44 | def __len__(self): 45 | return len(self.data) 46 | 47 | def __getitem__(self, item): 48 | lyrics = self.data[item].lyrics 49 | comment = self.data[item].comment 50 | 51 | return [lyrics, comment] 52 | 53 | 54 | class LyricsCommentsDatasetPsuedo(Dataset): 55 | 56 | def __init__(self, dataset_path, random=False): 57 | super(LyricsCommentsDatasetPsuedo, self).__init__() 58 | self.random = random 59 | with open(dataset_path, "rb") as f: 60 | self.data = pickle.load(f) 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | def __getitem__(self, item): 66 | lyrics = self.data[item].lyrics.replace('\n', ';') 67 | comment = self.data[item].comment 68 | 69 | return [lyrics, comment] 70 | 71 | 72 | class LyricsCommentsDatasetPsuedo_fusion(Dataset): 73 | 74 | def __init__(self, dataset_path): 75 | super(LyricsCommentsDatasetPsuedo_fusion, self).__init__() 76 | with open(dataset_path, "rb") as f: 77 | self.data = pickle.load(f) 78 | 79 | def __len__(self): 80 | return len(self.data) 81 | 82 | 83 | def __getitem__(self, item): 84 | lyrics = self.data[item].lyrics.replace('\n', ';') 85 | comment = self.data[item].comment 86 | music_id = self.data[item].music4all_id 87 | 88 | return [lyrics, comment, music_id] 89 | 90 | 91 | from torch.utils.data import Dataset, DataLoader 92 | import torch 93 | from MusicData import MusicData 94 | import csv 95 | import os 96 | from pydub import AudioSegment 97 | import matplotlib.pyplot as plt 98 | from scipy.io import wavfile 99 | from tempfile import mktemp 100 | from scipy import signal 101 | import numpy as np 102 | import torchaudio 103 | import transformers 104 | import nltk 105 | 106 | 107 | class Music4AllDataset(Dataset): 108 | def __init__(self, 109 | mel_bins, 110 | audio_length, 111 | pad_length, 112 | tag_file_path=r"Music4All/music4all/id_genres.csv", 113 | augment=True): 114 | self.tag_file_path = tag_file_path 115 | self.allow_cache = True 116 | self.mel_bins = mel_bins 117 | self.audio_length = audio_length 118 | self.pad_length = pad_length 119 | self.augment = augment 120 | # read all tags 121 | tags_file = open(tag_file_path, 'r', encoding='utf-8') 122 | self.tags_reader = list(csv.reader(tags_file, delimiter='\t'))[1:] 123 | tags_file.close() 124 | if self.augment: 125 | self.data_augmentation() 126 | 127 | def data_augmentation(self): 128 | pass 129 | 130 | def __len__(self): 131 | return len(self.tags_reader) 132 | 133 | def __getitem__(self, item): 134 | """ 135 | 136 | :param item: index 137 | :return: tags and mel-spectrogram. 138 | """ 139 | id = self.tags_reader[item][0] 140 | tags = self.tags_reader[item][1] #.split(',') 141 | 142 | # pad tags 143 | # if len(tags) >= self.pad_length: 144 | # tags = tags[:self.pad_length] 145 | # else: 146 | # for i in range(self.pad_length - len(tags)): 147 | # tags.append("[PAD]") 148 | 149 | spec_path = os.path.join("Music4All/temp_data/specs/data_cache/", id + ".npy") 150 | exist_cache = os.path.isfile(spec_path) 151 | # search cache 152 | # if exist cache, load 153 | if self.allow_cache and exist_cache: 154 | spectrogram = torch.Tensor(np.load(spec_path)) 155 | # if does not exist, calculate and save 156 | else: 157 | audio_path = os.path.join("Music4All/music4all/audios", 158 | id + '.mp3' 159 | ) 160 | (data, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path) 161 | spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins, 162 | n_fft=512, 163 | sample_rate=sample_rate, 164 | f_max=8000.0, 165 | f_min=0.0, 166 | )(torch.Tensor(data)) 167 | # TODO: There is a huge bug! 168 | # cut length 169 | if self.audio_length is not None: 170 | spectrogram = spectrogram[:, :, :self.audio_length] 171 | # to mono 172 | spectrogram = spectrogram[0, :, :].unsqueeze(0) 173 | 174 | if self.allow_cache: 175 | np.save(spec_path, spectrogram.numpy()) 176 | 177 | return tags, spectrogram 178 | 179 | 180 | class MusCapsDataset(Dataset): 181 | def __init__(self, 182 | mel_bins, 183 | audio_length, 184 | pad_length, 185 | tag_file_path=r"Music4All/music4all/id_genres.csv", 186 | augment=True): 187 | self.tag_file_path = tag_file_path 188 | self.allow_cache = True 189 | self.mel_bins = mel_bins 190 | self.audio_length = audio_length 191 | self.pad_length = pad_length 192 | self.augment = augment 193 | # read all tags 194 | tags_file = open(tag_file_path, 'r', encoding='utf-8') 195 | self.tags_reader = list(csv.reader(tags_file, delimiter='\t'))[1:] 196 | tags_file.close() 197 | if self.augment: 198 | self.data_augmentation() 199 | 200 | def data_augmentation(self): 201 | pass 202 | 203 | def __len__(self): 204 | return len(self.tags_reader) 205 | 206 | def __getitem__(self, item): 207 | """ 208 | 209 | :param item: index 210 | :return: tags and mel-spectrogram. 211 | """ 212 | id = self.tags_reader[item][0] 213 | tags = self.tags_reader[item][1] #.split(',') 214 | 215 | # pad tags 216 | # if len(tags) >= self.pad_length: 217 | # tags = tags[:self.pad_length] 218 | # else: 219 | # for i in range(self.pad_length - len(tags)): 220 | # tags.append("[PAD]") 221 | 222 | spec_path = os.path.join("Music4All/temp_data/specs/data_cache/", id + ".npy") 223 | exist_cache = os.path.isfile(spec_path) 224 | # search cache 225 | # if exist cache, load 226 | if self.allow_cache and exist_cache: 227 | spectrogram = torch.Tensor(np.load(spec_path)) 228 | # if does not exist, calculate and save 229 | else: 230 | audio_path = os.path.join("Music4All/music4all/audios", 231 | id + '.mp3' 232 | ) 233 | (data, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path) 234 | spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins, 235 | n_fft=512, 236 | sample_rate=sample_rate, 237 | f_max=8000.0, 238 | f_min=0.0, 239 | )(torch.Tensor(data)) 240 | # cut length 241 | if self.audio_length is not None: 242 | spectrogram = spectrogram[:, :, :self.audio_length] 243 | # to mono 244 | spectrogram = spectrogram[0, :, :].unsqueeze(0) 245 | np.save(spec_path, spectrogram.numpy()) 246 | 247 | return tags, spectrogram 248 | 249 | class GTZANDataset(Dataset): 250 | def __init__(self, raw_dataset, is_augment=True, window=1366): 251 | self.raw = raw_dataset 252 | self.data = list() 253 | self.mel_bins = 96 254 | self.gtzan_genres = [ 255 | "blues", 256 | "classical", 257 | "country", 258 | "disco", 259 | "hiphop", 260 | "jazz", 261 | "metal", 262 | "pop", 263 | "reggae", 264 | "rock", 265 | ] 266 | self.is_augment = is_augment 267 | self.window = window 268 | self.init() 269 | 270 | def init(self): 271 | for i, (waveform, sample_rate, label) in enumerate(self.raw): 272 | spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins)(torch.Tensor(waveform)) 273 | if self.is_augment: 274 | self.augment(spectrogram, label) 275 | else: 276 | self.data.append((spectrogram[:,:,:self.window], label)) 277 | 278 | def augment(self, spectrogram, label): 279 | length = spectrogram.shape[-1] # length 280 | # augment audio with sliding window 281 | hop_length = 250 282 | slices = (length - self.window) // hop_length 283 | for i in range(slices): 284 | self.data.append((spectrogram[:, :, i * hop_length:self.window + i*hop_length], label)) 285 | 286 | 287 | 288 | def __len__(self): 289 | return len(self.data) 290 | 291 | def __getitem__(self, index): 292 | spectrogram, label = self.data[index] 293 | label = self.gtzan_genres.index(label) 294 | return spectrogram, label 295 | 296 | 297 | 298 | -------------------------------------------------------------------------------- /code/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data import LyricsCommentsDatasetPsuedo_fusion 3 | from torch import utils, nn 4 | from model import CommentGenerator 5 | from model_fusion import CommentGenerator_fusion 6 | import transformers 7 | import datasets 8 | from tqdm import tqdm 9 | import statistics 10 | import os 11 | DATASET_PATH = "dataset_test.pkl" 12 | MODEL_PATH = "model/bart_fusion_full.pt" 13 | # MODEL_NAME = "bart" 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 16 | 17 | test_dataset = LyricsCommentsDatasetPsuedo_fusion(DATASET_PATH) 18 | dataset_length = len(test_dataset) 19 | 20 | test_dataloader = utils.data.DataLoader(test_dataset, 21 | # batch_size=len(valid_dataset), 22 | batch_size=32, 23 | shuffle=False) 24 | 25 | if 'baseline' in MODEL_PATH: 26 | model = CommentGenerator().cuda() 27 | else: 28 | model = CommentGenerator_fusion().cuda() 29 | model.load_state_dict(torch.load(MODEL_PATH)) 30 | 31 | model.eval() 32 | 33 | samples_list = list() 34 | # generate 35 | for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)): 36 | if 'baseline' in MODEL_PATH: 37 | with torch.no_grad(): 38 | output_samples = model.generate(lyrics) 39 | else: 40 | with torch.no_grad(): 41 | output_samples = model.generate(lyrics, music_id) 42 | samples_list.append(output_samples) 43 | 44 | # ------ ROUGE ------ # 45 | 46 | metrics = datasets.load_metric('rouge')#, 'sacrebleu', 'meteor', 'bertscore') 47 | 48 | for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)): 49 | output_samples = samples_list[batch_index] 50 | metrics.add_batch(predictions=output_samples, references=comment) 51 | 52 | score = metrics.compute() 53 | print(score) 54 | 55 | # ------ BLEU ------ # 56 | 57 | metrics = datasets.load_metric('sacrebleu')#, 'sacrebleu', 'meteor', 'bertscore') 58 | 59 | for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)): 60 | output_samples = samples_list[batch_index] 61 | metrics.add_batch(predictions=output_samples, references=[[i] for i in comment]) 62 | 63 | score = metrics.compute() 64 | print(score) 65 | 66 | # ------ BERTScore ------ # 67 | 68 | metrics = datasets.load_metric('bertscore')#, 'sacrebleu', 'meteor', 'bertscore') 69 | 70 | for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)): 71 | output_samples = samples_list[batch_index] 72 | metrics.add_batch(predictions=output_samples, references=[[i] for i in comment]) 73 | 74 | score = metrics.compute(lang='en') 75 | score = statistics.mean(score['f1']) 76 | print(score) 77 | 78 | # ------ METEOR ------ # 79 | 80 | metrics = datasets.load_metric('meteor')#, 'sacrebleu', 'meteor', 'bertscore') 81 | 82 | for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)): 83 | output_samples = samples_list[batch_index] 84 | metrics.add_batch(predictions=output_samples, references=[[i] for i in comment]) 85 | 86 | score = metrics.compute() 87 | print(score) -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BartTokenizer, BartForConditionalGeneration 4 | 5 | 6 | class CommentGenerator(nn.Module): 7 | def __init__(self): 8 | super(CommentGenerator, self).__init__() 9 | self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 10 | self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base") 11 | # self.bart_config = BartConfig() 12 | self.condition = None 13 | 14 | 15 | def forward(self, input_sentence_list, labels=None): 16 | encoded_input = self.tokenizer( 17 | input_sentence_list, 18 | padding=True, 19 | truncation=True, 20 | max_length=512, 21 | return_tensors='pt', 22 | ) 23 | if labels is not None: 24 | labels = self.tokenizer( 25 | labels, 26 | padding=True, 27 | truncation=True, 28 | max_length=512, 29 | return_tensors='pt', 30 | ) 31 | output = self.bart(input_ids=encoded_input['input_ids'].cuda(), 32 | attention_mask=encoded_input['attention_mask'].cuda(), 33 | labels=labels['input_ids'].cuda(), 34 | # labels 35 | ) 36 | return output 37 | 38 | def generate(self, input_sentence_list, is_cuda=True): 39 | encoded_input = self.tokenizer(input_sentence_list, 40 | padding=True, 41 | truncation=True, 42 | return_tensors='pt', 43 | ) 44 | output_ids = self.bart.generate(encoded_input['input_ids'].cuda(), 45 | num_beams=4, 46 | max_length=512, 47 | early_stopping=True, 48 | do_sample=True) 49 | return ([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 50 | for g in output_ids]) 51 | # tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 52 | # encoded_input = tokenizer(['Hello all', 'Hi all'], return_tensors='pt') 53 | # print(encoded_input) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /code/model_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BartTokenizer 4 | from modeling_bart import BartForMultimodalGeneration 5 | from music_encoder import CNNSA 6 | 7 | 8 | 9 | class CommentGenerator_fusion(nn.Module): 10 | def __init__(self): 11 | super(CommentGenerator_fusion, self).__init__() 12 | self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 13 | model_path = "best_model.pth" 14 | self.music_encoder = CNNSA().cuda() 15 | self.music_encoder.load_state_dict(torch.load(model_path)) 16 | # trial: fix music encoder's params 17 | for params in self.music_encoder.parameters(): 18 | params.requires_grad = False 19 | 20 | self.bart = BartForMultimodalGeneration.from_pretrained("facebook/bart-base", 21 | fusion_layers=[4,5], # [4,5] 22 | use_forget_gate=False, # [True] 23 | dim_common=768, # 256 24 | n_attn_heads=1).cuda() 25 | 26 | 27 | def forward(self, input_sentence_list, music_ids, labels=None): 28 | encoded_input = self.tokenizer( 29 | input_sentence_list, 30 | padding=True, 31 | truncation=True, 32 | max_length=512, 33 | return_tensors='pt', 34 | ) 35 | if labels is not None: 36 | labels = self.tokenizer( 37 | labels, 38 | padding=True, 39 | truncation=True, 40 | max_length=512, 41 | return_tensors='pt', 42 | ) 43 | music_features = self.music_encoder(music_ids) 44 | output = self.bart(input_ids=encoded_input['input_ids'].cuda(), 45 | attention_mask=encoded_input['attention_mask'].cuda(), 46 | labels=labels['input_ids'].cuda(), 47 | music_features=music_features 48 | # labels 49 | ) 50 | return output 51 | 52 | def generate(self, input_sentence_list, music_ids, is_cuda=True): 53 | encoded_input = self.tokenizer(input_sentence_list, 54 | padding=True, 55 | truncation=True, 56 | return_tensors='pt', 57 | ) 58 | music_features = self.music_encoder(music_ids) 59 | output_ids = self.bart.generate(encoded_input['input_ids'].cuda(), 60 | num_beams=5, 61 | max_length=512, 62 | early_stopping=True, 63 | do_sample=True, 64 | music_features=music_features) 65 | return ([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 66 | for g in output_ids]) 67 | # tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 68 | # encoded_input = tokenizer(['Hello all', 'Hi all'], return_tensors='pt') 69 | # print(encoded_input) -------------------------------------------------------------------------------- /code/modeling_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Revised by anonymous. 4 | 5 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | """ PyTorch BART model. """ 19 | import copy 20 | import math 21 | import random 22 | import warnings 23 | from typing import Optional, Tuple 24 | import numpy as np 25 | 26 | import torch.nn.functional as F 27 | import torch 28 | import torch.utils.checkpoint 29 | from torch import nn 30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 31 | 32 | from transformers.activations import ACT2FN 33 | from transformers.file_utils import ( 34 | add_code_sample_docstrings, 35 | add_end_docstrings, 36 | add_start_docstrings, 37 | add_start_docstrings_to_model_forward, 38 | replace_return_docstrings, 39 | ) 40 | from transformers.modeling_outputs import ( 41 | BaseModelOutput, 42 | BaseModelOutputWithPastAndCrossAttentions, 43 | CausalLMOutputWithCrossAttentions, 44 | Seq2SeqLMOutput, 45 | Seq2SeqModelOutput, 46 | Seq2SeqQuestionAnsweringModelOutput, 47 | Seq2SeqSequenceClassifierOutput, 48 | ) 49 | from transformers.modeling_utils import PreTrainedModel 50 | from transformers.utils import logging 51 | from transformers.models.bart.configuration_bart import BartConfig 52 | 53 | from music_encoder import CNNSA 54 | 55 | logger = logging.get_logger(__name__) 56 | 57 | _CHECKPOINT_FOR_DOC = "facebook/bart-large" 58 | _CONFIG_FOR_DOC = "BartConfig" 59 | _TOKENIZER_FOR_DOC = "BartTokenizer" 60 | 61 | 62 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 63 | "facebook/bart-large", 64 | # See all BART models at https://huggingface.co/models?filter=bart 65 | ] 66 | 67 | 68 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 69 | """ 70 | Shift input ids one token to the right. 71 | """ 72 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 73 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 74 | shifted_input_ids[:, 0] = decoder_start_token_id 75 | 76 | if pad_token_id is None: 77 | raise ValueError("self.model.config.pad_token_id has to be defined.") 78 | # replace possible -100 values in labels by `pad_token_id` 79 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 80 | 81 | return shifted_input_ids 82 | 83 | 84 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 85 | """ 86 | Make causal mask used for bi-directional self-attention. 87 | """ 88 | bsz, tgt_len = input_ids_shape 89 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 90 | mask_cond = torch.arange(mask.size(-1)) 91 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 92 | mask = mask.to(dtype) 93 | 94 | if past_key_values_length > 0: 95 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 96 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 97 | 98 | 99 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 100 | """ 101 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 102 | """ 103 | bsz, src_len = mask.size() 104 | tgt_len = tgt_len if tgt_len is not None else src_len 105 | 106 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 107 | 108 | inverted_mask = 1.0 - expanded_mask 109 | 110 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 111 | 112 | 113 | class BartLearnedPositionalEmbedding(nn.Embedding): 114 | """ 115 | This module learns positional embeddings up to a fixed maximum size. 116 | """ 117 | 118 | def __init__(self, num_embeddings: int, embedding_dim: int): 119 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 120 | # and adjust num_embeddings appropriately. Other models don't have this hack 121 | self.offset = 2 122 | super().__init__(num_embeddings + self.offset, embedding_dim) 123 | 124 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 125 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 126 | bsz, seq_len = input_ids_shape[:2] 127 | positions = torch.arange( 128 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 129 | ) 130 | return super().forward(positions + self.offset) 131 | 132 | 133 | class BartAttention(nn.Module): 134 | """Multi-headed attention from 'Attention Is All You Need' paper""" 135 | 136 | def __init__( 137 | self, 138 | embed_dim: int, 139 | num_heads: int, 140 | dropout: float = 0.0, 141 | is_decoder: bool = False, 142 | bias: bool = True, 143 | ): 144 | super().__init__() 145 | self.embed_dim = embed_dim 146 | self.num_heads = num_heads 147 | self.dropout = dropout 148 | self.head_dim = embed_dim // num_heads 149 | 150 | if (self.head_dim * num_heads) != self.embed_dim: 151 | raise ValueError( 152 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 153 | f" and `num_heads`: {num_heads})." 154 | ) 155 | self.scaling = self.head_dim ** -0.5 156 | self.is_decoder = is_decoder 157 | 158 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 159 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 160 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 161 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 162 | 163 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 164 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 165 | 166 | def forward( 167 | self, 168 | hidden_states: torch.Tensor, 169 | key_value_states: Optional[torch.Tensor] = None, 170 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 171 | attention_mask: Optional[torch.Tensor] = None, 172 | layer_head_mask: Optional[torch.Tensor] = None, 173 | output_attentions: bool = False, 174 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 175 | """Input shape: Batch x Time x Channel""" 176 | 177 | # if key_value_states are provided this layer is used as a cross-attention layer 178 | # for the decoder 179 | is_cross_attention = key_value_states is not None 180 | 181 | bsz, tgt_len, _ = hidden_states.size() 182 | 183 | # get query proj 184 | query_states = self.q_proj(hidden_states) * self.scaling 185 | # get key, value proj 186 | if is_cross_attention and past_key_value is not None: 187 | # reuse k,v, cross_attentions 188 | key_states = past_key_value[0] 189 | value_states = past_key_value[1] 190 | elif is_cross_attention: 191 | # cross_attentions 192 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 193 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 194 | elif past_key_value is not None: 195 | # reuse k, v, self_attention 196 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 197 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 198 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 199 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 200 | else: 201 | # self_attention 202 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 203 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 204 | 205 | if self.is_decoder: 206 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 207 | # Further calls to cross_attention layer can then reuse all cross-attention 208 | # key/value_states (first "if" case) 209 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 210 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 211 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 212 | # if encoder bi-directional self-attention `past_key_value` is always `None` 213 | past_key_value = (key_states, value_states) 214 | 215 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 216 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 217 | key_states = key_states.view(*proj_shape) 218 | value_states = value_states.view(*proj_shape) 219 | 220 | src_len = key_states.size(1) 221 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 222 | 223 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 224 | raise ValueError( 225 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 226 | ) 227 | 228 | if attention_mask is not None: 229 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 230 | raise ValueError( 231 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 232 | ) 233 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 234 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 235 | 236 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 237 | 238 | if layer_head_mask is not None: 239 | if layer_head_mask.size() != (self.num_heads,): 240 | raise ValueError( 241 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 242 | ) 243 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 244 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 245 | 246 | if output_attentions: 247 | # this operation is a bit awkward, but it's required to 248 | # make sure that attn_weights keeps its gradient. 249 | # In order to do so, attn_weights have to be reshaped 250 | # twice and have to be reused in the following 251 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 252 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 253 | else: 254 | attn_weights_reshaped = None 255 | 256 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 257 | 258 | attn_output = torch.bmm(attn_probs, value_states) 259 | 260 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 261 | raise ValueError( 262 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 263 | ) 264 | 265 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 266 | attn_output = attn_output.transpose(1, 2) 267 | 268 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 269 | # partitioned aross GPUs when using tensor-parallelism. 270 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 271 | 272 | attn_output = self.out_proj(attn_output) 273 | 274 | return attn_output, attn_weights_reshaped, past_key_value 275 | 276 | 277 | class BartEncoderLayer(nn.Module): 278 | def __init__(self, config: BartConfig): 279 | super().__init__() 280 | self.embed_dim = config.d_model 281 | self.self_attn = BartAttention( 282 | embed_dim=self.embed_dim, 283 | num_heads=config.encoder_attention_heads, 284 | dropout=config.attention_dropout, 285 | ) 286 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 287 | self.dropout = config.dropout 288 | self.activation_fn = ACT2FN[config.activation_function] 289 | self.activation_dropout = config.activation_dropout 290 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 291 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 292 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 293 | 294 | def forward( 295 | self, 296 | hidden_states: torch.Tensor, 297 | attention_mask: torch.Tensor, 298 | layer_head_mask: torch.Tensor, 299 | output_attentions: bool = False, 300 | ): 301 | """ 302 | Args: 303 | hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* 304 | attention_mask (`torch.FloatTensor`): attention mask of size 305 | *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. 306 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 307 | *(encoder_attention_heads,)*. 308 | output_attentions (`bool`, *optional*): 309 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 310 | returned tensors for more detail. 311 | """ 312 | residual = hidden_states 313 | hidden_states, attn_weights, _ = self.self_attn( 314 | hidden_states=hidden_states, 315 | attention_mask=attention_mask, 316 | layer_head_mask=layer_head_mask, 317 | output_attentions=output_attentions, 318 | ) 319 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 320 | hidden_states = residual + hidden_states 321 | hidden_states = self.self_attn_layer_norm(hidden_states) 322 | 323 | residual = hidden_states 324 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 325 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 326 | hidden_states = self.fc2(hidden_states) 327 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 328 | hidden_states = residual + hidden_states 329 | hidden_states = self.final_layer_norm(hidden_states) 330 | 331 | if hidden_states.dtype == torch.float16 and ( 332 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() 333 | ): 334 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 335 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 336 | 337 | outputs = (hidden_states,) 338 | 339 | if output_attentions: 340 | outputs += (attn_weights,) 341 | 342 | return outputs 343 | 344 | 345 | class BartDecoderLayer(nn.Module): 346 | def __init__(self, config: BartConfig): 347 | super().__init__() 348 | self.embed_dim = config.d_model 349 | 350 | self.self_attn = BartAttention( 351 | embed_dim=self.embed_dim, 352 | num_heads=config.decoder_attention_heads, 353 | dropout=config.attention_dropout, 354 | is_decoder=True, 355 | ) 356 | self.dropout = config.dropout 357 | self.activation_fn = ACT2FN[config.activation_function] 358 | self.activation_dropout = config.activation_dropout 359 | 360 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 361 | self.encoder_attn = BartAttention( 362 | self.embed_dim, 363 | config.decoder_attention_heads, 364 | dropout=config.attention_dropout, 365 | is_decoder=True, 366 | ) 367 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 368 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 369 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 370 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 371 | 372 | def forward( 373 | self, 374 | hidden_states: torch.Tensor, 375 | attention_mask: Optional[torch.Tensor] = None, 376 | encoder_hidden_states: Optional[torch.Tensor] = None, 377 | encoder_attention_mask: Optional[torch.Tensor] = None, 378 | layer_head_mask: Optional[torch.Tensor] = None, 379 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 380 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 381 | output_attentions: Optional[bool] = False, 382 | use_cache: Optional[bool] = True, 383 | ): 384 | """ 385 | Args: 386 | hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)* 387 | attention_mask (`torch.FloatTensor`): attention mask of size 388 | *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. 389 | encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape *(batch, seq_len, embed_dim)* 390 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size 391 | *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. 392 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 393 | *(encoder_attention_heads,)*. 394 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of 395 | size *(decoder_attention_heads,)*. 396 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states 397 | output_attentions (`bool`, *optional*): 398 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 399 | returned tensors for more detail. 400 | """ 401 | residual = hidden_states 402 | 403 | # Self Attention 404 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 405 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 406 | # add present self-attn cache to positions 1,2 of present_key_value tuple 407 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 408 | hidden_states=hidden_states, 409 | past_key_value=self_attn_past_key_value, 410 | attention_mask=attention_mask, 411 | layer_head_mask=layer_head_mask, 412 | output_attentions=output_attentions, 413 | ) 414 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 415 | hidden_states = residual + hidden_states 416 | hidden_states = self.self_attn_layer_norm(hidden_states) 417 | 418 | # Cross-Attention Block 419 | cross_attn_present_key_value = None 420 | cross_attn_weights = None 421 | if encoder_hidden_states is not None: 422 | residual = hidden_states 423 | 424 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 425 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 426 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 427 | hidden_states=hidden_states, 428 | key_value_states=encoder_hidden_states, 429 | attention_mask=encoder_attention_mask, 430 | layer_head_mask=cross_attn_layer_head_mask, 431 | past_key_value=cross_attn_past_key_value, 432 | output_attentions=output_attentions, 433 | ) 434 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 435 | hidden_states = residual + hidden_states 436 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 437 | 438 | # add cross-attn to positions 3,4 of present_key_value tuple 439 | present_key_value = present_key_value + cross_attn_present_key_value 440 | 441 | # Fully Connected 442 | residual = hidden_states 443 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 444 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 445 | hidden_states = self.fc2(hidden_states) 446 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 447 | hidden_states = residual + hidden_states 448 | hidden_states = self.final_layer_norm(hidden_states) 449 | 450 | outputs = (hidden_states,) 451 | 452 | if output_attentions: 453 | outputs += (self_attn_weights, cross_attn_weights) 454 | 455 | if use_cache: 456 | outputs += (present_key_value,) 457 | 458 | return outputs 459 | 460 | 461 | class BartClassificationHead(nn.Module): 462 | """Head for sentence-level classification tasks.""" 463 | 464 | def __init__( 465 | self, 466 | input_dim: int, 467 | inner_dim: int, 468 | num_classes: int, 469 | pooler_dropout: float, 470 | ): 471 | super().__init__() 472 | self.dense = nn.Linear(input_dim, inner_dim) 473 | self.dropout = nn.Dropout(p=pooler_dropout) 474 | self.out_proj = nn.Linear(inner_dim, num_classes) 475 | 476 | def forward(self, hidden_states: torch.Tensor): 477 | hidden_states = self.dropout(hidden_states) 478 | hidden_states = self.dense(hidden_states) 479 | hidden_states = torch.tanh(hidden_states) 480 | hidden_states = self.dropout(hidden_states) 481 | hidden_states = self.out_proj(hidden_states) 482 | return hidden_states 483 | 484 | 485 | class BartPretrainedModel(PreTrainedModel): 486 | config_class = BartConfig 487 | base_model_prefix = "model" 488 | supports_gradient_checkpointing = True 489 | _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] 490 | 491 | def _init_weights(self, module): 492 | std = self.config.init_std 493 | if isinstance(module, nn.Linear): 494 | module.weight.data.normal_(mean=0.0, std=std) 495 | if module.bias is not None: 496 | module.bias.data.zero_() 497 | elif isinstance(module, nn.Embedding): 498 | module.weight.data.normal_(mean=0.0, std=std) 499 | if module.padding_idx is not None: 500 | module.weight.data[module.padding_idx].zero_() 501 | 502 | def _set_gradient_checkpointing(self, module, value=False): 503 | if isinstance(module, (BartDecoder, BartEncoder)): 504 | module.gradient_checkpointing = value 505 | 506 | @property 507 | def dummy_inputs(self): 508 | pad_token = self.config.pad_token_id 509 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 510 | dummy_inputs = { 511 | "attention_mask": input_ids.ne(pad_token), 512 | "input_ids": input_ids, 513 | } 514 | return dummy_inputs 515 | 516 | 517 | class PretrainedBartModel(BartPretrainedModel): 518 | def __init_subclass__(self): 519 | warnings.warn( 520 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", 521 | FutureWarning, 522 | ) 523 | 524 | 525 | BART_START_DOCSTRING = r""" 526 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic 527 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 528 | pruning heads etc.) 529 | 530 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 531 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 532 | general usage and behavior. 533 | 534 | Parameters: 535 | config ([`BartConfig`]): 536 | Model configuration class with all the parameters of the model. Initializing with a config file does not 537 | load the weights associated with the model, only the configuration. Check out the 538 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 539 | """ 540 | 541 | BART_GENERATION_EXAMPLE = r""" 542 | Summarization example:: 543 | 544 | >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig 545 | 546 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') 547 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') 548 | 549 | >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." 550 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') 551 | 552 | >>> # Generate Summary 553 | >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) 554 | >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) 555 | 556 | Mask filling example:: 557 | 558 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 559 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 560 | >>> TXT = "My friends are but they eat too many carbs." 561 | 562 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') 563 | >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] 564 | >>> logits = model(input_ids).logits 565 | 566 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 567 | >>> probs = logits[0, masked_index].softmax(dim=0) 568 | >>> values, predictions = probs.topk(5) 569 | 570 | >>> tokenizer.decode(predictions).split() 571 | """ 572 | 573 | BART_INPUTS_DOCSTRING = r""" 574 | Args: 575 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 576 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 577 | it. 578 | 579 | Indices can be obtained using [`BartTokenizer`]. See 580 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for 581 | details. 582 | 583 | [What are input IDs?](../glossary#input-ids) 584 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 585 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 586 | 587 | - 1 for tokens that are **not masked**, 588 | - 0 for tokens that are **masked**. 589 | 590 | [What are attention masks?](../glossary#attention-mask) 591 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 592 | Indices of decoder input sequence tokens in the vocabulary. 593 | 594 | Indices can be obtained using [`BartTokenizer`]. See 595 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for 596 | details. 597 | 598 | [What are decoder input IDs?](../glossary#decoder-input-ids) 599 | 600 | Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If 601 | `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 602 | `past_key_values`). 603 | 604 | For translation and summarization training, `decoder_input_ids` should be provided. If no 605 | `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to 606 | the right for denoising pre-training following the paper. 607 | decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 608 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will 609 | also be used by default. 610 | 611 | If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and 612 | modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 613 | information on the default strategy. 614 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 615 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: 616 | 617 | - 1 indicates the head is **not masked**, 618 | - 0 indicates the head is **masked**. 619 | 620 | decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 621 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: 622 | 623 | - 1 indicates the head is **not masked**, 624 | - 0 indicates the head is **masked**. 625 | 626 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 627 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: 628 | 629 | - 1 indicates the head is **not masked**, 630 | - 0 indicates the head is **masked**. 631 | 632 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): 633 | Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: 634 | `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, 635 | *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 636 | cross-attention of the decoder. 637 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 638 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors 639 | of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 640 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 641 | 642 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 643 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 644 | 645 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` 646 | (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` 647 | instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated 648 | vectors than the model's internal embedding lookup matrix. 649 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 650 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 651 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` 652 | have to be input (see `past_key_values`). This is useful if you want more control over how to convert 653 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 654 | 655 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` 656 | takes the value of `inputs_embeds`. 657 | use_cache (`bool`, *optional*): 658 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up 659 | decoding (see `past_key_values`). 660 | output_attentions (`bool`, *optional*): 661 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 662 | tensors for more detail. 663 | output_hidden_states (`bool`, *optional*): 664 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 665 | more detail. 666 | return_dict (`bool`, *optional*): 667 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 668 | """ 669 | 670 | 671 | class BartEncoder(BartPretrainedModel): 672 | """ 673 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 674 | [`BartEncoderLayer`]. 675 | 676 | Args: 677 | config: BartConfig 678 | embed_tokens (nn.Embedding): output embedding 679 | """ 680 | 681 | def __init__(self, config: BartConfig, 682 | embed_tokens: Optional[nn.Embedding] = None, 683 | fusion_layers=[5], # 5 is the last layer 684 | use_forget_gate=True, 685 | dim_common=256, 686 | n_attn_heads=1): 687 | super().__init__(config) 688 | 689 | self.dropout = config.dropout 690 | self.layerdrop = config.encoder_layerdrop 691 | 692 | embed_dim = config.d_model 693 | self.padding_idx = config.pad_token_id 694 | self.max_source_positions = config.max_position_embeddings 695 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 696 | 697 | if embed_tokens is not None: 698 | self.embed_tokens = embed_tokens 699 | else: 700 | self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) 701 | 702 | self.embed_positions = BartLearnedPositionalEmbedding( 703 | config.max_position_embeddings, 704 | embed_dim, 705 | ) 706 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) 707 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 708 | 709 | self.gradient_checkpointing = False 710 | 711 | # ==================== Modification Starts ==================== 712 | # 1. params and variables 713 | self.use_forget_gate = use_forget_gate 714 | self.fusion_layers = fusion_layers 715 | music_feature_dim = 256 716 | text_feature_dim = embed_dim # 768 717 | 718 | # 2. define attention 719 | self._linear_1 = nn.Linear(music_feature_dim, dim_common) # K 720 | self._linear_2 = nn.Linear(music_feature_dim, dim_common) # V 721 | self._linear_3 = nn.Linear(text_feature_dim, dim_common) # Q 722 | self._multi_head_attn = nn.MultiheadAttention(dim_common, n_attn_heads) 723 | self._linear_4 = nn.Linear(text_feature_dim + dim_common, text_feature_dim) # TODO: it does not make sense 724 | if use_forget_gate: 725 | self.fg = nn.Linear(dim_common + text_feature_dim, dim_common) 726 | 727 | # ==================== Modification Ends ==================== 728 | self.final_layer_norm = nn.LayerNorm(embed_dim) 729 | self.sigmoid = nn.Sigmoid() 730 | 731 | # Initialize weights and apply final processing 732 | self.post_init() 733 | 734 | def get_input_embeddings(self): 735 | return self.embed_tokens 736 | 737 | def set_input_embeddings(self, value): 738 | self.embed_tokens = value 739 | 740 | def forward( 741 | self, 742 | input_ids=None, 743 | attention_mask=None, 744 | head_mask=None, 745 | inputs_embeds=None, 746 | output_attentions=None, 747 | output_hidden_states=None, 748 | return_dict=None, 749 | music_features=None, 750 | music_len=None 751 | ): 752 | r""" 753 | Args: 754 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 755 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 756 | provide it. 757 | 758 | Indices can be obtained using [`BartTokenizer`]. See 759 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] 760 | for details. 761 | 762 | [What are input IDs?](../glossary#input-ids) 763 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 764 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 765 | 766 | - 1 for tokens that are **not masked**, 767 | - 0 for tokens that are **masked**. 768 | 769 | [What are attention masks?](../glossary#attention-mask) 770 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 771 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 772 | 773 | - 1 indicates the head is **not masked**, 774 | - 0 indicates the head is **masked**. 775 | 776 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 777 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded 778 | representation. This is useful if you want more control over how to convert `input_ids` indices 779 | into associated vectors than the model's internal embedding lookup matrix. 780 | output_attentions (`bool`, *optional*): 781 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 782 | returned tensors for more detail. 783 | output_hidden_states (`bool`, *optional*): 784 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 785 | for more detail. 786 | return_dict (`bool`, *optional*): 787 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 788 | """ 789 | 790 | # ==================== Modification Starts ==================== 791 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 792 | output_hidden_states = ( 793 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 794 | ) 795 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 796 | 797 | # retrieve input_ids and inputs_embeds 798 | if input_ids is not None and inputs_embeds is not None: 799 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 800 | elif input_ids is not None: 801 | input_shape = input_ids.size() 802 | input_ids = input_ids.view(-1, input_shape[-1]) 803 | elif inputs_embeds is not None: 804 | input_shape = inputs_embeds.size()[:-1] 805 | else: 806 | raise ValueError("You have to specify either input_ids or inputs_embeds") 807 | 808 | if inputs_embeds is None: 809 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 810 | 811 | embed_pos = self.embed_positions(input_shape) 812 | 813 | hidden_states = inputs_embeds + embed_pos 814 | hidden_states = self.layernorm_embedding(hidden_states) 815 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 816 | 817 | # expand attention_mask 818 | if attention_mask is not None: 819 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 820 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 821 | 822 | encoder_states = () if output_hidden_states else None 823 | all_attentions = () if output_attentions else None 824 | 825 | # check if head_mask has a correct number of layers specified if desired 826 | if head_mask is not None: 827 | if head_mask.size()[0] != (len(self.layers)): 828 | raise ValueError( 829 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 830 | ) 831 | 832 | for idx, encoder_layer in enumerate(self.layers): 833 | if output_hidden_states: 834 | encoder_states = encoder_states + (hidden_states,) 835 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 836 | dropout_probability = random.uniform(0, 1) 837 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 838 | layer_outputs = (None, None) 839 | else: 840 | if self.gradient_checkpointing and self.training: 841 | 842 | def create_custom_forward(module): 843 | def custom_forward(*inputs): 844 | return module(*inputs, output_attentions) 845 | 846 | return custom_forward 847 | 848 | layer_outputs = torch.utils.checkpoint.checkpoint( 849 | create_custom_forward(encoder_layer), 850 | hidden_states, 851 | attention_mask, 852 | (head_mask[idx] if head_mask is not None else None), 853 | ) 854 | else: 855 | layer_outputs = encoder_layer( 856 | hidden_states, 857 | attention_mask, 858 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 859 | output_attentions=output_attentions, 860 | ) 861 | 862 | hidden_states = layer_outputs[0] 863 | 864 | # ==================== music-text fusion ===================== 865 | 866 | def forget_gate(music_features, text_features): 867 | forget_mask = self.fg(torch.cat((music_features, text_features), 2)) 868 | forget_mask = self.sigmiod(forget_mask) 869 | forget_mask = F.dropout(forget_mask, p=self.dropout, training=self.training) 870 | music_features = forget_mask.mul(music_features) 871 | return music_features 872 | 873 | if idx in self.fusion_layers: 874 | ''' 875 | => K_a = linear_1(V) in (S_v, D_a) 876 | => V_a = linear_2(V) in (S_v, D_a) 877 | => Q_a = linear_3(T) in (S_t, D_a) 878 | => T_out = MultiHeadAttn(Q_a, K_a, V_a) in (S_t, D_a) 879 | => T_out = linear_4(concat(T, T_out)) in (S_t, D_t) 880 | => T_out = T + T_out (Residual Connection) 881 | ''' 882 | K = self._linear_1(music_features).transpose(0, 1) 883 | V = self._linear_2(music_features).transpose(0, 1) 884 | Q = self._linear_3(hidden_states).transpose(0, 1) 885 | attn_output, _ = self._multi_head_attn(Q, K, V) 886 | attn_output = attn_output.transpose(0, 1) 887 | if self.use_forget_gate: 888 | forget_mask = self.fg(torch.cat((attn_output, hidden_states), 2)) 889 | forget_mask = self.sigmoid(forget_mask) 890 | forget_mask = F.dropout(forget_mask, p=self.dropout, training=self.training) 891 | attn_output = forget_mask.mul(attn_output) 892 | # output = self._linear_4(torch.cat((hidden_states, attn_output), 2)) 893 | 894 | # Residual Connection 895 | hidden_states = hidden_states + 0.1 * attn_output 896 | hidden_states = self.final_layer_norm(hidden_states) 897 | 898 | # ==================== music-text fusion ===================== 899 | 900 | if output_attentions: 901 | all_attentions = all_attentions + (layer_outputs[1],) 902 | 903 | if output_hidden_states: 904 | encoder_states = encoder_states + (hidden_states,) 905 | 906 | if not return_dict: 907 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 908 | return BaseModelOutput( 909 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 910 | ) 911 | 912 | 913 | class BartDecoder(BartPretrainedModel): 914 | """ 915 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] 916 | 917 | Args: 918 | config: BartConfig 919 | embed_tokens (nn.Embedding): output embedding 920 | """ 921 | 922 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 923 | super().__init__(config) 924 | self.dropout = config.dropout 925 | self.layerdrop = config.decoder_layerdrop 926 | self.padding_idx = config.pad_token_id 927 | self.max_target_positions = config.max_position_embeddings 928 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 929 | 930 | if embed_tokens is not None: 931 | self.embed_tokens = embed_tokens 932 | else: 933 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 934 | 935 | self.embed_positions = BartLearnedPositionalEmbedding( 936 | config.max_position_embeddings, 937 | config.d_model, 938 | ) 939 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 940 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 941 | 942 | self.gradient_checkpointing = False 943 | # Initialize weights and apply final processing 944 | self.post_init() 945 | 946 | def get_input_embeddings(self): 947 | return self.embed_tokens 948 | 949 | def set_input_embeddings(self, value): 950 | self.embed_tokens = value 951 | 952 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 953 | # create causal mask 954 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 955 | combined_attention_mask = None 956 | if input_shape[-1] > 1: 957 | combined_attention_mask = _make_causal_mask( 958 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 959 | ).to(self.device) 960 | 961 | if attention_mask is not None: 962 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 963 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 964 | combined_attention_mask = ( 965 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 966 | ) 967 | 968 | return combined_attention_mask 969 | 970 | def forward( 971 | self, 972 | input_ids=None, 973 | attention_mask=None, 974 | encoder_hidden_states=None, 975 | encoder_attention_mask=None, 976 | head_mask=None, 977 | cross_attn_head_mask=None, 978 | past_key_values=None, 979 | inputs_embeds=None, 980 | use_cache=None, 981 | output_attentions=None, 982 | output_hidden_states=None, 983 | return_dict=None, 984 | ): 985 | r""" 986 | Args: 987 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 988 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 989 | provide it. 990 | 991 | Indices can be obtained using [`BartTokenizer`]. See 992 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] 993 | for details. 994 | 995 | [What are input IDs?](../glossary#input-ids) 996 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 997 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 998 | 999 | - 1 for tokens that are **not masked**, 1000 | - 0 for tokens that are **masked**. 1001 | 1002 | [What are attention masks?](../glossary#attention-mask) 1003 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 1004 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1005 | of the decoder. 1006 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 1007 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 1008 | selected in `[0, 1]`: 1009 | 1010 | - 1 for tokens that are **not masked**, 1011 | - 0 for tokens that are **masked**. 1012 | 1013 | [What are attention masks?](../glossary#attention-mask) 1014 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1015 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1016 | 1017 | - 1 indicates the head is **not masked**, 1018 | - 0 indicates the head is **masked**. 1019 | 1020 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1021 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 1022 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 1023 | 1024 | - 1 indicates the head is **not masked**, 1025 | - 0 indicates the head is **masked**. 1026 | 1027 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1028 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 1029 | tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional 1030 | tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 1031 | 1032 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 1033 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential 1034 | decoding. 1035 | 1036 | If `past_key_values` are used, the user can optionally input only the last 1037 | `decoder_input_ids` (those that don't have their past key value states given to this model) of 1038 | shape `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size, 1039 | sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices 1040 | into associated vectors than the model's internal embedding lookup matrix. 1041 | output_attentions (`bool`, *optional*): 1042 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1043 | returned tensors for more detail. 1044 | output_hidden_states (`bool`, *optional*): 1045 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1046 | for more detail. 1047 | return_dict (`bool`, *optional*): 1048 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1049 | """ 1050 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1051 | output_hidden_states = ( 1052 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1053 | ) 1054 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1055 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1056 | 1057 | # retrieve input_ids and inputs_embeds 1058 | if input_ids is not None and inputs_embeds is not None: 1059 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1060 | elif input_ids is not None: 1061 | input_shape = input_ids.size() 1062 | input_ids = input_ids.view(-1, input_shape[-1]) 1063 | elif inputs_embeds is not None: 1064 | input_shape = inputs_embeds.size()[:-1] 1065 | else: 1066 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1067 | 1068 | # past_key_values_length 1069 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 1070 | 1071 | if inputs_embeds is None: 1072 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 1073 | 1074 | attention_mask = self._prepare_decoder_attention_mask( 1075 | attention_mask, input_shape, inputs_embeds, past_key_values_length 1076 | ) 1077 | 1078 | # expand encoder attention mask 1079 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 1080 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1081 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 1082 | 1083 | # embed positions 1084 | positions = self.embed_positions(input_shape, past_key_values_length) 1085 | 1086 | hidden_states = inputs_embeds + positions 1087 | hidden_states = self.layernorm_embedding(hidden_states) 1088 | 1089 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1090 | 1091 | # decoder layers 1092 | all_hidden_states = () if output_hidden_states else None 1093 | all_self_attns = () if output_attentions else None 1094 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 1095 | next_decoder_cache = () if use_cache else None 1096 | 1097 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 1098 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 1099 | if attn_mask is not None: 1100 | if attn_mask.size()[0] != (len(self.layers)): 1101 | raise ValueError( 1102 | "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 1103 | ) 1104 | 1105 | for idx, decoder_layer in enumerate(self.layers): 1106 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1107 | if output_hidden_states: 1108 | all_hidden_states += (hidden_states,) 1109 | dropout_probability = random.uniform(0, 1) 1110 | if self.training and (dropout_probability < self.layerdrop): 1111 | continue 1112 | 1113 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1114 | 1115 | if self.gradient_checkpointing and self.training: 1116 | 1117 | if use_cache: 1118 | logger.warning( 1119 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1120 | ) 1121 | use_cache = False 1122 | 1123 | def create_custom_forward(module): 1124 | def custom_forward(*inputs): 1125 | # None for past_key_value 1126 | return module(*inputs, output_attentions, use_cache) 1127 | 1128 | return custom_forward 1129 | 1130 | layer_outputs = torch.utils.checkpoint.checkpoint( 1131 | create_custom_forward(decoder_layer), 1132 | hidden_states, 1133 | attention_mask, 1134 | encoder_hidden_states, 1135 | encoder_attention_mask, 1136 | head_mask[idx] if head_mask is not None else None, 1137 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 1138 | None, 1139 | ) 1140 | else: 1141 | 1142 | layer_outputs = decoder_layer( 1143 | hidden_states, 1144 | attention_mask=attention_mask, 1145 | encoder_hidden_states=encoder_hidden_states, 1146 | encoder_attention_mask=encoder_attention_mask, 1147 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1148 | cross_attn_layer_head_mask=( 1149 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 1150 | ), 1151 | past_key_value=past_key_value, 1152 | output_attentions=output_attentions, 1153 | use_cache=use_cache, 1154 | ) 1155 | hidden_states = layer_outputs[0] 1156 | 1157 | if use_cache: 1158 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1159 | 1160 | if output_attentions: 1161 | all_self_attns += (layer_outputs[1],) 1162 | 1163 | if encoder_hidden_states is not None: 1164 | all_cross_attentions += (layer_outputs[2],) 1165 | 1166 | # add hidden states from the last decoder layer 1167 | if output_hidden_states: 1168 | all_hidden_states += (hidden_states,) 1169 | 1170 | next_cache = next_decoder_cache if use_cache else None 1171 | if not return_dict: 1172 | return tuple( 1173 | v 1174 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1175 | if v is not None 1176 | ) 1177 | return BaseModelOutputWithPastAndCrossAttentions( 1178 | last_hidden_state=hidden_states, 1179 | past_key_values=next_cache, 1180 | hidden_states=all_hidden_states, 1181 | attentions=all_self_attns, 1182 | cross_attentions=all_cross_attentions, 1183 | ) 1184 | 1185 | 1186 | @add_start_docstrings( 1187 | "The bare BART Model outputting raw hidden-states without any specific head on top.", 1188 | BART_START_DOCSTRING, 1189 | ) 1190 | class BartModel(BartPretrainedModel): 1191 | def __init__(self, config: BartConfig, 1192 | fusion_layers=None, 1193 | use_forget_gate=None, 1194 | dim_common=256, 1195 | n_attn_heads=1): 1196 | super().__init__(config) 1197 | 1198 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1199 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1200 | 1201 | self.encoder = BartEncoder(config, self.shared, fusion_layers, use_forget_gate, dim_common, n_attn_heads) 1202 | self.decoder = BartDecoder(config, self.shared) 1203 | 1204 | # Initialize weights and apply final processing 1205 | self.post_init() 1206 | 1207 | def get_input_embeddings(self): 1208 | return self.shared 1209 | 1210 | def set_input_embeddings(self, value): 1211 | self.shared = value 1212 | self.encoder.embed_tokens = self.shared 1213 | self.decoder.embed_tokens = self.shared 1214 | 1215 | def get_encoder(self): 1216 | return self.encoder 1217 | 1218 | def get_decoder(self): 1219 | return self.decoder 1220 | 1221 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1222 | @add_code_sample_docstrings( 1223 | processor_class=_TOKENIZER_FOR_DOC, 1224 | checkpoint=_CHECKPOINT_FOR_DOC, 1225 | output_type=Seq2SeqModelOutput, 1226 | config_class=_CONFIG_FOR_DOC, 1227 | ) 1228 | def forward( 1229 | self, 1230 | input_ids=None, 1231 | attention_mask=None, 1232 | decoder_input_ids=None, 1233 | decoder_attention_mask=None, 1234 | head_mask=None, 1235 | decoder_head_mask=None, 1236 | cross_attn_head_mask=None, 1237 | encoder_outputs=None, 1238 | past_key_values=None, 1239 | inputs_embeds=None, 1240 | decoder_inputs_embeds=None, 1241 | use_cache=None, 1242 | output_attentions=None, 1243 | output_hidden_states=None, 1244 | return_dict=None, 1245 | music_features=None, 1246 | music_len=None, 1247 | ): 1248 | 1249 | # different to other models, Bart automatically creates decoder_input_ids from 1250 | # input_ids if no decoder_input_ids are provided 1251 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1252 | if input_ids is None: 1253 | raise ValueError( 1254 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 1255 | "passed, `input_ids` cannot be `None`. Please pass either " 1256 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 1257 | ) 1258 | 1259 | decoder_input_ids = shift_tokens_right( 1260 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 1261 | ) 1262 | 1263 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1264 | output_hidden_states = ( 1265 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1266 | ) 1267 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1268 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1269 | 1270 | if encoder_outputs is None: 1271 | encoder_outputs = self.encoder( 1272 | input_ids=input_ids, 1273 | attention_mask=attention_mask, 1274 | head_mask=head_mask, 1275 | inputs_embeds=inputs_embeds, 1276 | output_attentions=output_attentions, 1277 | output_hidden_states=output_hidden_states, 1278 | return_dict=return_dict, 1279 | music_features=music_features, 1280 | music_len=music_len, 1281 | ) 1282 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1283 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1284 | encoder_outputs = BaseModelOutput( 1285 | last_hidden_state=encoder_outputs[0], 1286 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1287 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1288 | ) 1289 | 1290 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1291 | decoder_outputs = self.decoder( 1292 | input_ids=decoder_input_ids, 1293 | attention_mask=decoder_attention_mask, 1294 | encoder_hidden_states=encoder_outputs[0], 1295 | encoder_attention_mask=attention_mask, 1296 | head_mask=decoder_head_mask, 1297 | cross_attn_head_mask=cross_attn_head_mask, 1298 | past_key_values=past_key_values, 1299 | inputs_embeds=decoder_inputs_embeds, 1300 | use_cache=use_cache, 1301 | output_attentions=output_attentions, 1302 | output_hidden_states=output_hidden_states, 1303 | return_dict=return_dict, 1304 | ) 1305 | 1306 | if not return_dict: 1307 | return decoder_outputs + encoder_outputs 1308 | 1309 | return Seq2SeqModelOutput( 1310 | last_hidden_state=decoder_outputs.last_hidden_state, 1311 | past_key_values=decoder_outputs.past_key_values, 1312 | decoder_hidden_states=decoder_outputs.hidden_states, 1313 | decoder_attentions=decoder_outputs.attentions, 1314 | cross_attentions=decoder_outputs.cross_attentions, 1315 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1316 | encoder_hidden_states=encoder_outputs.hidden_states, 1317 | encoder_attentions=encoder_outputs.attentions, 1318 | ) 1319 | 1320 | 1321 | @add_start_docstrings( 1322 | "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING 1323 | ) 1324 | class BartForMultimodalGeneration(BartPretrainedModel): 1325 | base_model_prefix = "model" 1326 | _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] 1327 | 1328 | def __init__(self, config: BartConfig, fusion_layers=None, use_forget_gate=None, dim_common=256, n_attn_heads=1): 1329 | super().__init__(config) 1330 | self.model = BartModel(config, fusion_layers, use_forget_gate, dim_common, n_attn_heads) 1331 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1332 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1333 | 1334 | # Initialize weights and apply final processing 1335 | self.post_init() 1336 | 1337 | def get_encoder(self): 1338 | return self.model.get_encoder() 1339 | 1340 | def get_decoder(self): 1341 | return self.model.get_decoder() 1342 | 1343 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1344 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1345 | self._resize_final_logits_bias(new_num_tokens) 1346 | return new_embeddings 1347 | 1348 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1349 | old_num_tokens = self.final_logits_bias.shape[-1] 1350 | if new_num_tokens <= old_num_tokens: 1351 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1352 | else: 1353 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1354 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1355 | self.register_buffer("final_logits_bias", new_bias) 1356 | 1357 | def get_output_embeddings(self): 1358 | return self.lm_head 1359 | 1360 | def set_output_embeddings(self, new_embeddings): 1361 | self.lm_head = new_embeddings 1362 | 1363 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1364 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1365 | @add_end_docstrings(BART_GENERATION_EXAMPLE) 1366 | def forward( 1367 | self, 1368 | input_ids=None, 1369 | attention_mask=None, 1370 | decoder_input_ids=None, 1371 | decoder_attention_mask=None, 1372 | head_mask=None, 1373 | decoder_head_mask=None, 1374 | cross_attn_head_mask=None, 1375 | encoder_outputs=None, 1376 | past_key_values=None, 1377 | inputs_embeds=None, 1378 | decoder_inputs_embeds=None, 1379 | labels=None, 1380 | use_cache=None, 1381 | output_attentions=None, 1382 | output_hidden_states=None, 1383 | return_dict=None, 1384 | music_features=None, 1385 | music_len=None, 1386 | ): 1387 | r""" 1388 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1389 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1390 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1391 | 1392 | Returns: 1393 | """ 1394 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1395 | 1396 | if labels is not None: 1397 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1398 | decoder_input_ids = shift_tokens_right( 1399 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1400 | ) 1401 | 1402 | outputs = self.model( 1403 | input_ids, 1404 | attention_mask=attention_mask, 1405 | decoder_input_ids=decoder_input_ids, 1406 | encoder_outputs=encoder_outputs, 1407 | decoder_attention_mask=decoder_attention_mask, 1408 | head_mask=head_mask, 1409 | decoder_head_mask=decoder_head_mask, 1410 | cross_attn_head_mask=cross_attn_head_mask, 1411 | past_key_values=past_key_values, 1412 | inputs_embeds=inputs_embeds, 1413 | decoder_inputs_embeds=decoder_inputs_embeds, 1414 | use_cache=use_cache, 1415 | output_attentions=output_attentions, 1416 | output_hidden_states=output_hidden_states, 1417 | return_dict=return_dict, 1418 | music_features=music_features, 1419 | music_len=music_len, 1420 | ) 1421 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1422 | 1423 | masked_lm_loss = None 1424 | if labels is not None: 1425 | loss_fct = CrossEntropyLoss() 1426 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1427 | 1428 | if not return_dict: 1429 | output = (lm_logits,) + outputs[1:] 1430 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1431 | 1432 | return Seq2SeqLMOutput( 1433 | loss=masked_lm_loss, 1434 | logits=lm_logits, 1435 | past_key_values=outputs.past_key_values, 1436 | decoder_hidden_states=outputs.decoder_hidden_states, 1437 | decoder_attentions=outputs.decoder_attentions, 1438 | cross_attentions=outputs.cross_attentions, 1439 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1440 | encoder_hidden_states=outputs.encoder_hidden_states, 1441 | encoder_attentions=outputs.encoder_attentions, 1442 | ) 1443 | 1444 | def prepare_inputs_for_generation( 1445 | self, 1446 | decoder_input_ids, 1447 | past=None, 1448 | attention_mask=None, 1449 | head_mask=None, 1450 | decoder_head_mask=None, 1451 | cross_attn_head_mask=None, 1452 | use_cache=None, 1453 | encoder_outputs=None, 1454 | **kwargs 1455 | ): 1456 | # cut decoder_input_ids if past is used 1457 | if past is not None: 1458 | decoder_input_ids = decoder_input_ids[:, -1:] 1459 | 1460 | return { 1461 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1462 | "encoder_outputs": encoder_outputs, 1463 | "past_key_values": past, 1464 | "decoder_input_ids": decoder_input_ids, 1465 | "attention_mask": attention_mask, 1466 | "head_mask": head_mask, 1467 | "decoder_head_mask": decoder_head_mask, 1468 | "cross_attn_head_mask": cross_attn_head_mask, 1469 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1470 | } 1471 | 1472 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1473 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1474 | 1475 | @staticmethod 1476 | def _reorder_cache(past, beam_idx): 1477 | reordered_past = () 1478 | for layer_past in past: 1479 | # cached cross_attention states don't have to be reordered -> they are always the same 1480 | reordered_past += ( 1481 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1482 | ) 1483 | return reordered_past -------------------------------------------------------------------------------- /code/music_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchaudio 5 | import os 6 | import random 7 | 8 | from attention_modules import BertConfig, BertEncoder, BertPooler 9 | 10 | 11 | class Conv_1d(nn.Module): 12 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): 13 | super(Conv_1d, self).__init__() 14 | self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 15 | self.bn = nn.BatchNorm1d(output_channels) 16 | self.relu = nn.ReLU() 17 | self.mp = nn.MaxPool1d(pooling) 18 | def forward(self, x): 19 | out = self.mp(self.relu(self.bn(self.conv(x)))) 20 | return out 21 | 22 | 23 | class Conv_2d(nn.Module): 24 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): 25 | super(Conv_2d, self).__init__() 26 | self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 27 | self.bn = nn.BatchNorm2d(output_channels) 28 | self.relu = nn.ReLU() 29 | self.mp = nn.MaxPool2d(pooling) 30 | def forward(self, x): 31 | out = self.mp(self.relu(self.bn(self.conv(x)))) 32 | return out 33 | 34 | 35 | class Res_2d(nn.Module): 36 | def __init__(self, input_channels, output_channels, shape=3, stride=2): 37 | super(Res_2d, self).__init__() 38 | # convolution 39 | self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 40 | self.bn_1 = nn.BatchNorm2d(output_channels) 41 | self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) 42 | self.bn_2 = nn.BatchNorm2d(output_channels) 43 | 44 | # residual 45 | self.diff = False 46 | if (stride != 1) or (input_channels != output_channels): 47 | self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 48 | self.bn_3 = nn.BatchNorm2d(output_channels) 49 | self.diff = True 50 | self.relu = nn.ReLU() 51 | 52 | def forward(self, x): 53 | # convolution 54 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 55 | 56 | # residual 57 | if self.diff: 58 | x = self.bn_3(self.conv_3(x)) 59 | out = x + out 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | class CNNSA(nn.Module): 65 | ''' 66 | Won et al. 2019 67 | Toward interpretable music tagging with self-attention. 68 | Feature extraction with CNN + temporal summary with Transformer encoder. 69 | ''' 70 | def __init__(self, 71 | n_channels=128, 72 | sample_rate=16000, 73 | n_fft=512, 74 | f_min=0.0, 75 | f_max=8000.0, 76 | n_mels=128, 77 | n_class=50): 78 | super(CNNSA, self).__init__() 79 | 80 | # Spectrogram 81 | self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, 82 | n_fft=n_fft, 83 | f_min=f_min, 84 | f_max=f_max, 85 | n_mels=n_mels) 86 | self.to_db = torchaudio.transforms.AmplitudeToDB() 87 | self.spec_bn = nn.BatchNorm2d(1) 88 | 89 | # CNN 90 | self.layer1 = Res_2d(1, n_channels, stride=2) 91 | self.layer2 = Res_2d(n_channels, n_channels, stride=2) 92 | self.layer3 = Res_2d(n_channels, n_channels*2, stride=2) 93 | self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) 94 | self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) 95 | self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) 96 | self.layer7 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) 97 | 98 | # Transformer encoder 99 | bert_config = BertConfig(vocab_size=256, 100 | hidden_size=256, 101 | num_hidden_layers=2, 102 | num_attention_heads=8, 103 | intermediate_size=1024, 104 | hidden_act="gelu", 105 | hidden_dropout_prob=0.4, 106 | max_position_embeddings=700, 107 | attention_probs_dropout_prob=0.5) 108 | self.encoder = BertEncoder(bert_config) 109 | self.pooler = BertPooler(bert_config) 110 | self.vec_cls = self.get_cls(256) 111 | 112 | # Dense 113 | self.dropout = nn.Dropout(0.5) 114 | self.dense = nn.Linear(256, n_class) 115 | 116 | def get_cls(self, channel): 117 | np.random.seed(0) 118 | single_cls = torch.Tensor(np.random.random((1, channel))) 119 | vec_cls = torch.cat([single_cls for _ in range(64)], dim=0) 120 | vec_cls = vec_cls.unsqueeze(1) 121 | return vec_cls 122 | 123 | def append_cls(self, x): 124 | batch, _, _ = x.size() 125 | part_vec_cls = self.vec_cls[:batch].clone() 126 | part_vec_cls = part_vec_cls.to(x.device) 127 | return torch.cat([part_vec_cls, x], dim=1) 128 | 129 | def get_spec(self, ids, audio_length=15*16000, allow_random=False): 130 | 131 | wav_list = list() 132 | 133 | for id in ids: 134 | audio_path = os.path.join("/import/c4dm-datasets/Music4All/music4all/audios", id + '.mp3') 135 | (wav, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path) 136 | 137 | # to mono 138 | mono_wav = torch.mean(wav, dim=0) 139 | 140 | # cut length 141 | if allow_random: 142 | random_index = random.randint(0, len(mono_wav) - audio_length - 1) 143 | else: 144 | random_index = 0 145 | mono_wav_cut = mono_wav[random_index: random_index + audio_length] 146 | 147 | wav_list.append(mono_wav_cut) 148 | 149 | # merge wav to (bs, length) 150 | data = torch.stack(wav_list, dim=0) 151 | 152 | # to spectrogram 153 | spectrogram = self.spec(data.cuda()) 154 | 155 | return spectrogram 156 | 157 | def forward(self, ids): 158 | # Spectrogram 159 | # for batch 160 | spec = self.get_spec(ids) 161 | spec_db = self.to_db(spec) 162 | x = spec_db.unsqueeze(1) # add channel dim 163 | x = self.spec_bn(x) 164 | 165 | # CNN 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | x = self.layer5(x) 171 | x = self.layer6(x) 172 | x = self.layer7(x) 173 | x = x.squeeze(2) 174 | 175 | # Get [CLS] token 176 | x = x.permute(0, 2, 1) 177 | x = self.append_cls(x) 178 | 179 | # Transformer encoder 180 | x = self.encoder(x) 181 | x = x[-1] # last layer 182 | # x = self.pooler(x) 183 | # 184 | # # Dense 185 | # x = self.dropout(x) 186 | # x = self.dense(x) 187 | # x = nn.Sigmoid()(x) 188 | 189 | return x # return the last layer. Shape: (length, 256) 190 | 191 | 192 | # test code 193 | # model = CNNSA() 194 | # model.load_state_dict(torch.load("best_model.pth")) 195 | # id = ["wlIcjSZkgW0cgWrm", "wlIcjSZkgW0cgWrm"] 196 | # output = model(id) -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from data import LyricsCommentsDatasetPsuedo 4 | from torch import utils, nn 5 | from model import CommentGenerator 6 | import transformers 7 | import time 8 | import statistics 9 | import os 10 | import random 11 | import datasets 12 | 13 | IS_LOAD = False 14 | LOAD_EPOCH = 0 15 | EPOCH = 20 16 | BATCH_SIZE = 8 17 | LOG_INTERVAL = 100 18 | SAMPLE_INTERVAL = 2000 19 | VALIDATION_INTERVAL = 2 20 | LOG_FOLDER = "log/" 21 | MODEL_FOLDER = "model/" 22 | EARLY_STOPPING_INTERVAL = 5 23 | MODEL_NAME = "bart_baseline_full_256" 24 | CHOICE_NUMBER = 5 25 | DATASET_PATH = "dataset_not_negative_256.pkl" 26 | 27 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 28 | 29 | dataset = LyricsCommentsDatasetPsuedo(dataset_path=DATASET_PATH) 30 | dataset_length = len(dataset) 31 | 32 | train_dataset_length = int(dataset_length * 0.9) 33 | valid_dataset_length = dataset_length - train_dataset_length 34 | train_dataset, valid_dataset = utils.data.random_split(dataset, 35 | [train_dataset_length, 36 | valid_dataset_length], 37 | generator=torch.Generator().manual_seed(42)) 38 | train_dataloader = utils.data.DataLoader(train_dataset, 39 | batch_size=BATCH_SIZE, 40 | shuffle=True) 41 | valid_dataloader = utils.data.DataLoader(valid_dataset, 42 | batch_size=32, 43 | shuffle=False) 44 | 45 | model = CommentGenerator().cuda() 46 | 47 | criterion = nn.CrossEntropyLoss() 48 | 49 | optimizer = transformers.Adafactor(model.parameters(), warmup_init=False, relative_step=False, 50 | lr=6e-4, 51 | ) 52 | 53 | loss_stat = list() 54 | start_time = time.time() 55 | start_time_local = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 56 | 57 | early_stop_token = (0.0, 0) 58 | 59 | model.train() 60 | for epoch in range(1 + LOAD_EPOCH, EPOCH + 1 + LOAD_EPOCH): 61 | for batch_index, [lyrics, comment] in enumerate(train_dataloader): 62 | # pre-process data 63 | input_sentences = lyrics 64 | raw_labels = comment 65 | output = model(input_sentences, raw_labels) 66 | loss = output.loss 67 | 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | loss_stat.append(loss.item()) 72 | 73 | # log 74 | if batch_index and batch_index % LOG_INTERVAL == 0: 75 | curr_time = time.time() 76 | passed_time_all = curr_time - start_time 77 | time_str = f"{int(passed_time_all / 60)}:{int(passed_time_all % 60)}" 78 | log = f"{MODEL_NAME}\t" \ 79 | f"Time: {time_str}\t" \ 80 | f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \ 81 | f"Loss: {statistics.mean(loss_stat[-1 * BATCH_SIZE:])}\t" \ 82 | f"Avg loss: {statistics.mean(loss_stat)}" 83 | if __debug__: 84 | print(log) 85 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + "_" + start_time_local + ".txt"), 'a+', encoding='utf-8') as r: 86 | r.write(log) 87 | r.write("\n") 88 | loss_stat = list() 89 | 90 | if batch_index and batch_index % SAMPLE_INTERVAL == 0: 91 | 92 | model.eval() 93 | samples_list = random.choices(valid_dataset, k=CHOICE_NUMBER) 94 | sample_sentence, sample_label = zip(*samples_list) 95 | output_samples = model.generate(sample_sentence) 96 | for sample_index in range(CHOICE_NUMBER): 97 | log = f"Lyrics: {sample_sentence[sample_index]}\n" \ 98 | f"Sample outputs: {output_samples[sample_index]}\n" \ 99 | f"Ground Truth: {sample_label[sample_index]}" 100 | if __debug__: 101 | print(log) 102 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + "_" + start_time_local + ".txt"), 'a+', encoding='utf-8') as r: 103 | r.write(log) 104 | r.write("\n") 105 | model.train() 106 | 107 | if epoch and epoch % VALIDATION_INTERVAL == 0: 108 | model.eval() 109 | metrics = datasets.load_metric('rouge') 110 | valid_dataloader = utils.data.DataLoader(valid_dataset, 111 | batch_size=32, 112 | shuffle=False) 113 | for batch_index_valid, [lyrics_valid, comment_valid] in enumerate(valid_dataloader): 114 | output_samples = model.generate(lyrics_valid) 115 | metrics.add_batch(predictions=output_samples, references=comment_valid) 116 | 117 | # control time. 118 | if batch_index_valid > 10: 119 | break 120 | score = metrics.compute() 121 | if __debug__: 122 | print(str(score)) 123 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+', 124 | encoding='utf-8') as r: 125 | r.write(str(score)) 126 | r.write("\n") 127 | 128 | # save 129 | if score['rouge1'].mid.recall > early_stop_token[0]: 130 | early_stop_token = [score['rouge1'].mid.recall, epoch] # replace to the best 131 | torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_best.pt")) 132 | torch.save(optimizer.state_dict(), 133 | os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_best.pt")) 134 | 135 | if epoch: 136 | torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_epoch{epoch}.pt")) 137 | torch.save(optimizer.state_dict(), 138 | os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_epoch{epoch}.pt")) 139 | 140 | # early stopping 141 | if score['rouge1'].mid.recall <= early_stop_token[0] and epoch > ( 142 | early_stop_token[1] + EARLY_STOPPING_INTERVAL): 143 | print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.") 144 | 145 | model.train() -------------------------------------------------------------------------------- /code/train_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from data import LyricsCommentsDatasetPsuedo_fusion 4 | from torch import utils, nn 5 | from model_fusion import CommentGenerator_fusion 6 | import transformers 7 | import time 8 | import statistics 9 | import os 10 | import random 11 | import datasets 12 | 13 | IS_LOAD = False 14 | LOAD_EPOCH = 0 15 | EPOCH = 50 16 | BATCH_SIZE = 8 17 | LOG_INTERVAL = 100 18 | SAMPLE_INTERVAL = 1000 19 | VALIDATION_INTERVAL = 2 20 | LOG_FOLDER = "log/" 21 | MODEL_FOLDER = "model/" 22 | SAVE_INTERVAL = 2 23 | EARLY_STOPPING_INTERVAL = 5 24 | MODEL_NAME = "bart_fusion_full_256" 25 | CHOICE_NUMBER = 2 26 | DATASET_PATH = "/homes/yz007/multimodal-transformer/comment_generator/dataset_full_256.pkl" 27 | 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 29 | 30 | dataset = LyricsCommentsDatasetPsuedo_fusion(dataset_path=DATASET_PATH) 31 | dataset_length = len(dataset) 32 | 33 | train_dataset_length = int(dataset_length * 0.9) 34 | valid_dataset_length = dataset_length - train_dataset_length 35 | train_dataset, valid_dataset = utils.data.random_split(dataset, 36 | [train_dataset_length, 37 | valid_dataset_length], 38 | generator=torch.Generator().manual_seed(42)) 39 | train_dataloader = utils.data.DataLoader(train_dataset, 40 | batch_size=BATCH_SIZE, 41 | shuffle=True) 42 | # valid_dataloader = utils.data.DataLoader(valid_dataset, 43 | # batch_size=32, 44 | # shuffle=False) 45 | 46 | model = CommentGenerator_fusion().cuda() 47 | 48 | criterion = nn.CrossEntropyLoss() 49 | 50 | 51 | 52 | # optimizer = transformers.Adafactor(filter(lambda p: p.requires_grad, model.parameters()), 53 | # lr=6e-4, 54 | # ) 55 | optimizer = transformers.Adafactor(model.parameters(), warmup_init=False, relative_step=False, 56 | lr=6e-4, 57 | ) 58 | 59 | if IS_LOAD: 60 | model.load_state_dict(torch.load("/homes/yz007/multimodal-transformer/comment_generator/model/bart_fusion_positive_256_6e-4_epoch6.pt")) 61 | optimizer.load_state_dict(torch.load("/homes/yz007/multimodal-transformer/comment_generator/model/bart_fusion_positive_256_6e-4_optim_epoch6.pt")) 62 | 63 | loss_stat = list() 64 | start_time = time.time() 65 | start_time_local = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 66 | 67 | early_stop_token = [0.0, 0] 68 | validation_loss_history = list() 69 | 70 | model.train() 71 | for epoch in range(1 + LOAD_EPOCH, EPOCH + 1 + LOAD_EPOCH): 72 | for batch_index, [lyrics, comment, music_id] in enumerate(train_dataloader): 73 | # pre-process data 74 | input_sentences = lyrics 75 | raw_labels = comment 76 | output = model(input_sentences, music_id, raw_labels) 77 | loss = output.loss 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | loss_stat.append(loss.item()) 83 | 84 | # log 85 | if batch_index and batch_index % LOG_INTERVAL == 0: 86 | curr_time = time.time() 87 | passed_time_all = curr_time - start_time 88 | time_str = f"{int(passed_time_all / 60)}:{int(passed_time_all % 60)}" 89 | log = f"{MODEL_NAME}\t" \ 90 | f"Time: {time_str}\t" \ 91 | f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \ 92 | f"Loss: {statistics.mean(loss_stat[-1 * LOG_INTERVAL * BATCH_SIZE:])}\t" \ 93 | f"Avg loss: {statistics.mean(loss_stat)}" 94 | if __debug__: 95 | print(log) 96 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+', 97 | encoding='utf-8') as r: 98 | r.write(log) 99 | r.write("\n") 100 | loss_stat = list() 101 | 102 | if batch_index and batch_index % SAMPLE_INTERVAL == 0: 103 | # make samples 104 | model.eval() 105 | samples_list = random.choices(valid_dataset, k=CHOICE_NUMBER) 106 | sample_sentence, sample_label, music_ids = zip(*samples_list) 107 | with torch.no_grad(): 108 | output_samples = model.generate(sample_sentence, music_ids) 109 | for sample_index in range(CHOICE_NUMBER): 110 | log = f"Lyrics: {sample_sentence[sample_index]}\n" \ 111 | f"Sample outputs: {output_samples[sample_index]}\n" \ 112 | f"Ground Truth: {sample_label[sample_index]}" 113 | if __debug__: 114 | print(log) 115 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+', 116 | encoding='utf-8') as r: 117 | r.write(log) 118 | r.write("\n") 119 | 120 | # validation loss 121 | valid_dataloader = utils.data.DataLoader(valid_dataset, 122 | batch_size=8, 123 | shuffle=False) 124 | valid_loss_stat = list() 125 | for batch_index_valid, [lyrics_valid, comment_valid, music_id_valid] in enumerate(valid_dataloader): 126 | with torch.no_grad(): 127 | output_valid = model(lyrics_valid, music_id_valid, comment_valid) 128 | valid_loss = output_valid.loss.item() 129 | valid_loss_stat.append(valid_loss) 130 | if batch_index_valid > 15: 131 | break 132 | valid_loss_mean = statistics.mean(valid_loss_stat) 133 | validation_loss_history.append(valid_loss_mean) 134 | log = f"{MODEL_NAME}\t" \ 135 | f"Time: {time_str}\t" \ 136 | f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \ 137 | f"Validation Loss: {valid_loss_mean}\t" 138 | if __debug__: 139 | print(log) 140 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+', 141 | encoding='utf-8') as r: 142 | r.write(log) 143 | r.write("\n") 144 | 145 | # back to train 146 | model.train() 147 | 148 | if epoch and epoch % VALIDATION_INTERVAL == 0: 149 | model.eval() 150 | metrics = datasets.load_metric('rouge') 151 | valid_dataloader = utils.data.DataLoader(valid_dataset, 152 | batch_size=8, 153 | shuffle=False) 154 | for batch_index_valid, [lyrics_valid, comment_valid, music_id_valid] in enumerate(valid_dataloader): 155 | with torch.no_grad(): 156 | output_samples = model.generate(lyrics_valid, music_id_valid) 157 | metrics.add_batch(predictions=output_samples, references=comment_valid) 158 | # control time. 159 | if batch_index_valid > 10: 160 | break 161 | score = metrics.compute() 162 | if __debug__: 163 | print(str(score)) 164 | with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+', 165 | encoding='utf-8') as r: 166 | r.write(str(score)) 167 | r.write("\n") 168 | 169 | # save 170 | if score['rouge1'].mid.recall > early_stop_token[0]: 171 | early_stop_token = [score['rouge1'].mid.recall, epoch] # replace to the best 172 | torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_best.pt")) 173 | torch.save(optimizer.state_dict(), 174 | os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_best.pt")) 175 | 176 | # save 177 | if epoch and epoch % SAVE_INTERVAL == 0: 178 | torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_epoch{epoch}.pt")) 179 | torch.save(optimizer.state_dict(), 180 | os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_epoch{epoch}.pt")) 181 | 182 | # early stopping 183 | if len(validation_loss_history) > EARLY_STOPPING_INTERVAL: 184 | if min(validation_loss_history[-2 * EARLY_STOPPING_INTERVAL:]) == validation_loss_history[-2 * EARLY_STOPPING_INTERVAL]: 185 | print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.") 186 | break 187 | if score['rouge1'].mid.recall <= early_stop_token[0] and epoch > ( 188 | early_stop_token[1] + EARLY_STOPPING_INTERVAL): 189 | print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.") 190 | break 191 | model.train() 192 | 193 | print(f"Training Complete. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.") 194 | --------------------------------------------------------------------------------