├── model.jpg ├── scripts ├── train_model.sh ├── synthesize_audio.sh ├── process_data.sh └── download_data.sh ├── requirements.txt ├── README.md └── src ├── process_data.py ├── inference.py ├── train.py └── model.py /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bwang514/PerformanceNet/HEAD/model.jpg -------------------------------------------------------------------------------- /scripts/train_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd )" 3 | python3 "${DIR}/../src/train.py" $1 $2 $3 $4 4 | 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.1.0 2 | numpy==1.14.2 3 | tqdm==4.28.1 4 | h5py==2.9.0 5 | intervaltree==3.0.2 6 | librosa==0.6.3 7 | pretty_midi==0.2.8 8 | scikit_learn==0.20.2 9 | pytorch==0.4.0 10 | -------------------------------------------------------------------------------- /scripts/synthesize_audio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Model inferencing and synthesizing output audio 3 | 4 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd )" 5 | python3 "${DIR}/../src/inference.py" $1 $2 6 | -------------------------------------------------------------------------------- /scripts/process_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script store the training data to shared memory. 3 | # Usage: process_data.sh 4 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd )" 5 | python "${DIR}/../src/process_data.py" 6 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script downloads the MusicNet dataset to the default data 3 | # diretory. 4 | # Usage: download_data.sh 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd )" 6 | DST="${DIR}/../data" 7 | mkdir -p "$DST" 8 | 9 | wget -P "$DST" "https://homes.cs.washington.edu/~thickstn/media/musicnet.npz" 10 | 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # PerformanceNet 3 | 4 | 5 | ![Model image](https://github.com/bwang514/PerformanceNet/blob/master/model.jpg) 6 | 7 | PerformanceNet is a deep convolutional model that learns in an end-to-end manner the score-to-audio mapping between musical scores and the correspondent real audio performance. Our work represents a humble yet valuable step towards the dream of **The AI Musician**. Find more details in our AAAI 2019 [paper](https://arxiv.org/abs/1811.04357)! 8 | 9 | 10 | ## Prerequisites 11 | 12 | > __Below we assume the working directory is the repository root.__ 13 | 14 | ### Install dependencies 15 | 16 | ```sh 17 | # Install the dependencies 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Prepare training data 22 | 23 | > PerformanceNet utilizes the [MusicNet](https://homes.cs.washington.edu/~thickstn/start.html) dataset 24 | , which provides musical scores and the correspondant performance audio data. 25 | 26 | ```sh 27 | # Download the training data 28 | ./scripts/download_data.sh 29 | ``` 30 | You can also download the training data manually 31 | ([musicnet.npz](https://homes.cs.washington.edu/~thickstn/media/musicnet.npz)). 32 | 33 | > Pre-process the dataset into pianorolls and spectrogram used for training PerformanceNet. 34 | 35 | ```sh 36 | # Pre-process the dataset 37 | ./scripts/process_data.sh 38 | ``` 39 | ## Scripts 40 | 41 | > __Below we assume the working directory is the repository root.__ 42 | 43 | We provide the scripts for easy managing the experiments. 44 | 45 | ### Train a new model 46 | 47 | 1. Run the following command to set up a new experiment. (~= 300 epoch to obtain good results) 48 | 49 | > The arguments are (in order) __1. instrument 2.training epoch 3. testing frequency 4. experiment name.__ 50 | 51 | ```sh 52 | # Set up a new experiment 53 | ./scripts/train_model.sh cello 300 10 cello_exp_1 54 | ``` 55 | 56 | ### Inference and generate audio 57 | 58 | We use the Griffin-Lim algorithm to convert the output spectrogram into audio waveform. (__Note:__ it can take very long time to synthesize a longer audio) 59 | 60 | 1. Synthesizing with test data split from the Musicnet dataset (New folder that contains generated audio would be created automatically in your exp directory.) 61 | 62 | > The arguments are (in order) 1. experiment directory 2. data resource (TEST_DATA means using the test data split from training dataset.) 63 | 64 | ```sh 65 | # Generating 5 * 5 seconds audio clip by default 66 | ./scripts/synthesize_audio.sh cello_exp_1 TEST_DATA 67 | ``` 68 | 69 | 2. Synthesizing audio from your own midi file: 70 | 71 | > Please manually create a directory called "midi" in you experiment directory, then put the midi files into it before executing this script 72 | 73 | ```sh 74 | # Generating one audio clip, length depends on your midi score. 75 | ./scripts/synthesize_audio.sh cello_exp_1 YOUR_MIDI_FILE.midi 76 | ``` 77 | 78 | Our model can perform any solo music given the score. Therefore we provide a convenient script to convert any .midi file to the input for our model. The quality could vary in different keys, as some notes may never appear in training data. Common keys (C, D, G) should work well though. Also it's important to make sure the note range is within the instrument's range. 79 | 80 | 81 | ## Sound examples 82 | 83 | 1. Violin: https://www.youtube.com/watch?v=kAEbbNUEEgI 84 | 2. Flute: https://www.youtube.com/watch?v=Y38Z2De1NFo 85 | 3. Cello: https://www.youtube.com/watch?v=3LzN3GvMNeU 86 | 4. 吳萼洋 蜂蜜檸檬 cover: https://youtu.be/k0-cT6GxS3g 87 | 88 | ## Attribution 89 | 90 | If you use this code in your research, please cite the following papers: 91 | 92 | 1. __PerformanceNet: Score-to-Audio Music Generation with Multi-Band Convolutional Residual Network.__ Bryan Wang, Yi-Hsuan Yang. _in Proceedings of the 33rd AAAI Conference on Artificial Intelligence __(AAAI)__, 2019_. [[paper](https://wvvw.aaai.org/ojs/index.php/AAAI/article/view/3911)]
93 | 2. __Demonstration of PerformanceNet: A Convolutional Neural Network Model for Score-to-Audio Music Generation.__ Yu-Hua Chen, Bryan Wang, Yi-Hsuan Yang. _in Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence __(IJCAI)__, Demos, 2019_. [[paper](https://www.ijcai.org/proceedings/2019/938)] 94 | 95 | 96 | -------------------------------------------------------------------------------- /src/process_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa.output 3 | import librosa 4 | from intervaltree import Interval,IntervalTree 5 | from scipy import fft 6 | import pickle 7 | import h5py 8 | import sys 9 | 10 | 11 | class hyperparams(object): 12 | def __init__(self): 13 | self.sr = 44100 # Sampling rate. 14 | self.n_fft = 2048 # fft points (samples) 15 | self.stride = 256 # 256 samples hop between windows 16 | self.wps = 44100 // 256 # ~86 windows/second 17 | self.instrument = { 18 | 'cello': [2217, 2218, 2219, 2220 ,2221, 2222, 2293, 2294, 2295, 2296, 2297, 2298], 19 | 'violin': [2191, 2244, 2288, 2289, 2659], 20 | 'flute':[2202, 2203, 2204] 21 | } 22 | self.hop_inst = {'cello': self.wps, 'violin': int(self.wps * 0.5), 'flute': int(self.wps*0.25)} 23 | 24 | 25 | hp = hyperparams() 26 | 27 | 28 | def get_data(): 29 | ''' 30 | 31 | Extract the desired solo data from the dataset. 32 | 33 | Default: 34 | Process cello, violin, flute 35 | 36 | ''' 37 | dataset = np.load(open('data/musicnet.npz','rb'), encoding = 'latin1') 38 | train_data = h5py.File('data/train_data.hdf5', 'w') 39 | 40 | for inst in hp.instrument: 41 | print ('------ Processing ' + inst + ' ------') 42 | score = [] 43 | audio = [] 44 | for song in hp.instrument[inst]: 45 | a,b = dataset[str(song)] 46 | score.append(a) 47 | audio.append(b) 48 | 49 | spec_list, score_list, onoff_list = process_data(score,audio,inst) 50 | train_data.create_dataset(inst + "_spec", data=spec_list) 51 | train_data.create_dataset(inst + "_pianoroll", data=score_list) 52 | train_data.create_dataset(inst + "_onoff", data=onoff_list) 53 | 54 | 55 | def process_data(X, Y, inst): 56 | ''' 57 | Data Pre-processing 58 | 59 | Score: 60 | Generate pianoroll from interval tree data structure 61 | 62 | Audio: 63 | Convert waveform into power spectrogram 64 | 65 | ''' 66 | def process_spectrum(X, step, hop): 67 | audio = X[i][(step * hop * hp.stride): (step * hop * hp.stride) + ((hp.wps*5 - 1)* hp.stride)] 68 | spec = librosa.stft(audio, n_fft= hp.n_fft, hop_length = hp.stride) 69 | magnitude = np.log1p(np.abs(spec)**2) 70 | return magnitude 71 | 72 | def process_score(Y, step, hop): 73 | score = np.zeros((hp.wps*5, 128)) 74 | onset = np.zeros(score.shape) 75 | offset = np.zeros(score.shape) 76 | 77 | for window in range(score.shape[0]): 78 | 79 | #For score, set all notes to 1 if they are played at this window timestep 80 | labels = Y[i][(step * hop + window) * hp.stride] 81 | for label in labels: 82 | score[window,label.data[1]] = 1 83 | 84 | 85 | #For onset/offset, set onset to 1 and offset to -1 86 | if window != 0: 87 | onset[window][np.setdiff1d(score[window].nonzero(), score[window-1].nonzero())] = 1 88 | offset[window][np.setdiff1d(score[window-1].nonzero(), score[window].nonzero())] = -1 89 | else: 90 | onset[window][score[window].nonzero()] = 1 91 | 92 | 93 | onset += offset 94 | return score, onset 95 | 96 | 97 | spec_list=[] 98 | score_list=[] 99 | onoff_list=[] 100 | num_songs = len(X) 101 | hop = hp.hop_inst[inst] 102 | for i in range(num_songs): 103 | song_length = len(X[i]) 104 | num_spec = (song_length) // (hop * hp.stride) 105 | print ('{} song {} has {} windows'.format(inst, i, num_spec)) 106 | 107 | for step in range(num_spec - 30): 108 | if step % 50 == 0: 109 | print ('{} steps of {} song {} has been done'.format(step,inst,i)) 110 | spec_list.append(process_spectrum(X,step,hop)) 111 | score, onoff = process_score(Y,step,hop) 112 | score_list.append(score) 113 | onoff_list.append(onoff) 114 | 115 | return np.array(spec_list), np.array(score_list), np.array(onoff_list) 116 | 117 | 118 | 119 | 120 | 121 | 122 | def main(): 123 | get_data() 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | 129 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pretty_midi 3 | import numpy as np 4 | import h5py 5 | import pickle 6 | import torch.nn as nn 7 | import torch.utils.data as utils 8 | import json 9 | import os 10 | from model import PerformanceNet 11 | import librosa 12 | from tqdm import tqdm 13 | import sys 14 | class AudioSynthesizer(): 15 | def __init__(self, checkpoint, exp_dir, data_source): 16 | self.exp_dir = exp_dir 17 | self.checkpoint = torch.load(os.path.join(exp_dir,checkpoint)) 18 | self.sample_rate = 44100 19 | self.wps = 44100//256 20 | self.data_source = data_source 21 | 22 | def get_test_midi(self): 23 | 24 | X = np.load(os.path.join(self.exp_dir,'test_data/test_X.npy')) 25 | rand = np.random.randint(len(X),size=5) 26 | score = [X[i] for i in rand] 27 | return torch.Tensor(score).cuda() 28 | 29 | def process_custom_midi(self, midi_filename): 30 | 31 | midi_dir = os.path.join(self.exp_dir,'midi') 32 | midi = pretty_midi.PrettyMIDI(os.path.join(midi_dir,midi_filename)) 33 | pianoroll = midi.get_piano_roll(fs=self.wps).T 34 | pianoroll[pianoroll.nonzero()] = 1 35 | onoff = np.zeros(pianoroll.shape) 36 | for i in range(pianoroll.shape[0]): 37 | if i == 0: 38 | onoff[i][pianoroll[i].nonzero()] = 1 39 | else: 40 | onoff[i][np.setdiff1d(pianoroll[i-1].nonzero(), pianoroll[i].nonzero())] = -1 41 | onoff[i][np.setdiff1d(pianoroll[i].nonzero(), pianoroll[i-1].nonzero())] = 1 42 | 43 | return pianoroll, onoff 44 | 45 | 46 | def inference(self): 47 | model = PerformanceNet().cuda() 48 | model.load_state_dict(self.checkpoint['state_dict']) 49 | 50 | if self.data_source == 'TEST_DATA': 51 | score = self.get_test_midi() 52 | score, onoff = torch.split(score, 128, dim=1) 53 | else: 54 | score, onoff = self.process_custom_midi(self.data_source) 55 | 56 | print ('Inferencing spectrogram......') 57 | 58 | with torch.no_grad(): 59 | model.eval() 60 | test_results = model(score, onoff) 61 | test_results = test_results.cpu().numpy() 62 | 63 | output_dir = self.create_output_dir() 64 | 65 | for i in range(len(test_results)): 66 | audio = self.griffinlim(test_results[i], audio_id = i+1) 67 | librosa.output.write_wav(os.path.join(output_dir,'output-{}.wav'.format(i+1)), audio, self.sample_rate) 68 | 69 | def create_output_dir(self): 70 | success = False 71 | dir_id = 1 72 | while not success: 73 | try: 74 | audio_out_dir = os.path.join(self.exp_dir,'audio_output_{}'.format(dir_id)) 75 | os.makedirs(audio_out_dir) 76 | success = True 77 | except FileExistsError: 78 | dir_id += 1 79 | return audio_out_dir 80 | 81 | def griffinlim(self, spectrogram, audio_id, n_iter = 300, window = 'hann', n_fft = 2048, hop_length = 256, verbose = False): 82 | 83 | print ('Synthesizing audio {}'.format(audio_id)) 84 | 85 | if hop_length == -1: 86 | hop_length = n_fft // 4 87 | spectrogram[0:5] = 0 88 | 89 | spectrogram[150:] = 0 90 | angles = np.exp(2j * np.pi * np.random.rand(*spectrogram.shape)) 91 | 92 | t = tqdm(range(n_iter), ncols=100, mininterval=2.0, disable=not verbose) 93 | for i in t: 94 | full = np.abs(spectrogram).astype(np.complex) * angles 95 | inverse = librosa.istft(full, hop_length = hop_length, window = window) 96 | rebuilt = librosa.stft(inverse, n_fft = n_fft, hop_length = hop_length, window = window) 97 | angles = np.exp(1j * np.angle(rebuilt)) 98 | 99 | if verbose: 100 | diff = np.abs(spectrogram) - np.abs(rebuilt) 101 | t.set_postfix(loss=np.linalg.norm(diff, 'fro')) 102 | 103 | full = np.abs(spectrogram).astype(np.complex) * angles 104 | inverse = librosa.istft(full, hop_length = hop_length, window = window) 105 | 106 | return inverse 107 | 108 | 109 | def main(): 110 | exp_dir = os.path.join(os.path.abspath('./experiments'), sys.argv[1]) # which experiment to test 111 | data_source = sys.argv[2] # test with testing data or customized data 112 | with open(os.path.join(exp_dir,'hyperparams.json'), 'r') as hpfile: 113 | hp = json.load(hpfile) 114 | checkpoints = 'checkpoint-{}.tar'.format(hp['best_epoch']) 115 | AudioSynth = AudioSynthesizer(checkpoints, exp_dir, data_source) 116 | AudioSynth.inference() 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torch.optim as optim 7 | from sklearn.model_selection import train_test_split 8 | import torch.utils.data as utils 9 | import h5py 10 | import sys 11 | import os 12 | import json 13 | from model import PerformanceNet 14 | cuda = torch.device("cuda") 15 | 16 | class hyperparams(object): 17 | def __init__(self): 18 | self.instrument = sys.argv[1] 19 | self.train_epoch = int(sys.argv[2]) #default = 300 20 | self.test_freq = int(sys.argv[3]) #default = 10 21 | self.exp_name = sys.argv[4] 22 | self.iter_train_loss = [] 23 | self.iter_test_loss = [] 24 | self.loss_history = [] 25 | self.test_loss_history = [] 26 | self.best_loss = 1e10 27 | self.best_epoch = 0 28 | 29 | def Process_Data(instr, exp_dir): 30 | dataset = h5py.File('data/train_data.hdf5','r') 31 | score = dataset['{}_pianoroll'.format(instr)][:] 32 | spec = dataset['{}_spec'.format(instr)][:] 33 | onoff = dataset['{}_onoff'.format(instr)][:] 34 | score = np.concatenate((score, onoff),axis = -1) 35 | score = np.transpose(score,(0,2,1)) 36 | 37 | X_train, X_test, Y_train, Y_test = train_test_split(score, spec, test_size=0.2) 38 | 39 | test_data_dir = os.path.join(exp_dir,'test_data') 40 | os.makedirs(test_data_dir) 41 | 42 | np.save(os.path.join(test_data_dir, "test_X.npy"), X_test) 43 | np.save(os.path.join(test_data_dir, "test_Y.npy"), Y_test) 44 | 45 | train_dataset = utils.TensorDataset(torch.Tensor(X_train, device=cuda), torch.Tensor(Y_train, device=cuda)) 46 | train_loader = utils.DataLoader(train_dataset, batch_size=16, shuffle=True) 47 | test_dataset = utils.TensorDataset(torch.Tensor(X_test, device=cuda), torch.Tensor(Y_test,device=cuda)) 48 | test_loader = utils.DataLoader(test_dataset, batch_size=16, shuffle=True) 49 | 50 | return train_loader, test_loader 51 | 52 | def train(model, epoch, train_loader, optimizer,iter_train_loss): 53 | model.train() 54 | train_loss = 0 55 | for batch_idx, (data, target) in enumerate(train_loader): 56 | optimizer.zero_grad() 57 | split = torch.split(data, 128, dim=1) 58 | y_pred = model(split[0].cuda(),split[1].cuda()) 59 | loss_function = nn.MSELoss() 60 | loss = loss_function(y_pred, target.cuda()) 61 | loss.backward() 62 | iter_train_loss.append(loss.item()) 63 | train_loss += loss 64 | optimizer.step() 65 | 66 | if batch_idx % 2 == 0: 67 | print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx/len(train_loader), loss.item()/len(data))) 68 | 69 | print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss/ len(train_loader.dataset))) 70 | return train_loss/ len(train_loader.dataset) 71 | 72 | def test(model, epoch, test_loader, scheduler, iter_test_loss): 73 | with torch.no_grad(): 74 | model.eval() 75 | test_loss = 0 76 | for idx, (data, target) in enumerate(test_loader): 77 | split = torch.split(data,128,dim = 1) 78 | y_pred = model(split[0].cuda(),split[1].cuda()) 79 | loss_function = nn.MSELoss() 80 | loss = loss_function(y_pred,target.cuda()) 81 | iter_test_loss.append(loss.item()) 82 | test_loss += loss 83 | test_loss/= len(test_loader.dataset) 84 | scheduler.step(test_loss) 85 | print ('====> Test set loss: {:.4f}'.format(test_loss)) 86 | return test_loss 87 | 88 | 89 | def main(): 90 | hp = hyperparams() 91 | 92 | try: 93 | exp_root = os.path.join(os.path.abspath('./'),'experiments') 94 | os.makedirs(exp_root) 95 | except FileExistsError: 96 | pass 97 | 98 | exp_dir = os.path.join(exp_root, hp.exp_name) 99 | os.makedirs(exp_dir) 100 | 101 | model = PerformanceNet() 102 | model.cuda() 103 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 104 | model.zero_grad() 105 | optimizer.zero_grad() 106 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') 107 | train_loader, test_loader = Process_Data(hp.instrument, exp_dir) 108 | print ('start training') 109 | for epoch in range(hp.train_epoch): 110 | loss = train(model, epoch, train_loader, optimizer,hp.iter_train_loss) 111 | hp.loss_history.append(loss.item()) 112 | if epoch % hp.test_freq == 0: 113 | test_loss = test(model, epoch, test_loader, scheduler, hp.iter_test_loss) 114 | hp.test_loss_history.append(test_loss.item()) 115 | if test_loss < hp.best_loss: 116 | torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, os.path.join(exp_dir, 'checkpoint-{}.tar'.format(str(epoch + 1 )))) 117 | hp.best_loss = test_loss.item() 118 | hp.best_epoch = epoch + 1 119 | with open(os.path.join(exp_dir,'hyperparams.json'), 'w') as outfile: 120 | json.dump(hp.__dict__, outfile) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torch.optim as optim 7 | from sklearn.model_selection import train_test_split 8 | import torch.utils.data as utils 9 | import sys 10 | import pickle as pkl 11 | 12 | cuda = torch.device("cuda") 13 | 14 | def conv1x3(in_channels, out_channels, stride=1, padding=1, bias=True,groups=1): 15 | return nn.Conv1d( 16 | in_channels, 17 | out_channels, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=padding, 21 | bias=bias, 22 | groups=groups) 23 | 24 | def upconv1x2(in_channels, out_channels, kernel): 25 | return nn.ConvTranspose1d( 26 | in_channels, 27 | out_channels, 28 | kernel_size=kernel, 29 | stride=2, 30 | padding=1 31 | ) 32 | 33 | class DownConv(nn.Module): 34 | def __init__(self, in_channels, out_channels, block_id, pooling = True): 35 | super(DownConv,self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.pooling = pooling 39 | self.activation = nn.LeakyReLU(0.01) 40 | self.conv1 = conv1x3(self.in_channels, self.out_channels) 41 | self.conv1_BN = nn.InstanceNorm1d(self.out_channels) 42 | self.conv2 = conv1x3(self.out_channels, self.out_channels) 43 | self.conv2_BN = nn.InstanceNorm1d(self.out_channels) 44 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2) 45 | def forward(self,x): 46 | x = self.activation(self.conv1_BN(self.conv1(x))) 47 | x = self.activation(self.conv1_BN(self.conv2(x))) 48 | before_pool = x 49 | if self.pooling: 50 | x = self.pool(x) 51 | return x, before_pool 52 | 53 | class UpConv(nn.Module): 54 | def __init__(self, in_channels, out_channels, skip_channels, cond_channels, block_id, activation = nn.LeakyReLU(0.01), upconv_kernel=2): 55 | super(UpConv, self).__init__() 56 | self.skip_channels = skip_channels 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | self.cond_channels = cond_channels 60 | self.activation = activation 61 | self.upconv = upconv1x2(self.in_channels, self.out_channels,kernel=upconv_kernel) 62 | self.upconv_BN = nn.InstanceNorm1d(self.out_channels) 63 | self.conv1 = conv1x3( self.skip_channels + self.out_channels, self.out_channels) 64 | self.conv1_BN = nn.InstanceNorm1d(self.out_channels) 65 | self.conv2 = conv1x3(self.out_channels + self.cond_channels, self.out_channels) 66 | self.conv2_BN = nn.InstanceNorm1d(self.out_channels) 67 | 68 | def crop_and_concat(self, upsampled, bypass): 69 | c = (bypass.size()[2] - upsampled.size()[2]) // 2 70 | bypass = F.pad(bypass, (-c, -c)) 71 | if bypass.shape[2] > upsampled.shape[2]: 72 | bypass = F.pad(bypass, (0, -(bypass.shape[2] - upsampled.shape[2]))) 73 | else: 74 | bypass = F.pad(bypass, ((0, bypass.shape[2] - upsampled.shape[2]) )) 75 | return torch.cat((upsampled, bypass), 1) 76 | 77 | def forward(self, res, dec, cond): 78 | x = self.activation(self.upconv_BN(self.upconv(dec))) 79 | x = self.crop_and_concat(x, res) 80 | x = self.activation(self.conv1_BN(self.conv1(x))) 81 | 82 | if self.cond_channels: 83 | x = self.crop_and_concat(x, cond) 84 | 85 | x = self.conv2(x) 86 | x = self.activation(self.conv2_BN(x)) 87 | return x 88 | 89 | class Onset_Offset_Encoder(nn.Module): 90 | def __init__(self, depth = 3, start_channels = 128): 91 | super(Onset_Offset_Encoder, self).__init__() 92 | self.start_channels = start_channels 93 | self.depth = depth 94 | self.down_convs = [] 95 | self.construct_layers() 96 | self.down_convs = nn.ModuleList(self.down_convs) 97 | self.reset_params() 98 | def construct_layers(self): 99 | for i in range(self.depth): 100 | ins = self.start_channels if i == 0 else outs 101 | outs = self.start_channels * (2 ** (i+1)) 102 | pooling = True if i < self.depth else False 103 | DC = DownConv(ins, outs, pooling=pooling, block_id = i + 9) 104 | self.down_convs.append(DC) 105 | @staticmethod 106 | def weight_init(m): 107 | if isinstance(m, nn.Conv1d): 108 | init.xavier_normal_(m.weight) 109 | init.constant_(m.bias, 0) 110 | def reset_params(self): 111 | for i, m in enumerate(self.modules()): 112 | self.weight_init(m) 113 | def forward(self, x): 114 | condition_tensors = [] 115 | for i, module in enumerate(self.down_convs): 116 | x,_ = module(x) 117 | if (i > self.depth - 3): 118 | condition_tensors.append(x) 119 | return condition_tensors 120 | 121 | class MBRBlock(nn.Module): 122 | def __init__(self, in_channels, num_of_band): 123 | super(MBRBlock, self).__init__() 124 | self.in_dim = in_channels 125 | self.num_of_band = num_of_band 126 | self.conv_list1 = [] 127 | self.bn_list1 = [] 128 | self.conv_list2 = [] 129 | self.bn_list2 = [] 130 | self.activation = nn.LeakyReLU(0.01) 131 | self.band_dim = self.in_dim // self.num_of_band 132 | for i in range(self.num_of_band): 133 | self.conv_list1.append(nn.Conv1d(in_channels = self.band_dim, out_channels = self.band_dim, kernel_size = 3, padding = 1)) 134 | for i in range(self.num_of_band): 135 | self.conv_list2.append(nn.Conv1d(in_channels = self.band_dim, out_channels = self.band_dim, kernel_size = 3, padding = 1)) 136 | for i in range(self.num_of_band): 137 | self.bn_list1.append(nn.InstanceNorm1d(self.band_dim)) 138 | for i in range(self.num_of_band): 139 | self.bn_list2.append(nn.InstanceNorm1d(self.band_dim)) 140 | self.conv_list1 = nn.ModuleList(self.conv_list1) 141 | self.conv_list2 = nn.ModuleList(self.conv_list2) 142 | self.bn_list1 = nn.ModuleList(self.bn_list1) 143 | self.bn_list2 = nn.ModuleList(self.bn_list2) 144 | 145 | def forward(self,x): 146 | bands = torch.chunk(x, self.num_of_band, dim = 1) 147 | for i in range(len(bands)): 148 | t = self.activation(self.bn_list1[i](self.conv_list1[i](bands[i]))) 149 | t = self.bn_list2[i](self.conv_list2[i](t)) 150 | bands[i] = torch.add(bands[i],1,t) 151 | x = torch.add(x,1,torch.cat(bands, dim = 1)) 152 | return x 153 | 154 | class PerformanceNet(nn.Module): 155 | def __init__(self, depth = 5,start_channels = 128): 156 | super(PerformanceNet, self).__init__() 157 | self.depth = depth 158 | self.start_channels = start_channels 159 | self.construct_layers() 160 | self.reset_params() 161 | 162 | #@staticmethod 163 | def construct_layers(self): 164 | self.down_convs = [] 165 | self.up_convs = [] 166 | for i in range(self.depth): 167 | ins = self.start_channels if i == 0 else outs 168 | outs = self.start_channels * (2 ** (i+1)) 169 | pooling = True if i < self.depth-1 else False 170 | DC = DownConv(ins, outs, pooling=pooling, block_id=i) 171 | self.down_convs.append(DC) 172 | self.up_convs.append(UpConv(4096,2048,2048, 1024, block_id = 5, upconv_kernel=6)) 173 | self.up_convs.append(UpConv(2048,1024,1024, 512, block_id = 6, upconv_kernel=4)) 174 | self.up_convs.append(UpConv(1024,1024,512,0,block_id= 7, upconv_kernel=3)) 175 | self.up_convs.append(UpConv(1024,1024,256,0, block_id = 8)) 176 | self.down_convs = nn.ModuleList(self.down_convs) 177 | self.up_convs = nn.ModuleList(self.up_convs) 178 | self.MBRBlock1 = MBRBlock(1024,2) 179 | self.MBRBlock2 = MBRBlock(1024,4) 180 | self.MBRBlock3 = MBRBlock(1024,8) 181 | self.MBRBlock4 = MBRBlock(1024,16) 182 | self.lastconv = nn.ConvTranspose1d(1024,1025,kernel_size=3, stride=1, padding=1) 183 | self.lrelu = nn.LeakyReLU(0.01) 184 | self.onset_offset_encoder = Onset_Offset_Encoder() 185 | 186 | @staticmethod 187 | def weight_init(m): 188 | if isinstance(m, nn.Conv1d): 189 | init.xavier_normal_(m.weight) 190 | init.constant_(m.bias, 0) 191 | if isinstance(m, nn.ConvTranspose1d): 192 | init.xavier_normal_(m.weight) 193 | init.constant_(m.bias, 0) 194 | 195 | 196 | def reset_params(self): 197 | for i, m in enumerate(self.modules()): 198 | self.weight_init(m) 199 | 200 | def forward(self, x, cond): 201 | encoder_layer_outputs = [] 202 | for i, module in enumerate(self.down_convs): 203 | x, before_pool = module(x) 204 | encoder_layer_outputs.append(before_pool) 205 | 206 | Onoff_Conditions = self.onset_offset_encoder(cond) 207 | 208 | for i, module in enumerate(self.up_convs): 209 | before_pool = encoder_layer_outputs[-(i+2)] 210 | if i < self.onset_offset_encoder.depth - 1: 211 | x = module(before_pool, x, Onoff_Conditions[i-1]) 212 | else: 213 | x = module(before_pool, x, None) 214 | x = self.MBRBlock1(x) 215 | x = self.MBRBlock2(x) 216 | x = self.MBRBlock3(x) 217 | x = self.MBRBlock4(x) 218 | x = self.lrelu(self.lastconv(x)) 219 | return x 220 | 221 | --------------------------------------------------------------------------------