├── HiFiGanWrapper.py ├── README.md ├── audio2mel.py ├── checkpoint └── hifigan │ ├── g_00935000 │ └── hifigan_config.json ├── datasets.py ├── extract_code.py ├── inference.py ├── pixelsnail.py ├── scheduler.py ├── train_pixelsnail.py ├── train_vqvae.py └── vqvae.py /HiFiGanWrapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.nn import Conv1d, ConvTranspose1d 6 | from torch.nn.utils import weight_norm, remove_weight_norm 7 | 8 | 9 | class HiFiGanWrapper: 10 | def __init__(self, generator_pretrain_name: str, config_path: str): 11 | self.hifi_gan_generator: Generator 12 | self.generator_pretrain_path: str = generator_pretrain_name 13 | self.config_path: str = config_path 14 | self.load_hifi_gan() 15 | 16 | def load_hifi_gan(self): 17 | with open(self.config_path) as f: 18 | data = f.read() 19 | json_config = json.loads(data) 20 | self.hifi_gan_generator = Generator(AttrDict(json_config)) 21 | 22 | state_dict_g = torch.load(self.generator_pretrain_path, map_location='cpu') 23 | self.hifi_gan_generator.load_state_dict(state_dict_g['generator']) 24 | 25 | self.hifi_gan_generator = self.hifi_gan_generator.cuda() 26 | self.hifi_gan_generator.eval() 27 | self.hifi_gan_generator.remove_weight_norm() 28 | 29 | def generate_audio_by_hifi_gan(self, input_feature): 30 | final_shape_len = 3 31 | 32 | if type(input_feature) != torch.Tensor: 33 | input_feature = torch.from_numpy(input_feature) 34 | 35 | for _ in range(final_shape_len - len(input_feature.shape)): 36 | input_feature = torch.unsqueeze(input_feature, 0) 37 | 38 | input_feature = input_feature.cuda() 39 | 40 | with torch.no_grad(): 41 | # in: (batch,mel_size,time) , out: (batch,channel,time) 42 | audio = self.hifi_gan_generator(input_feature) 43 | audio = audio.squeeze() 44 | audio = audio.cpu().numpy() 45 | return audio 46 | 47 | 48 | class AttrDict(dict): 49 | def __init__(self, *args, **kwargs): 50 | super(AttrDict, self).__init__(*args, **kwargs) 51 | self.__dict__ = self 52 | 53 | 54 | LRELU_SLOPE = 0.1 55 | 56 | 57 | def init_weights(m, mean=0.0, std=0.01): 58 | classname = m.__class__.__name__ 59 | if classname.find("Conv") != -1: 60 | m.weight.data.normal_(mean, std) 61 | 62 | 63 | def get_padding(kernel_size, dilation=1): 64 | return int((kernel_size * dilation - dilation) / 2) 65 | 66 | 67 | class ResBlock1(torch.nn.Module): 68 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 69 | super(ResBlock1, self).__init__() 70 | self.h = h 71 | self.convs1 = nn.ModuleList( 72 | [ 73 | weight_norm( 74 | Conv1d( 75 | channels, 76 | channels, 77 | kernel_size, 78 | 1, 79 | dilation=dilation[0], 80 | padding=get_padding(kernel_size, dilation[0]), 81 | ) 82 | ), 83 | weight_norm( 84 | Conv1d( 85 | channels, 86 | channels, 87 | kernel_size, 88 | 1, 89 | dilation=dilation[1], 90 | padding=get_padding(kernel_size, dilation[1]), 91 | ) 92 | ), 93 | weight_norm( 94 | Conv1d( 95 | channels, 96 | channels, 97 | kernel_size, 98 | 1, 99 | dilation=dilation[2], 100 | padding=get_padding(kernel_size, dilation[2]), 101 | ) 102 | ), 103 | ] 104 | ) 105 | self.convs1.apply(init_weights) 106 | 107 | self.convs2 = nn.ModuleList( 108 | [ 109 | weight_norm( 110 | Conv1d( 111 | channels, 112 | channels, 113 | kernel_size, 114 | 1, 115 | dilation=1, 116 | padding=get_padding(kernel_size, 1), 117 | ) 118 | ), 119 | weight_norm( 120 | Conv1d( 121 | channels, 122 | channels, 123 | kernel_size, 124 | 1, 125 | dilation=1, 126 | padding=get_padding(kernel_size, 1), 127 | ) 128 | ), 129 | weight_norm( 130 | Conv1d( 131 | channels, 132 | channels, 133 | kernel_size, 134 | 1, 135 | dilation=1, 136 | padding=get_padding(kernel_size, 1), 137 | ) 138 | ), 139 | ] 140 | ) 141 | self.convs2.apply(init_weights) 142 | 143 | def forward(self, x): 144 | for c1, c2 in zip(self.convs1, self.convs2): 145 | xt = F.leaky_relu(x, LRELU_SLOPE) 146 | xt = c1(xt) 147 | xt = F.leaky_relu(xt, LRELU_SLOPE) 148 | xt = c2(xt) 149 | x = xt + x 150 | return x 151 | 152 | def remove_weight_norm(self): 153 | for l in self.convs1: 154 | remove_weight_norm(l) 155 | for l in self.convs2: 156 | remove_weight_norm(l) 157 | 158 | 159 | class ResBlock2(torch.nn.Module): 160 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 161 | super(ResBlock2, self).__init__() 162 | self.h = h 163 | self.convs = nn.ModuleList( 164 | [ 165 | weight_norm( 166 | Conv1d( 167 | channels, 168 | channels, 169 | kernel_size, 170 | 1, 171 | dilation=dilation[0], 172 | padding=get_padding(kernel_size, dilation[0]), 173 | ) 174 | ), 175 | weight_norm( 176 | Conv1d( 177 | channels, 178 | channels, 179 | kernel_size, 180 | 1, 181 | dilation=dilation[1], 182 | padding=get_padding(kernel_size, dilation[1]), 183 | ) 184 | ), 185 | ] 186 | ) 187 | self.convs.apply(init_weights) 188 | 189 | def forward(self, x): 190 | for c in self.convs: 191 | xt = F.leaky_relu(x, LRELU_SLOPE) 192 | xt = c(xt) 193 | x = xt + x 194 | return x 195 | 196 | def remove_weight_norm(self): 197 | for l in self.convs: 198 | remove_weight_norm(l) 199 | 200 | 201 | class Generator(torch.nn.Module): 202 | def __init__(self, h): 203 | super(Generator, self).__init__() 204 | self.h = h 205 | self.num_kernels = len(h.resblock_kernel_sizes) 206 | self.num_upsamples = len(h.upsample_rates) 207 | self.conv_pre = weight_norm( 208 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 209 | ) 210 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 211 | 212 | self.ups = nn.ModuleList() 213 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 214 | self.ups.append( 215 | weight_norm( 216 | ConvTranspose1d( 217 | h.upsample_initial_channel // (2**i), 218 | h.upsample_initial_channel // (2 ** (i + 1)), 219 | k, 220 | u, 221 | padding=(k - u) // 2, 222 | ) 223 | ) 224 | ) 225 | 226 | self.resblocks = nn.ModuleList() 227 | for i in range(len(self.ups)): 228 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 229 | for j, (k, d) in enumerate( 230 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 231 | ): 232 | self.resblocks.append(resblock(h, ch, k, d)) 233 | 234 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 235 | self.ups.apply(init_weights) 236 | self.conv_post.apply(init_weights) 237 | 238 | def forward(self, x): 239 | x = self.conv_pre(x) 240 | for i in range(self.num_upsamples): 241 | x = F.leaky_relu(x, LRELU_SLOPE) 242 | x = self.ups[i](x) 243 | xs = None 244 | for j in range(self.num_kernels): 245 | if xs is None: 246 | xs = self.resblocks[i * self.num_kernels + j](x) 247 | else: 248 | xs += self.resblocks[i * self.num_kernels + j](x) 249 | x = xs / self.num_kernels 250 | x = F.leaky_relu(x) 251 | x = self.conv_post(x) 252 | x = torch.tanh(x) 253 | 254 | return x 255 | 256 | def remove_weight_norm(self): 257 | print('Removing weight norm...') 258 | for l in self.ups: 259 | remove_weight_norm(l) 260 | for l in self.resblocks: 261 | l.remove_weight_norm() 262 | remove_weight_norm(self.conv_pre) 263 | remove_weight_norm(self.conv_post) 264 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCASE2023 - Task 7 - Baseline systems 2 | 3 | The code of this repository is mostly from [liuxubo717/sound_generation](https://github.com/liuxubo717/sound_generation). If you use this code, please cite the original repository following: 4 | ``` 5 | @article{liu2021conditional, 6 | title={Conditional Sound Generation Using Neural Discrete Time-Frequency Representation Learning}, 7 | author={Liu, Xubo and Iqbal, Turab and Zhao, Jinzheng and Huang, Qiushi and Plumbley, Mark D and Wang, Wenwu}, 8 | journal={arXiv preprint arXiv:2107.09998}, 9 | year={2021} 10 | } 11 | ``` 12 | For the neural vocoder, we brought the generator model code from [jik876/hifi-gan](https://github.com/jik876/hifi-gan). 13 | 14 | ## Set up 15 | 16 | * Clone the repository: 17 | 18 | ``` 19 | git clone https://github.com/DCASE2023-Task7-Foley-Sound-Synthesis/dcase2023_task7_baseline.git 20 | ``` 21 | 22 | * Install python requirements referring to packages as follows: 23 | 24 | ``` 25 | torch==1.13.1 26 | librosa==0.10.0 27 | python-lmdb==1.4.0 28 | tqdm 29 | ``` 30 | 31 | * Download the development dataset and move it to the root folder. The dataset path must be `./DCASEFoleySoundSynthesisDevSet` 32 | 33 | ## Usage 34 | 35 | 1: (Stage 1) Train a multi-scale VQ-VAE to extract the Discrete T-F Representation (DTFR) of sound. The pre-trained model will be saved to `checkpoint/vqvae/`. 36 | 37 | ``` 38 | python train_vqvae.py --epoch 800 39 | ``` 40 | 41 | 2: Extract DTFR for stage 2 training. 42 | 43 | ``` 44 | python extract_code.py --vqvae_checkpoint [VQ-VAE CHECKPOINT] 45 | ``` 46 | 47 | 3: (Stage 2) Train a PixelSNAIL model on the extracted DTFR of sound. The pre-trained model will be saved to `checkpoint/pixelsnail-final/`. 48 | 49 | ``` 50 | python train_pixelsnail.py --epoch 1500 51 | ``` 52 | 53 | 4: Inference sounds. The synthesized sound samples will be saved to `./synthesized` 54 | 55 | ``` 56 | python inference.py --vqvae_checkpoint [VQ-VAE CHECKPOINT] --pixelsnail_checkpoint [PIXELSNAIL CHECKPOINT] --number_of_synthesized_sound_per_class [NUMBER OF SOUND SAMPLES] 57 | ``` 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /audio2mel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | from scipy.io.wavfile import read as loadwav 4 | import numpy as np 5 | import datasets 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | """ Mel-Spectrogram extraction code from Turab ood_audio""" 14 | 15 | # def mel_spectrogram(audio, n_fft, n_mels, hop_length, sample_rate): 16 | # # Compute mel-scaled spectrogram 17 | # mel_fb = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels) 18 | # spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length) 19 | # mel = np.dot(mel_fb, np.abs(spec)) 20 | # 21 | # # return librosa.power_to_db(mel, ref=0., top_db=None) 22 | # return np.log(mel + 1e-9) 23 | 24 | """ Mel-Spectrogram extraction code from HiFi-GAN meldataset.py""" 25 | 26 | 27 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 28 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 29 | 30 | 31 | def dynamic_range_decompression(x, C=1): 32 | return np.exp(x) / C 33 | 34 | 35 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 36 | return torch.log(torch.clamp(x, min=clip_val) * C) 37 | 38 | 39 | def dynamic_range_decompression_torch(x, C=1): 40 | return torch.exp(x) / C 41 | 42 | 43 | def spectral_normalize_torch(magnitudes): 44 | output = dynamic_range_compression_torch(magnitudes) 45 | return output 46 | 47 | 48 | def spectral_de_normalize_torch(magnitudes): 49 | output = dynamic_range_decompression_torch(magnitudes) 50 | return output 51 | 52 | 53 | mel_basis = {} 54 | hann_window = {} 55 | 56 | 57 | def mel_spectrogram_hifi( 58 | audio, n_fft, n_mels, sample_rate, hop_length, fmin, fmax, center=False 59 | ): 60 | audio = torch.FloatTensor(audio) 61 | audio = audio.unsqueeze(0) 62 | 63 | if torch.min(audio) < -1.0: 64 | print('min value is ', torch.min(audio)) 65 | if torch.max(audio) > 1.0: 66 | print('max value is ', torch.max(audio)) 67 | 68 | global mel_basis, hann_window 69 | if fmax not in mel_basis: 70 | mel_fb = librosa.filters.mel( 71 | sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax 72 | ) 73 | mel_basis[str(fmax) + '_' + str(audio.device)] = ( 74 | torch.from_numpy(mel_fb).float().to(audio.device) 75 | ) 76 | hann_window[str(audio.device)] = torch.hann_window(n_fft).to(audio.device) 77 | 78 | audio = torch.nn.functional.pad( 79 | audio.unsqueeze(1), 80 | (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), 81 | mode='reflect', 82 | ) 83 | audio = audio.squeeze(1) 84 | 85 | spec = torch.stft( 86 | audio, 87 | n_fft, 88 | hop_length=hop_length, 89 | window=hann_window[str(audio.device)], 90 | center=center, 91 | pad_mode='reflect', 92 | normalized=False, 93 | onesided=True, 94 | return_complex=False, 95 | ) 96 | 97 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) 98 | 99 | mel = torch.matmul(mel_basis[str(fmax) + '_' + str(audio.device)], spec) 100 | mel = spectral_normalize_torch(mel).numpy() 101 | 102 | # pad_size = math.ceil(mel.shape[2] / 8) * 8 - mel.shape[2] 103 | # 104 | # mel = np.pad(mel, ((0, 0), (0, 0), (0, pad_size))) 105 | 106 | return mel 107 | 108 | 109 | """ Mel-Spectrogram extraction code from HiFi-GAN meldataset.py""" 110 | 111 | 112 | class Audio2Mel(torch.utils.data.Dataset): 113 | def __init__( 114 | self, 115 | audio_files, 116 | max_length, 117 | n_fft, 118 | n_mels, 119 | hop_length, 120 | sample_rate, 121 | fmin, 122 | fmax, 123 | ): 124 | self.audio_files = audio_files 125 | self.max_length = max_length # max length of audio 126 | self.n_fft = n_fft 127 | self.n_mels = n_mels 128 | self.hop_length = hop_length 129 | self.sample_rate = sample_rate 130 | self.fmin = fmin 131 | self.fmax = fmax 132 | 133 | def __getitem__(self, index): 134 | filename = self.audio_files[index]['file_path'] 135 | class_id = self.audio_files[index][ 136 | 'class_id' 137 | ] # datasets.get_class_id(filename) 138 | salience = 1 # datasets.get_salience(filename) 139 | 140 | sample_rate, audio = loadwav(filename) 141 | audio = audio / MAX_WAV_VALUE 142 | 143 | if sample_rate != self.sample_rate: 144 | raise ValueError( 145 | "{} sr doesn't match {} sr ".format(sample_rate, self.sample_rate) 146 | ) 147 | 148 | if len(audio) > self.max_length: 149 | # raise ValueError("{} length overflow".format(filename)) 150 | audio = audio[0 : self.max_length] 151 | 152 | # pad audio to max length, 4s for Urbansound8k dataset 153 | if len(audio) < self.max_length: 154 | # audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant') 155 | audio = np.pad(audio, (0, self.max_length - len(audio)), 'constant') 156 | 157 | # mel = mel_spectrogram(audio, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length, sample_rate=self.sample_rate) 158 | 159 | mel_spec = mel_spectrogram_hifi( 160 | audio, 161 | n_fft=self.n_fft, 162 | n_mels=self.n_mels, 163 | hop_length=self.hop_length, 164 | sample_rate=self.sample_rate, 165 | fmin=self.fmin, 166 | fmax=self.fmax, 167 | ) 168 | 169 | # print(mel_spec.shape) 170 | return mel_spec, class_id, salience, filename 171 | 172 | def __len__(self): 173 | return len(self.audio_files) 174 | 175 | 176 | def extract_flat_mel_from_Audio2Mel(Audio2Mel): 177 | mel = [] 178 | 179 | for item in Audio2Mel: 180 | mel.append(item[0].flatten()) 181 | 182 | return np.array(mel) 183 | 184 | 185 | if __name__ == '__main__': 186 | train_file_list, test_file_list = datasets.get_dataset_filelist() 187 | 188 | print(train_file_list[100]) 189 | 190 | train_set = Audio2Mel( 191 | train_file_list[0:2], 22050 * 4, 1024, 80, 256, 22050, 0, 8000 192 | ) 193 | 194 | print(train_set[0][0].shape) 195 | -------------------------------------------------------------------------------- /checkpoint/hifigan/g_00935000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCASE2023-Task7-Foley-Sound-Synthesis/dcase2023_task7_baseline/411ab9ecdb5880223a934666c4d294a4aa7062bc/checkpoint/hifigan/g_00935000 -------------------------------------------------------------------------------- /checkpoint/hifigan/hifigan_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 15, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import os 4 | import csv 5 | import torch 6 | import random 7 | import numpy as np 8 | import audio2mel 9 | from vqvae import VQVAE 10 | import lmdb 11 | import pickle 12 | from torch.utils.data import Dataset 13 | from collections import namedtuple 14 | 15 | CodeRow = namedtuple('CodeRow', ['bottom', 'class_id', 'salience', 'filename']) 16 | clas_dict: dict = { 17 | "DogBark": 0, 18 | "Footstep": 1, 19 | "GunShot": 2, 20 | "Keyboard": 3, 21 | "MovingMotorVehicle": 4, 22 | "Rain": 5, 23 | "Sneeze_Cough": 6, 24 | } 25 | 26 | 27 | def get_file_path(path): 28 | file_list = os.listdir(path) 29 | file_path = [] 30 | for audio in file_list: 31 | # i.e., ['test_mel/7965-3-11-0.npy'] 32 | file_path.append(os.path.join(path, audio)) 33 | return file_path 34 | 35 | 36 | def get_salience(file_name): 37 | annotation_file = 'UrbanSound8K/metadata/UrbanSound8K.csv' 38 | _, filename = os.path.split(file_name) 39 | with open(annotation_file, 'r') as f: 40 | reader = csv.reader(f) 41 | result = list(reader) 42 | result = result[1:] 43 | 44 | for row in result: 45 | if row[0] == filename: 46 | return row[4] 47 | 48 | 49 | def get_class_id(file_name): 50 | annotation_file = 'UrbanSound8K/metadata/UrbanSound8K.csv' 51 | _, filename = os.path.split(file_name) 52 | with open(annotation_file, 'r') as f: 53 | reader = csv.reader(f) 54 | result = list(reader) 55 | result = result[1:] 56 | 57 | for row in result: 58 | if row[0] == filename: 59 | return row[6] 60 | 61 | 62 | def get_dataset_filelist_urbansound8k( 63 | input_wavs_dir, input_annotation_file, test_fold_id: str, class_id=None 64 | ): 65 | training_files = [] 66 | validation_files = [] 67 | 68 | with open(input_annotation_file, 'r') as f: 69 | reader = csv.reader(f) 70 | result = list(reader) 71 | for row in result[1:]: 72 | # slice_file_name, fsID, start, end, salience, fold, classID, class 73 | if class_id is None: 74 | if row[5] == test_fold_id: 75 | validation_files.append(os.path.join(input_wavs_dir, row[0])) 76 | else: 77 | training_files.append(os.path.join(input_wavs_dir, row[0])) 78 | else: 79 | if row[6] == class_id: 80 | 81 | if row[5] == test_fold_id: 82 | validation_files.append(os.path.join(input_wavs_dir, row[0])) 83 | else: 84 | training_files.append(os.path.join(input_wavs_dir, row[0])) 85 | 86 | return training_files, validation_files 87 | 88 | 89 | def get_dataset_filelist() -> list: 90 | training_files: List[dict] = list() 91 | for root_dir, _, file_list in os.walk("./DCASEFoleySoundSynthesisDevSet"): 92 | for file_name in file_list: 93 | if os.path.splitext(file_name)[-1] == ".wav": 94 | training_files.append( 95 | { 96 | "class_id": clas_dict[root_dir.split("/")[-1]], 97 | "file_path": f"{root_dir}/{file_name}", 98 | } 99 | ) 100 | # training_files.append((clas_dict[root_dir.split("/")[-1]],root_dir.split("/")[-2],)) 101 | return training_files 102 | 103 | 104 | """ 105 | Generate Mels from Wav file 106 | Args: 107 | source: wav_path 108 | target: mel_path 109 | """ 110 | 111 | 112 | def mel_extract( 113 | file_list, 114 | mel_path, 115 | max_length=22050 * 4, 116 | n_fft=1024, 117 | n_mels=80, 118 | hop_length=256, 119 | sample_rate=22050, 120 | fmin=0, 121 | fmax=8000, 122 | ): 123 | # file_list = get_file_path(wav_path) 124 | mel = audio2mel.Audio2Mel( 125 | file_list, max_length, n_fft, n_mels, hop_length, sample_rate, fmin, fmax 126 | ) 127 | 128 | # print(mel[0][0].shape) 129 | 130 | for (item, _, _, filename) in mel: 131 | # print(item.shape) 132 | # item = F.interpolate(torch.tensor(item), scale_factor=2).numpy() 133 | np.save(os.path.join(mel_path, os.path.split(filename)[1]), item) 134 | print(filename, ' finished!') 135 | 136 | 137 | def mel_extract_test(): 138 | sample_size = 5 139 | test_file_list = get_dataset_filelist()[1] 140 | 141 | random.shuffle(test_file_list) 142 | 143 | test_file_list = test_file_list[0:sample_size] 144 | 145 | # extract mel of audio ground true 146 | mel_extract(test_file_list, 'test/test_mel') 147 | 148 | 149 | def mel_generate_test(): 150 | device = 'cuda:0' 151 | 152 | check_point = 'checkpoint/vqvae+/vqvae_560.pt' 153 | # check_point = 'checkpoint/small-vqvae2/vqvae_800.pt' 154 | 155 | model = VQVAE() 156 | model.load_state_dict(torch.load(check_point)) 157 | model = model.to(device) 158 | model.eval() 159 | 160 | test_mel_list = get_file_path('/home/lxb/Desktop/sound-recognition/mel-test9') 161 | 162 | with torch.no_grad(): 163 | for i, filname in enumerate(test_mel_list): 164 | # print("start " + i) 165 | mel = np.load(filname) 166 | mel = torch.FloatTensor(mel).unsqueeze(0).to(device) 167 | # print(mel[0][0][0][-4:]) 168 | out, _ = model(mel) 169 | out = out.squeeze(1).cpu().numpy() 170 | # print(out[0][0][-4:]) 171 | np.save( 172 | os.path.join( 173 | '/home/lxb/Desktop/sound-recognition/mel-generated9_ablation', 174 | os.path.split(filname)[1], 175 | ), 176 | out, 177 | ) 178 | 179 | 180 | class LMDBDataset(Dataset): 181 | def __init__(self, path): 182 | self.env = lmdb.open( 183 | path, 184 | max_readers=32, 185 | readonly=True, 186 | lock=False, 187 | readahead=False, 188 | meminit=False, 189 | ) 190 | 191 | if not self.env: 192 | raise IOError('Cannot open lmdb dataset', path) 193 | 194 | with self.env.begin(write=False) as txn: 195 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 196 | 197 | def __len__(self): 198 | return self.length 199 | 200 | def __getitem__(self, index): 201 | with self.env.begin(write=False) as txn: 202 | key = str(index).encode('utf-8') 203 | 204 | row = pickle.loads(txn.get(key)) 205 | 206 | return torch.from_numpy(row.bottom), row.class_id, row.salience, row.filename 207 | 208 | 209 | if __name__ == '__main__': 210 | # sc = '/media/lxb/U-DISK/baseline_30w/baseline_30w/samples' 211 | # audio_list = get_file_path(sc) 212 | # des = 'sample-results/baseline_32_30w' 213 | # mel_extract(audio_list, des) 214 | mel_generate_test() 215 | 216 | # test_9_list = get_dataset_filelist()[1] 217 | # mel_extract(test_9_list, '/media/lxb/sound-recognition/mel-test9') 218 | 219 | # test_file_list = get_file_path('test/test_audio') 220 | # 221 | # mel_extract(test_file_list, 'test/test_mel') 222 | 223 | # mel_generate_test() 224 | 225 | # _, data = get_dataset_filelist(class_id='90') 226 | # print(len(data)) 227 | 228 | # mel_extract_test() 229 | # dataset = LMDBDataset('code') 230 | # 231 | # loader = DataLoader( 232 | # dataset, batch_size=64, shuffle=True, num_workers=4, drop_last=True 233 | # ) 234 | # for i, (top, bottom, class_id, file_name) in enumerate(loader): 235 | # print(top.shape, bottom.shape, len(class_id), len(file_name)) 236 | 237 | # test to see latent space 238 | # with torch.no_grad(): 239 | # for i, filname in enumerate(test_mel_list): 240 | # mel = np.load(filname) 241 | # # x = torch.FloatTensor(x).unsqueeze(0).to(device) 242 | # mel = torch.FloatTensor(mel).unsqueeze(0).to(device) 243 | # _, _, _, id_t, id_b = model.encode(mel) 244 | # print(filname, id_b, id_t) 245 | -------------------------------------------------------------------------------- /extract_code.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import lmdb 7 | from tqdm import tqdm 8 | import audio2mel 9 | from datasets import get_dataset_filelist 10 | from vqvae import VQVAE 11 | from datasets import CodeRow 12 | 13 | 14 | def extract(lmdb_env, loader, model, device): 15 | index = 0 16 | 17 | with lmdb_env.begin(write=True) as txn: 18 | pbar = tqdm(loader) 19 | 20 | for img, class_id, salience, filename in pbar: 21 | img = img.to(device) 22 | 23 | _, _, id_b = model.encode(img) 24 | # id_t = id_t.detach().cpu().numpy() 25 | id_b = id_b.detach().cpu().numpy() 26 | 27 | for c_id, sali, file, bottom in zip(class_id, salience, filename, id_b): 28 | row = CodeRow( 29 | bottom=bottom, class_id=c_id, salience=sali, filename=file 30 | ) 31 | txn.put(str(index).encode('utf-8'), pickle.dumps(row)) 32 | index += 1 33 | pbar.set_description(f'inserted: {index}') 34 | 35 | txn.put('length'.encode('utf-8'), str(index).encode('utf-8')) 36 | 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | '--vqvae_checkpoint', type=str, default='./checkpoint/vqvae/vqvae.pth' 42 | ) 43 | parser.add_argument('--name', type=str, default='vqvae-code') 44 | 45 | args = parser.parse_args() 46 | 47 | device = 'cuda' 48 | 49 | train_file_list = get_dataset_filelist() 50 | 51 | train_set = audio2mel.Audio2Mel( 52 | train_file_list, 22050 * 4, 1024, 80, 256, 22050, 0, 8000 53 | ) 54 | 55 | loader = DataLoader(train_set, batch_size=128, sampler=None, num_workers=2) 56 | 57 | # for i, batch in enumerate(loader):l 58 | # mel, id, name = batch 59 | 60 | model = VQVAE() 61 | model.load_state_dict(torch.load(args.vqvae_checkpoint, map_location='cpu')) 62 | model = model.to(device) 63 | model.eval() 64 | 65 | map_size = 100 * 1024 * 1024 * 1024 66 | 67 | env = lmdb.open(args.name, map_size=map_size) 68 | 69 | extract(env, loader, model, device) 70 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from numpy import ndarray 3 | from torch import Tensor 4 | from abc import ABC, abstractmethod 5 | 6 | import os 7 | import argparse 8 | import math 9 | import time 10 | import datetime 11 | import torch 12 | from tqdm import tqdm 13 | import soundfile as sf 14 | 15 | from vqvae import VQVAE 16 | from pixelsnail import PixelSNAIL 17 | from HiFiGanWrapper import HiFiGanWrapper 18 | 19 | 20 | class SoundSynthesisModel(ABC): 21 | @abstractmethod 22 | def synthesize_sound(self, class_id: str, number_of_sounds: int) -> List[ndarray]: 23 | raise NotImplementedError 24 | 25 | 26 | class DCASE2023FoleySoundSynthesis: 27 | def __init__( 28 | self, number_of_synthesized_sound_per_class: int = 100, batch_size: int = 16 29 | ) -> None: 30 | self.number_of_synthesized_sound_per_class: int = ( 31 | number_of_synthesized_sound_per_class 32 | ) 33 | self.batch_size: int = batch_size 34 | self.class_id_dict: dict = { 35 | 0: 'DogBark', 36 | 1: 'Footstep', 37 | 2: 'GunShot', 38 | 3: 'Keyboard', 39 | 4: 'MovingMotorVehicle', 40 | 5: 'Rain', 41 | 6: 'Sneeze_Cough', 42 | } 43 | self.sr: int = 22050 44 | self.save_dir: str = "./synthesized" 45 | 46 | def synthesize(self, synthesis_model: SoundSynthesisModel) -> None: 47 | for sound_class_id in self.class_id_dict: 48 | sample_number: int = 1 49 | save_category_dir: str = ( 50 | f'{self.save_dir}/{self.class_id_dict[sound_class_id]}' 51 | ) 52 | os.makedirs(save_category_dir, exist_ok=True) 53 | for _ in tqdm( 54 | range( 55 | math.ceil( 56 | self.number_of_synthesized_sound_per_class / self.batch_size 57 | ) 58 | ), 59 | desc=f"Synthesizing {self.class_id_dict[sound_class_id]}", 60 | ): 61 | synthesized_sound_list: list = synthesis_model.synthesize_sound( 62 | sound_class_id, self.batch_size 63 | ) 64 | for synthesized_sound in synthesized_sound_list: 65 | if sample_number <= self.number_of_synthesized_sound_per_class: 66 | sf.write( 67 | f"{save_category_dir}/{str(sample_number).zfill(4)}.wav", 68 | synthesized_sound, 69 | samplerate=self.sr, 70 | ) 71 | sample_number += 1 72 | 73 | 74 | # ================================================================================================================================================ 75 | class BaseLineModel(SoundSynthesisModel): 76 | def __init__( 77 | self, pixel_snail_checkpoint: str, vqvae_snail_checkpoint: str 78 | ) -> None: 79 | super().__init__() 80 | self.pixel_snail = PixelSNAIL( 81 | [20, 86], 82 | 512, 83 | 256, 84 | 5, 85 | 4, 86 | 4, 87 | 256, 88 | dropout=0.1, 89 | n_cond_res_block=3, 90 | cond_res_channel=256, 91 | ) 92 | self.pixel_snail.load_state_dict( 93 | torch.load(pixel_snail_checkpoint, map_location='cpu')['model'] 94 | ) 95 | self.pixel_snail.cuda() 96 | self.pixel_snail.eval() 97 | 98 | self.vqvae = VQVAE() 99 | self.vqvae.load_state_dict( 100 | torch.load(vqvae_snail_checkpoint, map_location='cpu') 101 | ) 102 | self.vqvae.cuda() 103 | self.vqvae.eval() 104 | 105 | self.hifi_gan = HiFiGanWrapper( 106 | './checkpoint/hifigan/g_00935000', 107 | './checkpoint/hifigan/hifigan_config.json', 108 | ) 109 | 110 | @torch.no_grad() 111 | def synthesize_sound(self, class_id: str, number_of_sounds: int) -> List[ndarray]: 112 | audio_list: List[ndarray] = list() 113 | 114 | feature_shape: list = [20, 86] 115 | vq_token: Tensor = torch.zeros( 116 | number_of_sounds, *feature_shape, dtype=torch.int64 117 | ).cuda() 118 | cache = dict() 119 | 120 | for i in tqdm(range(feature_shape[0]), desc="pixel_snail"): 121 | for j in range(feature_shape[1]): 122 | out, cache = self.pixel_snail( 123 | vq_token[:, : i + 1, :], 124 | label_condition=torch.full([number_of_sounds, 1], int(class_id)) 125 | .long() 126 | .cuda(), 127 | cache=cache, 128 | ) 129 | prob: Tensor = torch.softmax(out[:, :, i, j], 1) 130 | vq_token[:, i, j] = torch.multinomial(prob, 1).squeeze(-1) 131 | 132 | pred_mel = self.vqvae.decode_code(vq_token).detach() 133 | for j, mel in enumerate(pred_mel): 134 | audio_list.append(self.hifi_gan.generate_audio_by_hifi_gan(mel)) 135 | return audio_list 136 | 137 | 138 | # =============================================================================================================================================== 139 | if __name__ == '__main__': 140 | start = time.time() 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument( 143 | '--vqvae_checkpoint', type=str, default='./checkpoint/vqvae/vqvae.pth' 144 | ) 145 | parser.add_argument( 146 | '--pixelsnail_checkpoint', 147 | type=str, 148 | default='./checkpoint/pixelsnail-final/bottom_1400.pt', 149 | ) 150 | parser.add_argument( 151 | '--number_of_synthesized_sound_per_class', type=int, default=100 152 | ) 153 | parser.add_argument('--batch_size', type=int, default=16) 154 | args = parser.parse_args() 155 | dcase_2023_foley_sound_synthesis = DCASE2023FoleySoundSynthesis( 156 | args.number_of_synthesized_sound_per_class, args.batch_size 157 | ) 158 | dcase_2023_foley_sound_synthesis.synthesize( 159 | synthesis_model=BaseLineModel(args.pixelsnail_checkpoint, args.vqvae_checkpoint) 160 | ) 161 | print(str(datetime.timedelta(seconds=time.time() - start))) 162 | -------------------------------------------------------------------------------- /pixelsnail.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xi Chen 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch 7 | 8 | from math import sqrt 9 | from functools import partial, lru_cache 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | 17 | def wn_linear(in_dim, out_dim): 18 | return nn.utils.weight_norm(nn.Linear(in_dim, out_dim)) 19 | 20 | 21 | class WNConv2d(nn.Module): 22 | def __init__( 23 | self, 24 | in_channel, 25 | out_channel, 26 | kernel_size, 27 | stride=1, 28 | padding=0, 29 | bias=True, 30 | activation=None, 31 | ): 32 | super().__init__() 33 | 34 | self.conv = nn.utils.weight_norm( 35 | nn.Conv2d( 36 | in_channel, 37 | out_channel, 38 | kernel_size, 39 | stride=stride, 40 | padding=padding, 41 | bias=bias, 42 | ) 43 | ) 44 | 45 | self.out_channel = out_channel 46 | 47 | if isinstance(kernel_size, int): 48 | kernel_size = [kernel_size, kernel_size] 49 | 50 | self.kernel_size = kernel_size 51 | 52 | self.activation = activation 53 | 54 | def forward(self, input): 55 | out = self.conv(input) 56 | 57 | if self.activation is not None: 58 | out = self.activation(out) 59 | 60 | return out 61 | 62 | 63 | def shift_down(input, size=1): 64 | return F.pad(input, [0, 0, size, 0])[:, :, : input.shape[2], :] 65 | 66 | 67 | def shift_right(input, size=1): 68 | return F.pad(input, [size, 0, 0, 0])[:, :, :, : input.shape[3]] 69 | 70 | 71 | class CausalConv2d(nn.Module): 72 | def __init__( 73 | self, 74 | in_channel, 75 | out_channel, 76 | kernel_size, 77 | stride=1, 78 | padding='downright', 79 | activation=None, 80 | ): 81 | super().__init__() 82 | 83 | if isinstance(kernel_size, int): 84 | kernel_size = [kernel_size] * 2 85 | 86 | self.kernel_size = kernel_size 87 | 88 | if padding == 'downright': 89 | pad = [kernel_size[1] - 1, 0, kernel_size[0] - 1, 0] 90 | 91 | elif padding == 'down' or padding == 'causal': 92 | pad = kernel_size[1] // 2 93 | 94 | pad = [pad, pad, kernel_size[0] - 1, 0] 95 | 96 | self.causal = 0 97 | if padding == 'causal': 98 | self.causal = kernel_size[1] // 2 99 | 100 | self.pad = nn.ZeroPad2d(pad) 101 | 102 | self.conv = WNConv2d( 103 | in_channel, 104 | out_channel, 105 | kernel_size, 106 | stride=stride, 107 | padding=0, 108 | activation=activation, 109 | ) 110 | 111 | def forward(self, input): 112 | out = self.pad(input) 113 | 114 | if self.causal > 0: 115 | self.conv.conv.weight_v.data[:, :, -1, self.causal :].zero_() 116 | 117 | out = self.conv(out) 118 | 119 | return out 120 | 121 | 122 | class GatedResBlock(nn.Module): 123 | def __init__( 124 | self, 125 | in_channel, 126 | channel, 127 | kernel_size, 128 | conv='wnconv2d', 129 | activation=nn.ELU, 130 | dropout=0.1, 131 | auxiliary_channel=0, 132 | condition_dim=0, 133 | ): 134 | super().__init__() 135 | 136 | if conv == 'wnconv2d': 137 | conv_module = partial(WNConv2d, padding=kernel_size // 2) 138 | 139 | elif conv == 'causal_downright': 140 | conv_module = partial(CausalConv2d, padding='downright') 141 | 142 | elif conv == 'causal': 143 | conv_module = partial(CausalConv2d, padding='causal') 144 | 145 | self.activation = activation() 146 | self.conv1 = conv_module(in_channel, channel, kernel_size) 147 | 148 | if auxiliary_channel > 0: 149 | self.aux_conv = WNConv2d(auxiliary_channel, channel, 1) 150 | 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | self.conv2 = conv_module(channel, in_channel * 2, kernel_size) 154 | 155 | if condition_dim > 0: 156 | # self.condition = nn.Linear(condition_dim, in_channel * 2, bias=False) 157 | self.condition = WNConv2d(condition_dim, in_channel * 2, 1, bias=False) 158 | 159 | self.gate = nn.GLU(1) 160 | 161 | def forward(self, input, aux_input=None, condition=None): 162 | out = self.conv1(self.activation(input)) 163 | 164 | if aux_input is not None: 165 | out = out + self.aux_conv(self.activation(aux_input)) 166 | 167 | out = self.activation(out) 168 | out = self.dropout(out) 169 | out = self.conv2(out) 170 | 171 | if condition is not None: 172 | # print('condition', condition.shape) 173 | condition = self.condition(condition) 174 | out += condition 175 | # out = out + condition.view(condition.shape[0], 1, 1, condition.shape[1]) 176 | 177 | out = self.gate(out) 178 | out += input 179 | 180 | return out 181 | 182 | 183 | @lru_cache(maxsize=64) 184 | def causal_mask(size): 185 | shape = [size, size] 186 | mask = np.triu(np.ones(shape), k=1).astype(np.uint8).T 187 | start_mask = np.ones(size).astype(np.float32) 188 | start_mask[0] = 0 189 | 190 | return ( 191 | torch.from_numpy(mask).unsqueeze(0), 192 | torch.from_numpy(start_mask).unsqueeze(1), 193 | ) 194 | 195 | 196 | class CausalAttention(nn.Module): 197 | def __init__(self, query_channel, key_channel, channel, n_head=8, dropout=0.1): 198 | super().__init__() 199 | 200 | self.query = wn_linear(query_channel, channel) 201 | self.key = wn_linear(key_channel, channel) 202 | self.value = wn_linear(key_channel, channel) 203 | 204 | self.dim_head = channel // n_head 205 | self.n_head = n_head 206 | 207 | self.dropout = nn.Dropout(dropout) 208 | 209 | def forward(self, query, key): 210 | batch, _, height, width = key.shape 211 | 212 | def reshape(input): 213 | return input.view(batch, -1, self.n_head, self.dim_head).transpose(1, 2) 214 | 215 | query_flat = query.view(batch, query.shape[1], -1).transpose(1, 2) 216 | key_flat = key.view(batch, key.shape[1], -1).transpose(1, 2) 217 | query = reshape(self.query(query_flat)) 218 | key = reshape(self.key(key_flat)).transpose(2, 3) 219 | value = reshape(self.value(key_flat)) 220 | 221 | attn = torch.matmul(query, key) / sqrt(self.dim_head) 222 | mask, start_mask = causal_mask(height * width) 223 | mask = mask.type_as(query) 224 | start_mask = start_mask.type_as(query) 225 | attn = attn.masked_fill(mask == 0, -1e4) 226 | attn = torch.softmax(attn, 3) * start_mask 227 | attn = self.dropout(attn) 228 | 229 | out = attn @ value 230 | out = out.transpose(1, 2).reshape( 231 | batch, height, width, self.dim_head * self.n_head 232 | ) 233 | out = out.permute(0, 3, 1, 2) 234 | 235 | return out 236 | 237 | 238 | class PixelBlock(nn.Module): 239 | def __init__( 240 | self, 241 | in_channel, 242 | channel, 243 | kernel_size, 244 | n_res_block, 245 | attention=True, 246 | dropout=0.1, 247 | condition_dim=0, 248 | ): 249 | super().__init__() 250 | 251 | resblocks = [] 252 | for i in range(n_res_block): 253 | resblocks.append( 254 | GatedResBlock( 255 | in_channel, 256 | channel, 257 | kernel_size, 258 | conv='causal', 259 | dropout=dropout, 260 | condition_dim=condition_dim, 261 | ) 262 | ) 263 | 264 | self.resblocks = nn.ModuleList(resblocks) 265 | 266 | self.attention = attention 267 | 268 | if attention: 269 | self.key_resblock = GatedResBlock( 270 | in_channel * 2 + 2, in_channel, 1, dropout=dropout 271 | ) 272 | self.query_resblock = GatedResBlock( 273 | in_channel + 2, in_channel, 1, dropout=dropout 274 | ) 275 | 276 | self.causal_attention = CausalAttention( 277 | in_channel + 2, in_channel * 2 + 2, in_channel // 2, dropout=dropout 278 | ) 279 | 280 | self.out_resblock = GatedResBlock( 281 | in_channel, 282 | in_channel, 283 | 1, 284 | auxiliary_channel=in_channel // 2, 285 | dropout=dropout, 286 | ) 287 | 288 | else: 289 | self.out = WNConv2d(in_channel + 2, in_channel, 1) 290 | 291 | def forward(self, input, background, condition=None): 292 | out = input 293 | 294 | for resblock in self.resblocks: 295 | out = resblock(out, condition=condition) 296 | 297 | if self.attention: 298 | key_cat = torch.cat([input, out, background], 1) 299 | key = self.key_resblock(key_cat) 300 | query_cat = torch.cat([out, background], 1) 301 | query = self.query_resblock(query_cat) 302 | attn_out = self.causal_attention(query, key) 303 | out = self.out_resblock(out, attn_out) 304 | 305 | else: 306 | bg_cat = torch.cat([out, background], 1) 307 | out = self.out(bg_cat) 308 | 309 | return out 310 | 311 | 312 | class CondResNet(nn.Module): 313 | def __init__(self, in_channel, channel, kernel_size, n_res_block): 314 | super().__init__() 315 | 316 | blocks = [WNConv2d(in_channel, channel, kernel_size, padding=kernel_size // 2)] 317 | 318 | for i in range(n_res_block): 319 | blocks.append(GatedResBlock(channel, channel, kernel_size)) 320 | 321 | self.blocks = nn.Sequential(*blocks) 322 | 323 | def forward(self, input): 324 | return self.blocks(input) 325 | 326 | 327 | class EmbedNet(nn.Module): 328 | def __init__(self, in_dim, hidden_dim, out_dim): 329 | super().__init__() 330 | 331 | blocks = [ 332 | nn.Linear(in_dim, hidden_dim), 333 | nn.Linear(hidden_dim, hidden_dim), 334 | nn.Linear(hidden_dim, hidden_dim), 335 | nn.Linear(hidden_dim, out_dim), 336 | ] 337 | 338 | self.blocks = nn.Sequential(*blocks) 339 | 340 | def forward(self, input): 341 | return self.blocks(input) 342 | 343 | 344 | class PixelSNAIL(nn.Module): 345 | def __init__( 346 | self, 347 | shape, 348 | n_class, # code nums 349 | channel, 350 | kernel_size, 351 | n_block, 352 | n_res_block, 353 | res_channel, 354 | attention=True, 355 | dropout=0.1, 356 | n_cond_res_block=0, 357 | cond_res_channel=0, 358 | cond_res_kernel=3, 359 | n_out_res_block=0, 360 | cond_embed_channel=1, 361 | ### 362 | n_label=7, # data class nums 363 | embed_dim=2048, 364 | ### 365 | ): 366 | super().__init__() 367 | 368 | height, width = shape 369 | 370 | self.n_class = n_class 371 | 372 | ### 373 | self.n_label = n_label 374 | ### 375 | 376 | if kernel_size % 2 == 0: 377 | kernel = kernel_size + 1 378 | 379 | else: 380 | kernel = kernel_size 381 | 382 | self.horizontal = CausalConv2d( 383 | n_class, channel, [kernel // 2, kernel], padding='down' 384 | ) 385 | self.vertical = CausalConv2d( 386 | n_class, channel, [(kernel + 1) // 2, kernel // 2], padding='downright' 387 | ) 388 | 389 | coord_x = (torch.arange(height).float() - height / 2) / height 390 | coord_x = coord_x.view(1, 1, height, 1).expand( 391 | 1, 1, height, width 392 | ) # shape: torch.Size([1, 1, 20, 86]) 393 | coord_y = (torch.arange(width).float() - width / 2) / width 394 | coord_y = coord_y.view(1, 1, 1, width).expand( 395 | 1, 1, height, width 396 | ) # shape: torch.Size([1, 1, 20, 86]) 397 | # print('x', coord_x.shape, 'y', coord_y.shape) 398 | self.register_buffer( 399 | 'background', torch.cat([coord_x, coord_y], 1) 400 | ) # shape: self.background torch.Size([1, 2, 20, 86]) 401 | 402 | self.blocks = nn.ModuleList() 403 | 404 | for i in range(n_block): 405 | self.blocks.append( 406 | PixelBlock( 407 | channel, 408 | res_channel, 409 | kernel_size, 410 | n_res_block, 411 | attention=attention, 412 | dropout=dropout, 413 | condition_dim=cond_embed_channel, 414 | ) 415 | ) 416 | 417 | if n_cond_res_block > 0: 418 | self.cond_resnet = CondResNet( 419 | n_class, cond_res_channel, cond_res_kernel, n_cond_res_block 420 | ) 421 | 422 | ### 423 | self.embedNet = EmbedNet(n_label, embed_dim, 20 * 86) 424 | ### 425 | 426 | out = [] 427 | 428 | for i in range(n_out_res_block): 429 | out.append(GatedResBlock(channel, res_channel, 1)) 430 | 431 | out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)]) 432 | 433 | self.out = nn.Sequential(*out) 434 | 435 | def forward(self, input, label_condition=None, cache=None): 436 | if cache is None: 437 | cache = {} 438 | batch, height, width = input.shape 439 | # print('input', input.shape) 440 | input = ( 441 | F.one_hot(input, self.n_class).permute(0, 3, 1, 2).type_as(self.background) 442 | ) 443 | horizontal = shift_down(self.horizontal(input)) 444 | vertical = shift_right(self.vertical(input)) 445 | out = horizontal + vertical 446 | 447 | # print('background-1', self.background.shape) 448 | background = self.background[:, :, :height, :].expand( 449 | batch, 2, height, width 450 | ) # shape: torch.Size([32, 2, 20, 86]) 451 | # print('background-2', background.shape) 452 | 453 | if True: 454 | if 'condition' in cache: 455 | condition = cache['condition'] 456 | condition = condition[:, :, :height, :] 457 | 458 | else: 459 | label = F.one_hot(label_condition, self.n_label).type_as( 460 | self.background 461 | ) 462 | # salience = salience_condition.unsqueeze(1) 463 | # condition = torch.cat((label, salience), 2) 464 | condition = label 465 | 466 | condition = self.embedNet(condition) 467 | condition = condition.view(-1, 1, 20, 86) 468 | # print(condition.shape) #torch.Size([64, 1, 10, 43]) 469 | cache['condition'] = condition.detach().clone() 470 | condition = condition[:, :, :height, :] 471 | 472 | # if code_condition is not None: 473 | # embed_condition = ( 474 | # F.one_hot(label_condition, self.n_label) 475 | # .type_as(self.background) 476 | # ) 477 | # embed_condition = self.embedNet(embed_condition) 478 | # embed_condition = embed_condition.view(-1, 1, 10, 43) 479 | # # print('embed-1', embed_condition.shape) 480 | # embed_condition = F.interpolate(embed_condition, scale_factor=2) 481 | # # print('embed-2', embed_condition.shape) 482 | # 483 | # condition = ( 484 | # F.one_hot(code_condition, self.n_class) 485 | # .permute(0, 3, 1, 2) 486 | # .type_as(self.background) 487 | # ) 488 | # # condition.shape: torch.Size([32, 512, 10, 43])) 489 | # condition = self.cond_resnet(condition) 490 | # # print(condition.shape) 491 | # condition = F.interpolate(condition, scale_factor=2) 492 | # # print('before', condition.shape) 493 | # condition = torch.cat([condition, embed_condition], 1) 494 | # # print('after', condition.shape) 495 | # # condition.shape: torch.Size([32, 256, 20, 86])) 496 | # cache['condition'] = condition.detach().clone() 497 | # # print(condition.shape) 498 | # condition = condition[:, :, :height, :] 499 | 500 | for block in self.blocks: 501 | out = block(out, background, condition=condition) # PixelBlock 502 | 503 | out = self.out(out) 504 | 505 | return out, cache 506 | 507 | 508 | if __name__ == '__main__': 509 | from torch.utils.data import DataLoader 510 | from datasets import LMDBDataset 511 | from torch import nn 512 | 513 | device = 'cuda' 514 | 515 | dataset = LMDBDataset('code/') 516 | loader = DataLoader( 517 | dataset, batch_size=2, shuffle=True, num_workers=4, drop_last=True 518 | ) 519 | 520 | model = PixelSNAIL( 521 | [20, 86], 522 | 512, 523 | 256, 524 | 5, 525 | 4, 526 | 4, 527 | 256, 528 | attention=False, 529 | dropout=0.1, 530 | n_cond_res_block=3, 531 | cond_res_channel=256, 532 | ) 533 | 534 | # model = PixelSNAIL( 535 | # [10, 43], 536 | # 512, 537 | # 256, 538 | # 5, 539 | # 4, 540 | # 4, 541 | # 256, 542 | # dropout=0.1, 543 | # n_out_res_block=0, 544 | # cond_res_channel = 0 545 | # ) 546 | 547 | model = nn.DataParallel(model) 548 | model = model.to(device) 549 | 550 | for i, (bottom, class_id, salience, file_name) in enumerate(loader): 551 | class_id = ( 552 | torch.FloatTensor(list(map(eval, list(class_id)))).long().unsqueeze(1) 553 | ) 554 | salience = torch.FloatTensor(list(map(eval, list(salience)))).unsqueeze(1) 555 | out, _ = model(bottom, label_condition=class_id, salience_condition=salience) 556 | if i == 5: 557 | print(class_id, salience, file_name) 558 | break 559 | 560 | 561 | # class UCPixelSNAIL(nn.Module): 562 | # def __init__( 563 | # self, 564 | # shape, 565 | # n_class, # code nums 566 | # channel, 567 | # kernel_size, 568 | # n_block, 569 | # n_res_block, 570 | # res_channel, 571 | # attention=True, 572 | # dropout=0.1, 573 | # n_cond_res_block=0, 574 | # cond_res_channel=0, 575 | # cond_res_kernel=3, 576 | # n_out_res_block=0, 577 | # ): 578 | # super().__init__() 579 | # 580 | # height, width = shape 581 | # 582 | # self.n_class = n_class 583 | # 584 | # if kernel_size % 2 == 0: 585 | # kernel = kernel_size + 1 586 | # 587 | # else: 588 | # kernel = kernel_size 589 | # 590 | # self.horizontal = CausalConv2d( 591 | # n_class, channel, [kernel // 2, kernel], padding='down' 592 | # ) 593 | # self.vertical = CausalConv2d( 594 | # n_class, channel, [(kernel + 1) // 2, kernel // 2], padding='downright' 595 | # ) 596 | # 597 | # coord_x = (torch.arange(height).float() - height / 2) / height 598 | # coord_x = coord_x.view(1, 1, height, 1).expand(1, 1, height, width) 599 | # coord_y = (torch.arange(width).float() - width / 2) / width 600 | # coord_y = coord_y.view(1, 1, 1, width).expand(1, 1, height, width) 601 | # self.register_buffer('background', torch.cat([coord_x, coord_y], 1)) 602 | # 603 | # self.blocks = nn.ModuleList() 604 | # 605 | # for i in range(n_block): 606 | # self.blocks.append( 607 | # PixelBlock( 608 | # channel, 609 | # res_channel, 610 | # kernel_size, 611 | # n_res_block, 612 | # attention=attention, 613 | # dropout=dropout, 614 | # condition_dim=cond_res_channel, 615 | # ) 616 | # ) 617 | # 618 | # if n_cond_res_block > 0: 619 | # self.cond_resnet = CondResNet( 620 | # n_class, cond_res_channel, cond_res_kernel, n_cond_res_block 621 | # ) 622 | # 623 | # out = [] 624 | # 625 | # for i in range(n_out_res_block): 626 | # out.append(GatedResBlock(channel, res_channel, 1)) 627 | # 628 | # out.extend([nn.ELU(inplace=True), WNConv2d(channel, n_class, 1)]) 629 | # 630 | # self.out = nn.Sequential(*out) 631 | # 632 | # def forward(self, input, condition=None, condition_type=None, cache=None): 633 | # if cache is None: 634 | # cache = {} 635 | # batch, height, width = input.shape 636 | # input = ( 637 | # F.one_hot(input, self.n_class).permute(0, 3, 1, 2).type_as(self.background) 638 | # ) 639 | # horizontal = shift_down(self.horizontal(input)) 640 | # vertical = shift_right(self.vertical(input)) 641 | # out = horizontal + vertical 642 | # 643 | # background = self.background[:, :, :height, :].expand(batch, 2, height, width) 644 | # 645 | # if condition is not None: 646 | # if 'condition' in cache: 647 | # condition = cache['condition'] 648 | # condition = condition[:, :, :height, :] 649 | # 650 | # else: 651 | # condition = ( 652 | # F.one_hot(condition, self.n_class) 653 | # .permute(0, 3, 1, 2) 654 | # .type_as(self.background) 655 | # ) 656 | # condition = self.cond_resnet(condition) 657 | # condition = F.interpolate(condition, scale_factor=2) 658 | # cache['condition'] = condition.detach().clone() 659 | # condition = condition[:, :, :height, :] 660 | # 661 | # for block in self.blocks: 662 | # out = block(out, background, condition=condition) 663 | # 664 | # out = self.out(out) 665 | # 666 | # return out, cache 667 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, floor, sin 2 | 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | class CosineLR(lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, lr_min, lr_max, step_size): 8 | self.lr_min = lr_min 9 | self.lr_max = lr_max 10 | self.step_size = step_size 11 | self.iteration = 0 12 | 13 | super().__init__(optimizer, -1) 14 | 15 | def get_lr(self): 16 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 17 | 1 + cos(self.iteration / self.step_size * pi) 18 | ) 19 | self.iteration += 1 20 | 21 | if self.iteration == self.step_size: 22 | self.iteration = 0 23 | 24 | return [lr for base_lr in self.base_lrs] 25 | 26 | 27 | class PowerLR(lr_scheduler._LRScheduler): 28 | def __init__(self, optimizer, lr_min, lr_max, warmup): 29 | self.lr_min = lr_min 30 | self.lr_max = lr_max 31 | self.warmup = warmup 32 | self.iteration = 0 33 | 34 | super().__init__(optimizer, -1) 35 | 36 | def get_lr(self): 37 | if self.iteration < self.warmup: 38 | lr = ( 39 | self.lr_min + (self.lr_max - self.lr_min) / self.warmup * self.iteration 40 | ) 41 | 42 | else: 43 | lr = self.lr_max * (self.iteration - self.warmup + 1) ** -0.5 44 | 45 | self.iteration += 1 46 | 47 | return [lr for base_lr in self.base_lrs] 48 | 49 | 50 | class SineLR(lr_scheduler._LRScheduler): 51 | def __init__(self, optimizer, lr_min, lr_max, step_size): 52 | self.lr_min = lr_min 53 | self.lr_max = lr_max 54 | self.step_size = step_size 55 | self.iteration = 0 56 | 57 | super().__init__(optimizer, -1) 58 | 59 | def get_lr(self): 60 | lr = self.lr_min + (self.lr_max - self.lr_min) * sin( 61 | self.iteration / self.step_size * pi 62 | ) 63 | self.iteration += 1 64 | 65 | if self.iteration == self.step_size: 66 | self.iteration = 0 67 | 68 | return [lr for base_lr in self.base_lrs] 69 | 70 | 71 | class LinearLR(lr_scheduler._LRScheduler): 72 | def __init__(self, optimizer, lr_min, lr_max, warmup, step_size): 73 | self.lr_min = lr_min 74 | self.lr_max = lr_max 75 | self.step_size = step_size 76 | self.warmup = warmup 77 | self.iteration = 0 78 | 79 | super().__init__(optimizer, -1) 80 | 81 | def get_lr(self): 82 | if self.iteration < self.warmup: 83 | lr = self.lr_max 84 | 85 | else: 86 | lr = self.lr_max + (self.iteration - self.warmup) * ( 87 | self.lr_min - self.lr_max 88 | ) / (self.step_size - self.warmup) 89 | self.iteration += 1 90 | 91 | if self.iteration == self.step_size: 92 | self.iteration = 0 93 | 94 | return [lr for base_lr in self.base_lrs] 95 | 96 | 97 | class CLR(lr_scheduler._LRScheduler): 98 | def __init__(self, optimizer, lr_min, lr_max, step_size): 99 | self.epoch = 0 100 | self.lr_min = lr_min 101 | self.lr_max = lr_max 102 | self.current_lr = lr_min 103 | self.step_size = step_size 104 | 105 | super().__init__(optimizer, -1) 106 | 107 | def get_lr(self): 108 | cycle = floor(1 + self.epoch / (2 * self.step_size)) 109 | x = abs(self.epoch / self.step_size - 2 * cycle + 1) 110 | lr = self.lr_min + (self.lr_max - self.lr_min) * max(0, 1 - x) 111 | self.current_lr = lr 112 | 113 | self.epoch += 1 114 | 115 | return [lr for base_lr in self.base_lrs] 116 | 117 | 118 | class Warmup(lr_scheduler._LRScheduler): 119 | def __init__(self, optimizer, model_dim, factor=1, warmup=16000): 120 | self.optimizer = optimizer 121 | self.model_dim = model_dim 122 | self.factor = factor 123 | self.warmup = warmup 124 | self.iteration = 0 125 | 126 | super().__init__(optimizer, -1) 127 | 128 | def get_lr(self): 129 | self.iteration += 1 130 | lr = ( 131 | self.factor 132 | * self.model_dim ** (-0.5) 133 | * min(self.iteration ** (-0.5), self.iteration * self.warmup ** (-1.5)) 134 | ) 135 | 136 | return [lr for base_lr in self.base_lrs] 137 | 138 | 139 | # Copyright 2019 fastai 140 | 141 | # Licensed under the Apache License, Version 2.0 (the "License"); 142 | # you may not use this file except in compliance with the License. 143 | # You may obtain a copy of the License at 144 | 145 | # http://www.apache.org/licenses/LICENSE-2.0 146 | 147 | # Unless required by applicable law or agreed to in writing, software 148 | # distributed under the License is distributed on an "AS IS" BASIS, 149 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 150 | # See the License for the specific language governing permissions and 151 | # limitations under the License. 152 | 153 | 154 | # Borrowed from https://github.com/fastai/fastai and changed to make it runs like PyTorch lr scheduler 155 | 156 | 157 | class CycleAnnealScheduler: 158 | def __init__( 159 | self, optimizer, lr_max, lr_divider, cut_point, step_size, momentum=None 160 | ): 161 | self.lr_max = lr_max 162 | self.lr_divider = lr_divider 163 | self.cut_point = step_size // cut_point 164 | self.step_size = step_size 165 | self.iteration = 0 166 | self.cycle_step = int(step_size * (1 - cut_point / 100) / 2) 167 | self.momentum = momentum 168 | self.optimizer = optimizer 169 | 170 | def get_lr(self): 171 | if self.iteration > 2 * self.cycle_step: 172 | cut = (self.iteration - 2 * self.cycle_step) / ( 173 | self.step_size - 2 * self.cycle_step 174 | ) 175 | lr = self.lr_max * (1 + (cut * (1 - 100) / 100)) / self.lr_divider 176 | 177 | elif self.iteration > self.cycle_step: 178 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 179 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 180 | 181 | else: 182 | cut = self.iteration / self.cycle_step 183 | lr = self.lr_max * (1 + cut * (self.lr_divider - 1)) / self.lr_divider 184 | 185 | return lr 186 | 187 | def get_momentum(self): 188 | if self.iteration > 2 * self.cycle_step: 189 | momentum = self.momentum[0] 190 | 191 | elif self.iteration > self.cycle_step: 192 | cut = 1 - (self.iteration - self.cycle_step) / self.cycle_step 193 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 194 | 195 | else: 196 | cut = self.iteration / self.cycle_step 197 | momentum = self.momentum[0] + cut * (self.momentum[1] - self.momentum[0]) 198 | 199 | return momentum 200 | 201 | def step(self): 202 | lr = self.get_lr() 203 | 204 | if self.momentum is not None: 205 | momentum = self.get_momentum() 206 | 207 | self.iteration += 1 208 | 209 | if self.iteration == self.step_size: 210 | self.iteration = 0 211 | 212 | for group in self.optimizer.param_groups: 213 | group['lr'] = lr 214 | 215 | if self.momentum is not None: 216 | group['betas'] = (momentum, group['betas'][1]) 217 | 218 | return lr 219 | 220 | 221 | def anneal_linear(start, end, proportion): 222 | return start + proportion * (end - start) 223 | 224 | 225 | def anneal_cos(start, end, proportion): 226 | cos_val = cos(pi * proportion) + 1 227 | 228 | return end + (start - end) / 2 * cos_val 229 | 230 | 231 | class Phase: 232 | def __init__(self, start, end, n_iter, anneal_fn): 233 | self.start, self.end = start, end 234 | self.n_iter = n_iter 235 | self.anneal_fn = anneal_fn 236 | self.n = 0 237 | 238 | def step(self): 239 | self.n += 1 240 | 241 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 242 | 243 | def reset(self): 244 | self.n = 0 245 | 246 | @property 247 | def is_done(self): 248 | return self.n >= self.n_iter 249 | 250 | 251 | class CycleScheduler: 252 | def __init__( 253 | self, 254 | optimizer, 255 | lr_max, 256 | n_iter, 257 | momentum=(0.95, 0.85), 258 | divider=25, 259 | warmup_proportion=0.3, 260 | phase=('linear', 'cos'), 261 | ): 262 | self.optimizer = optimizer 263 | 264 | phase1 = int(n_iter * warmup_proportion) 265 | phase2 = n_iter - phase1 266 | lr_min = lr_max / divider 267 | 268 | phase_map = {'linear': anneal_linear, 'cos': anneal_cos} 269 | 270 | self.lr_phase = [ 271 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]), 272 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]), 273 | ] 274 | 275 | self.momentum = momentum 276 | 277 | if momentum is not None: 278 | mom1, mom2 = momentum 279 | self.momentum_phase = [ 280 | Phase(mom1, mom2, phase1, phase_map[phase[0]]), 281 | Phase(mom2, mom1, phase2, phase_map[phase[1]]), 282 | ] 283 | 284 | else: 285 | self.momentum_phase = [] 286 | 287 | self.phase = 0 288 | 289 | def step(self): 290 | lr = self.lr_phase[self.phase].step() 291 | 292 | if self.momentum is not None: 293 | momentum = self.momentum_phase[self.phase].step() 294 | 295 | else: 296 | momentum = None 297 | 298 | for group in self.optimizer.param_groups: 299 | group['lr'] = lr 300 | 301 | if self.momentum is not None: 302 | if 'betas' in group: 303 | group['betas'] = (momentum, group['betas'][1]) 304 | 305 | else: 306 | group['momentum'] = momentum 307 | 308 | if self.lr_phase[self.phase].is_done: 309 | self.phase += 1 310 | 311 | if self.phase >= len(self.lr_phase): 312 | for phase in self.lr_phase: 313 | phase.reset() 314 | 315 | for phase in self.momentum_phase: 316 | phase.reset() 317 | 318 | self.phase = 0 319 | 320 | return lr, momentum 321 | 322 | 323 | class LRFinder(lr_scheduler._LRScheduler): 324 | def __init__(self, optimizer, lr_min, lr_max, step_size, linear=False): 325 | ratio = lr_max / lr_min 326 | self.linear = linear 327 | self.lr_min = lr_min 328 | self.lr_mult = (ratio / step_size) if linear else ratio ** (1 / step_size) 329 | self.iteration = 0 330 | self.lrs = [] 331 | self.losses = [] 332 | 333 | super().__init__(optimizer, -1) 334 | 335 | def get_lr(self): 336 | lr = ( 337 | self.lr_mult * self.iteration 338 | if self.linear 339 | else self.lr_mult**self.iteration 340 | ) 341 | lr = self.lr_min + lr if self.linear else self.lr_min * lr 342 | 343 | self.iteration += 1 344 | self.lrs.append(lr) 345 | 346 | return [lr for base_lr in self.base_lrs] 347 | 348 | def record(self, loss): 349 | self.losses.append(loss) 350 | 351 | def save(self, filename): 352 | with open(filename, 'w') as f: 353 | for lr, loss in zip(self.lrs, self.losses): 354 | f.write('{},{}\n'.format(lr, loss)) 355 | -------------------------------------------------------------------------------- /train_pixelsnail.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import os 5 | import torch 6 | from torch import nn, optim 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | amp = None 11 | 12 | from datasets import LMDBDataset 13 | from pixelsnail import PixelSNAIL 14 | from scheduler import CycleScheduler 15 | 16 | 17 | def train(epoch, loader, model, optimizer, scheduler, device): 18 | loader = tqdm(loader) 19 | 20 | criterion = nn.CrossEntropyLoss() 21 | 22 | for i, (bottom, class_id, salience, file_name) in enumerate(loader): 23 | model.zero_grad() 24 | 25 | class_id = torch.FloatTensor(list(list(class_id))).long().unsqueeze(1) 26 | # salience = torch.FloatTensor(list(map(eval, list(salience)))).unsqueeze(1) 27 | 28 | bottom = bottom.to(device) 29 | class_id = class_id.to(device) 30 | # salience = salience.to(device) 31 | 32 | target = bottom 33 | # print(target[0]) 34 | out, _ = model(bottom, label_condition=class_id) 35 | # out, _ = model(bottom, label_condition=class_id, salience_condition=salience) 36 | 37 | loss = criterion(out, target) 38 | # print(loss) 39 | loss.backward() 40 | 41 | if scheduler is not None: 42 | scheduler.step() 43 | optimizer.step() 44 | 45 | _, pred = out.max(1) 46 | correct = (pred == target).float() 47 | accuracy = correct.sum() / target.numel() 48 | 49 | lr = optimizer.param_groups[0]['lr'] 50 | 51 | loader.set_description( 52 | ( 53 | f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' 54 | f'acc: {accuracy:.5f}; lr: {lr:.5f}' 55 | ) 56 | ) 57 | 58 | 59 | class PixelTransform: 60 | def __init__(self): 61 | pass 62 | 63 | def __call__(self, input): 64 | ar = np.array(input) 65 | 66 | return torch.from_numpy(ar).long() 67 | 68 | 69 | if __name__ == '__main__': 70 | os.makedirs("checkpoint/pixelsnail-final", exist_ok=True) 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--batch', type=int, default=8) 73 | parser.add_argument('--epoch', type=int, default=1500) 74 | parser.add_argument('--hier', type=str, default='bottom') 75 | parser.add_argument('--lr', type=float, default=3e-4) 76 | parser.add_argument('--channel', type=int, default=256) 77 | parser.add_argument('--n_res_block', type=int, default=4) 78 | parser.add_argument('--n_res_channel', type=int, default=256) 79 | parser.add_argument('--n_out_res_block', type=int, default=0) 80 | parser.add_argument('--n_cond_res_block', type=int, default=3) 81 | parser.add_argument('--dropout', type=float, default=0.1) 82 | parser.add_argument('--amp', type=str, default='O0') 83 | parser.add_argument('--sched', type=str) 84 | parser.add_argument('--ckpt', type=str) 85 | parser.add_argument('--path', type=str, default='vqvae-code/') 86 | 87 | args = parser.parse_args() 88 | 89 | print(args) 90 | 91 | device = 'cuda' 92 | 93 | dataset = LMDBDataset(args.path) 94 | loader = DataLoader( 95 | dataset, batch_size=args.batch, shuffle=True, num_workers=4, drop_last=True 96 | ) 97 | 98 | ckpt = {} 99 | start_point = 0 100 | 101 | if args.ckpt is not None: 102 | _, start_point = args.ckpt.split('_') 103 | start_point = int(start_point[0:-3]) 104 | 105 | ckpt = torch.load(args.ckpt) 106 | args = ckpt['args'] 107 | 108 | if args.hier == 'top': 109 | # model = PixelSNAIL( 110 | # [10, 43], 111 | # 512, 112 | # args.channel, 113 | # 5, 114 | # 4, 115 | # args.n_res_block, 116 | # args.n_res_channel, 117 | # dropout=args.dropout, 118 | # n_out_res_block=args.n_out_res_block, 119 | # ) 120 | 121 | model = PixelSNAIL( 122 | [10, 43], 123 | 512, 124 | 256, 125 | 5, 126 | 4, 127 | 4, 128 | 256, 129 | dropout=0.1, 130 | n_out_res_block=0, 131 | cond_res_channel=0, 132 | ) 133 | 134 | elif args.hier == 'bottom': 135 | model = PixelSNAIL( 136 | [20, 86], 137 | 512, 138 | args.channel, 139 | 5, 140 | 4, 141 | args.n_res_block, 142 | args.n_res_channel, 143 | # attention=False, 144 | dropout=args.dropout, 145 | n_cond_res_block=args.n_cond_res_block, 146 | cond_res_channel=args.n_res_channel, 147 | ) 148 | 149 | if 'model' in ckpt: 150 | model.load_state_dict(ckpt['model']) 151 | 152 | model = model.to(device) 153 | # print(model) 154 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 155 | 156 | if amp is not None: 157 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp) 158 | 159 | model = nn.DataParallel(model) 160 | model = model.to(device) 161 | 162 | scheduler = None 163 | if args.sched == 'cycle': 164 | scheduler = CycleScheduler( 165 | optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None 166 | ) 167 | 168 | for i in range(start_point, start_point + args.epoch): 169 | train(i, loader, model, optimizer, scheduler, device) 170 | 171 | torch.save( 172 | {'model': model.module.state_dict(), 'args': args}, 173 | f'checkpoint/pixelsnail-final/{args.hier}_{str(i + 1).zfill(3)}.pt', 174 | ) 175 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import argparse 3 | import sys 4 | import os 5 | 6 | import torch 7 | from torch import nn, optim 8 | from torch.utils.data import DataLoader 9 | 10 | from tqdm import tqdm 11 | 12 | from vqvae import VQVAE 13 | from scheduler import CycleScheduler 14 | 15 | import audio2mel 16 | from datasets import get_dataset_filelist 17 | 18 | 19 | def train(epoch, loader, model, optimizer, scheduler, device): 20 | model.train() 21 | 22 | loader = tqdm(loader) 23 | 24 | criterion = nn.MSELoss() 25 | 26 | latent_loss_weight = 0.25 27 | 28 | mse_sum = 0 29 | mse_n = 0 30 | 31 | for i, (img, _, _, _) in enumerate(loader): 32 | model.zero_grad() 33 | 34 | img = img.to(device) 35 | 36 | out, latent_loss = model(img) 37 | recon_loss = criterion(out, img) 38 | latent_loss = latent_loss.mean() 39 | loss = recon_loss + latent_loss_weight * latent_loss 40 | loss.backward() 41 | 42 | if scheduler is not None: 43 | scheduler.step() 44 | optimizer.step() 45 | mse_sum = recon_loss.item() * img.shape[0] # img.shape[0] = batch_size 46 | mse_n = img.shape[0] 47 | 48 | lr = optimizer.param_groups[0]["lr"] 49 | 50 | loader.set_description( 51 | ( 52 | f"Epoch: {epoch + 1}; MSE: {recon_loss.item():.5f}; " 53 | f"latent: {latent_loss.item():.3f}; Avg MSE: {mse_sum / mse_n:.5f}; " 54 | f"lr: {lr:.5f}" 55 | ) 56 | ) 57 | 58 | latent_diff = latent_loss.item() 59 | return latent_diff, (mse_sum / mse_n) 60 | 61 | 62 | def test(epoch, loader, model, optimizer, scheduler, device): 63 | model.eval() 64 | 65 | criterion = nn.MSELoss() 66 | 67 | mse_sum = 0 68 | mse_n = 0 69 | 70 | for i, (img, _, _, _) in enumerate(loader): 71 | model.zero_grad() 72 | 73 | img = img.to(device) 74 | 75 | out, latent_loss = model(img) 76 | recon_loss = criterion(out, img) 77 | latent_loss = latent_loss.mean() 78 | 79 | if scheduler is not None: 80 | scheduler.step() 81 | optimizer.step() 82 | part_mse_sum = recon_loss.item() * img.shape[0] # img.shape[0] = batch_size 83 | part_mse_n = img.shape[0] 84 | comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} 85 | 86 | for part in comm: 87 | mse_sum += part["mse_sum"] 88 | mse_n += part["mse_n"] 89 | 90 | # validation 91 | if i % 100 == 0: 92 | pass 93 | 94 | latent_diff = latent_loss.item() 95 | if (epoch + 1) % 10 == 0: 96 | print( 97 | f"\nTest_Epoch: {epoch + 1}; " 98 | f"latent: {latent_diff:.3f}; Avg MSE: {mse_sum / mse_n:.5f} \n" 99 | ) 100 | return latent_diff, (mse_sum / mse_n) 101 | 102 | 103 | def main(args): 104 | device = "cuda" 105 | 106 | train_file_list: List[dict] = get_dataset_filelist() 107 | 108 | train_set = audio2mel.Audio2Mel( 109 | train_file_list, 22050 * 4, 1024, 80, 256, 22050, 0, 8000 110 | ) 111 | 112 | train_loader = DataLoader( 113 | train_set, batch_size=args.batch // args.n_gpu, num_workers=4, shuffle=True 114 | ) 115 | 116 | print("training set size: " + str(len(train_set))) 117 | 118 | model = VQVAE().to(device) 119 | 120 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 121 | scheduler = None 122 | if args.sched == "cycle": 123 | scheduler = CycleScheduler( 124 | optimizer, 125 | args.lr, 126 | n_iter=len(train_loader) * args.epoch, 127 | momentum=None, 128 | warmup_proportion=0.05, 129 | ) 130 | 131 | for i in range(args.epoch): 132 | train_latent_diff, train_average_loss = train( 133 | i, train_loader, model, optimizer, scheduler, device 134 | ) 135 | 136 | # if dist.is_primary(): 137 | 138 | torch.save( 139 | model.state_dict(), f"checkpoint/vqvae/vqvae_{str(i + 1).zfill(3)}.pt" 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | os.makedirs("checkpoint/vqvae/", exist_ok=True) 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--n_gpu", type=int, default=1) 147 | 148 | port = ( 149 | 2**15 150 | + 2**14 151 | + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14 152 | ) 153 | 154 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 155 | 156 | parser.add_argument("--epoch", type=int, default=800) 157 | parser.add_argument("--lr", type=float, default=3e-4) 158 | parser.add_argument("--batch", type=int, default=16) 159 | parser.add_argument("--sched", type=str) 160 | 161 | args = parser.parse_args() 162 | 163 | print(args) 164 | 165 | main(args) 166 | -------------------------------------------------------------------------------- /vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | # Copyright 2018 The Sonnet Authors. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ============================================================================ 19 | 20 | 21 | # Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch 22 | 23 | 24 | class Quantize(nn.Module): 25 | def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): 26 | super().__init__() 27 | 28 | self.dim = dim 29 | self.n_embed = n_embed 30 | self.decay = decay 31 | self.eps = eps 32 | 33 | embed = torch.randn(dim, n_embed) 34 | self.register_buffer("embed", embed) 35 | self.register_buffer("cluster_size", torch.zeros(n_embed)) 36 | self.register_buffer("embed_avg", embed.clone()) 37 | 38 | def forward(self, input): 39 | flatten = input.reshape(-1, self.dim) 40 | dist = ( 41 | flatten.pow(2).sum(1, keepdim=True) 42 | - 2 * flatten @ self.embed 43 | + self.embed.pow(2).sum(0, keepdim=True) 44 | ) 45 | _, embed_ind = (-dist).max(1) 46 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) 47 | embed_ind = embed_ind.view(*input.shape[:-1]) 48 | quantize = self.embed_code(embed_ind) 49 | 50 | if self.training: 51 | embed_onehot_sum = embed_onehot.sum(0) 52 | embed_sum = flatten.transpose(0, 1) @ embed_onehot 53 | 54 | self.cluster_size.data.mul_(self.decay).add_( 55 | embed_onehot_sum, alpha=1 - self.decay 56 | ) 57 | self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) 58 | n = self.cluster_size.sum() 59 | cluster_size = ( 60 | (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n 61 | ) 62 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) 63 | self.embed.data.copy_(embed_normalized) 64 | 65 | diff = (quantize.detach() - input).pow(2).mean() 66 | quantize = input + (quantize - input).detach() 67 | 68 | return quantize, diff, embed_ind 69 | 70 | def embed_code(self, embed_id): 71 | return F.embedding(embed_id, self.embed.transpose(0, 1)) 72 | 73 | 74 | class ResBlock(nn.Module): 75 | def __init__(self, in_channel, channel): 76 | super().__init__() 77 | 78 | self.conv = nn.Sequential( 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(in_channel, channel, 3, padding=1), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(channel, in_channel, 1), 83 | ) 84 | 85 | def forward(self, input): 86 | out = self.conv(input) 87 | out += input 88 | 89 | return out 90 | 91 | 92 | class Encoder(nn.Module): 93 | def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): 94 | super().__init__() 95 | 96 | if stride == 4: 97 | blocks_1 = [ 98 | nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), 99 | nn.ReLU(inplace=True), 100 | nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), 101 | nn.ReLU(inplace=True), 102 | nn.Conv2d(channel, channel, 3, padding=1), 103 | ] 104 | 105 | blocks_2 = [ 106 | nn.Conv2d(in_channel, channel // 2, 2, stride=2, padding=0), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(channel // 2, channel, 2, stride=2, padding=0), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(channel, channel, 3, padding=1), 111 | ] 112 | 113 | blocks_3 = [ 114 | nn.Conv2d(in_channel, channel // 2, 6, stride=2, padding=2), 115 | nn.ReLU(inplace=True), 116 | nn.Conv2d(channel // 2, channel, 6, stride=2, padding=2), 117 | nn.ReLU(inplace=True), 118 | nn.Conv2d(channel, channel, 3, padding=1), 119 | ] 120 | 121 | blocks_4 = [ 122 | nn.Conv2d(in_channel, channel // 2, 8, stride=2, padding=3), 123 | nn.ReLU(inplace=True), 124 | nn.Conv2d(channel // 2, channel, 8, stride=2, padding=3), 125 | nn.ReLU(inplace=True), 126 | nn.Conv2d(channel, channel, 3, padding=1), 127 | ] 128 | 129 | # elif stride == 2: 130 | # blocks_1 = [ 131 | # nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), 132 | # nn.ReLU(inplace=True), 133 | # nn.Conv2d(channel // 2, channel, 3, padding=1), 134 | # ] 135 | # 136 | # blocks_2 = [ 137 | # nn.Conv2d(in_channel, channel // 2, 2, stride=2, padding=0), 138 | # nn.ReLU(inplace=True), 139 | # nn.Conv2d(channel // 2, channel, 3, padding=1), 140 | # ] 141 | # 142 | # blocks_3 = [ 143 | # nn.Conv2d(in_channel, channel // 2, 8, stride=2, padding=3), 144 | # nn.ReLU(inplace=True), 145 | # nn.Conv2d(channel // 2, channel, 3, padding=1), 146 | # ] 147 | # 148 | # blocks_4 = [ 149 | # nn.Conv2d(in_channel, channel // 2, 16, stride=2, padding=7), 150 | # nn.ReLU(inplace=True), 151 | # nn.Conv2d(channel // 2, channel, 3, padding=1), 152 | # ] 153 | 154 | for i in range(n_res_block): 155 | blocks_1.append(ResBlock(channel, n_res_channel)) 156 | blocks_2.append(ResBlock(channel, n_res_channel)) 157 | blocks_3.append(ResBlock(channel, n_res_channel)) 158 | blocks_4.append(ResBlock(channel, n_res_channel)) 159 | 160 | blocks_1.append(nn.ReLU(inplace=True)) 161 | blocks_2.append(nn.ReLU(inplace=True)) 162 | blocks_3.append(nn.ReLU(inplace=True)) 163 | blocks_4.append(nn.ReLU(inplace=True)) 164 | 165 | self.blocks_1 = nn.Sequential(*blocks_1) 166 | self.blocks_2 = nn.Sequential(*blocks_2) 167 | self.blocks_3 = nn.Sequential(*blocks_3) 168 | self.blocks_4 = nn.Sequential(*blocks_4) 169 | 170 | def forward(self, input): 171 | return ( 172 | self.blocks_1(input) 173 | + self.blocks_2(input) 174 | + self.blocks_3(input) 175 | + self.blocks_4(input) 176 | ) 177 | 178 | # return self.blocks_1(input) 179 | 180 | 181 | class Decoder(nn.Module): 182 | def __init__( 183 | self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride 184 | ): 185 | super().__init__() 186 | 187 | blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] 188 | 189 | for i in range(n_res_block): 190 | blocks.append(ResBlock(channel, n_res_channel)) 191 | 192 | blocks.append(nn.ReLU(inplace=True)) 193 | 194 | if stride == 4: 195 | blocks.extend( 196 | [ 197 | nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), 198 | nn.ReLU(inplace=True), 199 | nn.ConvTranspose2d( 200 | channel // 2, out_channel, 4, stride=2, padding=1 201 | ), 202 | ] 203 | ) 204 | 205 | elif stride == 2: 206 | blocks.append( 207 | nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) 208 | ) 209 | 210 | self.blocks = nn.Sequential(*blocks) 211 | 212 | def forward(self, input): 213 | return self.blocks(input) 214 | 215 | 216 | class VQVAE(nn.Module): 217 | def __init__( 218 | self, 219 | in_channel=1, # for mel-spec. 220 | channel=128, 221 | n_res_block=2, 222 | n_res_channel=32, 223 | embed_dim=64, 224 | n_embed=512, 225 | decay=0.99, 226 | ): 227 | super().__init__() 228 | 229 | self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) 230 | # self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) 231 | # self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1) 232 | # self.quantize_t = Quantize(embed_dim, n_embed) 233 | # self.dec_t = Decoder(embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2) 234 | self.quantize_conv_b = nn.Conv2d(channel, embed_dim, 1) 235 | self.quantize_b = Quantize(embed_dim, n_embed) 236 | # self.upsample_t = nn.ConvTranspose2d( 237 | # embed_dim, embed_dim, 4, stride=2, padding=1 238 | # ) 239 | self.dec = Decoder( 240 | embed_dim, 241 | in_channel, 242 | channel, 243 | n_res_block, 244 | n_res_channel, 245 | stride=4, 246 | ) 247 | 248 | def forward(self, input): 249 | quant_b, diff, _ = self.encode(input) 250 | dec = self.decode(quant_b) 251 | 252 | return dec, diff 253 | 254 | def encode(self, input): 255 | enc_b = self.enc_b(input) 256 | # enc_t = self.enc_t(enc_b) 257 | 258 | quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1) 259 | quant_b, diff_b, id_b = self.quantize_b(quant_b) 260 | quant_b = quant_b.permute(0, 3, 1, 2) 261 | diff_b = diff_b.unsqueeze(0) 262 | 263 | return quant_b, diff_b, id_b 264 | 265 | def decode(self, quant_b): 266 | # _dec = self.dec_t(quant_t) 267 | dec = self.dec(quant_b) 268 | 269 | return dec 270 | 271 | def decode_code(self, code_b): 272 | quant_b = self.quantize_b.embed_code(code_b) 273 | quant_b = quant_b.permute(0, 3, 1, 2) 274 | 275 | dec = self.decode(quant_b) 276 | 277 | return dec 278 | 279 | 280 | if __name__ == '__main__': 281 | import audio2mel 282 | from datasets import get_dataset_filelist 283 | from torch.utils.data import DataLoader 284 | 285 | train_file_list, _ = get_dataset_filelist() 286 | 287 | train_set = audio2mel.Audio2Mel( 288 | train_file_list[0:4], 22050 * 4, 1024, 80, 256, 22050, 0, 8000 289 | ) 290 | 291 | loader = DataLoader(train_set, batch_size=2, sampler=None, num_workers=2) 292 | 293 | model = VQVAE() 294 | 295 | a = torch.randn(3, 3).to('cuda') 296 | print(a) 297 | model = model.to('cuda') 298 | 299 | for i, batch in enumerate(loader): 300 | mel, id, name = batch 301 | mel = mel.to('cuda') 302 | out, latent_loss = model(mel) 303 | print(out.shape) 304 | if i == 5: 305 | break 306 | --------------------------------------------------------------------------------