├── .gitattributes ├── .ipynb_checkpoints ├── 00_data_preprocess-checkpoint.ipynb ├── 01_basic_ASR_structure-checkpoint.ipynb └── 02_training_model-checkpoint.ipynb ├── 00_data_preprocess.ipynb ├── 01_basic_ASR_structure.ipynb ├── 02_training_model.ipynb ├── 02_training_model.py ├── 03_predict_and_return_output_wav_files.ipynb ├── 03_predict_and_return_output_wav_files.py ├── README.md ├── __pycache__ ├── asr_model.cpython-35.pyc ├── dataset.cpython-35.pyc ├── io_utils.cpython-35.pyc ├── model.cpython-35.pyc ├── standard.cpython-35.pyc ├── subpixel.cpython-35.pyc └── summarization.cpython-35.pyc ├── asr_model.py ├── asr_model.pyc ├── data ├── temp.txt ├── test.txt ├── train.txt └── valid.txt ├── dataset.py ├── default_log_name.lr0.000500.1.g4.b100 ├── checkpoint ├── events.out.tfevents.1531203790.smart-deep-learning ├── model.ckpt-1.data-00000-of-00001 ├── model.ckpt-1.index ├── model.ckpt-1.meta ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index └── model.ckpt.meta ├── example-hr.wav ├── example-lr.wav ├── git-lfs-2.4.2 ├── CHANGELOG.md ├── README.md ├── git-lfs └── install.sh ├── io_utils.py ├── model.py ├── practice.ipynb ├── src └── models │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── audiounet.cpython-35.pyc │ ├── dataset.cpython-35.pyc │ ├── io.cpython-35.pyc │ └── model.cpython-35.pyc │ ├── audiounet.py │ ├── audiounet.pyc │ ├── dataset.py │ ├── io.py │ ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ └── subpixel.cpython-35.pyc │ ├── standard.py │ ├── subpixel.py │ └── summarization.py │ └── model.py ├── standard.py ├── subpixel.py ├── summarization.py ├── temp.h5 ├── train.h5 └── valid.h5 /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/00_data_preprocess-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### referenced by - https://github.com/kuleshov/audio-super-res\n", 8 | "* 아래 audio-super-res로 명시된것은 위의 참고자료 모델을 의미함" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## 1. 학습에 사용될 원본 wav파일들을 사용하여 학습용 h5포맷 데이터 생성하기 \n", 16 | "* ./data/train 폴더에 학습용 wav파일 원본들을 저장한다\n", 17 | "* ./data/train.txt에 학습용 wav파일의 목록을 저장한다\n", 18 | "* 전처리 과정에서 low-res와 high-res버전의 데이터를 생성하고 이를 h5포맷으로 저장한다" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import os, argparse\n", 28 | "import numpy as np\n", 29 | "import h5py\n", 30 | "import librosa\n", 31 | "from scipy import interpolate\n", 32 | "from scipy.signal import decimate\n", 33 | "from scipy.signal import butter, lfilter\n", 34 | "args = {\n", 35 | " 'interpolate' : 0, # False\n", 36 | " 'dimension' : 8192 *2, # dimension of patches \n", 37 | " 'stride' : 8192 *2, # stride of patches - 8192 is apploximately 1 second\n", 38 | " 'scale' : 4, \n", 39 | " 'batch_size' : 1,\n", 40 | " 'sr' : 16000,# sampling rate\n", 41 | " 'sam' : 1,\n", 42 | " 'train_out' : 'train.h5',\n", 43 | " 'train_in_dir' : '../temp/data/train/', # the location where training data are\n", 44 | " 'valid_out' : 'valid.h5',\n", 45 | " 'valid_in_dir' : '../temp/data/valid/',\n", 46 | " 'train_file_list' : './data/train.txt', # file name list of training data\n", 47 | " 'valid_file_list' : './data/valid.txt',\n", 48 | " 'temp_out' : 'temp.h5',\n", 49 | " 'temp_in_dir' : '../temp/data/temp/',\n", 50 | " 'temp_file_list' : './data/temp.txt'\n", 51 | "}" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "\n", 61 | "def upsample(x_lr, r):\n", 62 | " x_lr = x_lr.flatten()\n", 63 | " x_hr_len = len(x_lr) * r\n", 64 | " x_sp = np.zeros(x_hr_len)\n", 65 | "\n", 66 | " i_lr = np.arange(x_hr_len, step=r)\n", 67 | " i_hr = np.arange(x_hr_len)\n", 68 | "\n", 69 | " f = interpolate.splrep(i_lr, x_lr)\n", 70 | "\n", 71 | " x_sp = interpolate.splev(i_hr, f)\n", 72 | "\n", 73 | " return x_sp" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def add_data_preprocessed(h5_file, inputfiles, in_dir, args, save_examples=False):\n", 83 | " \n", 84 | " # 1) read original dataset\n", 85 | " \n", 86 | " file_list = []\n", 87 | " file_extensions = set(['.wav'])\n", 88 | " with open(inputfiles) as f:\n", 89 | " for line in f:\n", 90 | " filename = line.strip()\n", 91 | " ext = os.path.splitext(filename)[1]\n", 92 | " if ext in file_extensions:\n", 93 | " file_list.append(os.path.join(in_dir, filename))\n", 94 | " \n", 95 | " num_files = len(file_list)\n", 96 | " \n", 97 | " # 2) read wav file (we always use interpolate mode) \n", 98 | " # 3) create low-res version\n", 99 | " # 4) upsample low-res version for same data size\n", 100 | " # 5) patch the data\n", 101 | " \n", 102 | " d, d_lr = args['dimension'], (args['dimension'])\n", 103 | " s, s_lr = args['stride'], (args['stride'])\n", 104 | " hr_patches = list()\n", 105 | " lr_patches = list()\n", 106 | " for j, file_path in enumerate(file_list):\n", 107 | " if j % 10 == 0: print('%d/%d' % (j, num_files))\n", 108 | " \n", 109 | " # load audio file\n", 110 | " x, fs = librosa.load(file_path, sr=args['sr']) # sr = sample rates\n", 111 | " \n", 112 | " # crop so that it works with scailing ratio\n", 113 | " x_len = len(x)\n", 114 | " x = x[ : x_len - (x_len % args['scale'])]\n", 115 | " \n", 116 | " # generate low-res version\n", 117 | " x_lr = decimate(x, args['scale'])\n", 118 | " \n", 119 | " # upsample data(we will use preprocessed low-res data)\n", 120 | " # EX. scale x4 on dimension\n", 121 | " # data (low-res )2048 ---> [cubic-upscaling] --> 8192 ---> model input (8192)\n", 122 | " # label (high-res)8192 -----------------------------------> model output(8192)\n", 123 | " x_lr = upsample(x_lr, args['scale'])\n", 124 | " \n", 125 | " assert len(x) % args['scale'] == 0\n", 126 | " assert len(x_lr) == (len(x))\n", 127 | " \n", 128 | " # Generate patches\n", 129 | " max_i = len(x) - d + 1 # d = dimension\n", 130 | " for i in range(0, max_i, s): # s = strides \n", 131 | " # keep only a fraction of all the patches\n", 132 | " u = np.random.uniform()\n", 133 | " if u > args['sam']: continue\n", 134 | " \n", 135 | " i_lr = i\n", 136 | " \n", 137 | " hr_patch = np.array( x[i : i+d] )\n", 138 | " lr_patch = np.array( x_lr[i_lr : i_lr + d_lr] ) \n", 139 | " assert len(hr_patch) == d\n", 140 | " assert len(lr_patch) == d_lr\n", 141 | " \n", 142 | " hr_patches.append(hr_patch.reshape((d,1)))\n", 143 | " lr_patches.append(lr_patch.reshape((d_lr,1)))\n", 144 | " \n", 145 | " \n", 146 | " # 6) save as .h5 files \n", 147 | " # crop # of patches so that it's a multiple of mini-batch size\n", 148 | " num_hr_patches = len(hr_patches)\n", 149 | " num_lr_patches = len(lr_patches)\n", 150 | " \n", 151 | " print('num_hr_patches:', num_hr_patches)\n", 152 | " print('num_lr_patches:', num_lr_patches)\n", 153 | " print('batch_size:', args['batch_size'])\n", 154 | " num_to_keep_hr = int(np.floor(num_hr_patches / args['batch_size']) * args['batch_size'])\n", 155 | " hr_patches = np.array(hr_patches[:num_to_keep_hr])\n", 156 | " \n", 157 | " num_to_keep_lr = int(np.floor(num_lr_patches / args['batch_size']) * args['batch_size'])\n", 158 | " lr_patches = np.array(lr_patches[:num_to_keep_lr])\n", 159 | "\n", 160 | " if save_examples:\n", 161 | " librosa.output.write_wav('example-hr.wav', hr_patches[40], fs, norm=False)\n", 162 | " #librosa.output.write_wav('example-lr.wav', lr_patches[40], int(fs / args['scale']), norm=False)\n", 163 | " librosa.output.write_wav('example-lr.wav', lr_patches[40], fs, norm=False)\n", 164 | " print (hr_patches[40].shape)\n", 165 | " print (lr_patches[40].shape)\n", 166 | " print (hr_patches[40][0][:10])\n", 167 | " print (lr_patches[40][0][:10])\n", 168 | " print ('two examples saved')\n", 169 | "\n", 170 | " print ('hr_patches shape:',hr_patches.shape)\n", 171 | " print ('lr_patches shape:',lr_patches.shape)\n", 172 | "\n", 173 | " # create the hdf5 file\n", 174 | " data_set = h5_file.create_dataset('data', lr_patches.shape, np.float32) # lr\n", 175 | " label_set = h5_file.create_dataset('label', hr_patches.shape, np.float32) # hr\n", 176 | "\n", 177 | " data_set[...] = lr_patches\n", 178 | " label_set[...] = hr_patches" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 4, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "0/231\n", 191 | "10/231\n", 192 | "20/231\n", 193 | "30/231\n", 194 | "40/231\n", 195 | "50/231\n", 196 | "60/231\n", 197 | "70/231\n", 198 | "80/231\n", 199 | "90/231\n", 200 | "100/231\n", 201 | "110/231\n", 202 | "120/231\n", 203 | "130/231\n", 204 | "140/231\n", 205 | "150/231\n", 206 | "160/231\n", 207 | "170/231\n", 208 | "180/231\n", 209 | "190/231\n", 210 | "200/231\n", 211 | "210/231\n", 212 | "220/231\n", 213 | "230/231\n", 214 | "num_hr_patches: 852\n", 215 | "num_lr_patches: 852\n", 216 | "batch_size: 1\n", 217 | "(16384, 1)\n", 218 | "(16384, 1)\n", 219 | "[0.14710128]\n", 220 | "[0.09391868]\n", 221 | "two examples saved\n", 222 | "hr_patches shape: (852, 16384, 1)\n", 223 | "lr_patches shape: (852, 16384, 1)\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "# create train\n", 229 | "with h5py.File(args['train_out'], 'w') as f:\n", 230 | " add_data_preprocessed(f, args['train_file_list'], args['train_in_dir'],args, save_examples=True)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## 2. 마찬가지로 검증 데이터셋도 h5파일로 구성한다\n", 238 | "* ./data/valid 폴더에 검증용 원본 데이터를 저장한다\n", 239 | "* ./data/valid.txt 파일에 검증용 데이터 목록을 저장한다\n", 240 | "* 1번과 같은 방식으로 h5파일 포맷으로 데이터를 생성" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 5, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "0/125\n", 253 | "10/125\n", 254 | "20/125\n", 255 | "30/125\n", 256 | "40/125\n", 257 | "50/125\n", 258 | "60/125\n", 259 | "70/125\n", 260 | "80/125\n", 261 | "90/125\n", 262 | "100/125\n", 263 | "110/125\n", 264 | "120/125\n", 265 | "num_hr_patches: 287\n", 266 | "num_lr_patches: 287\n", 267 | "batch_size: 1\n", 268 | "(16384, 1)\n", 269 | "(16384, 1)\n", 270 | "[-0.03057251]\n", 271 | "[-0.02454884]\n", 272 | "two examples saved\n", 273 | "hr_patches shape: (287, 16384, 1)\n", 274 | "lr_patches shape: (287, 16384, 1)\n" 275 | ] 276 | } 277 | ], 278 | "source": [ 279 | "# create validation\n", 280 | "with h5py.File(args['valid_out'], 'w') as f:\n", 281 | " add_data_preprocessed(f, args['valid_file_list'], args['valid_in_dir'],args, save_examples=True)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "## 3. 기타 다른 데이터를 사용하기\n", 289 | "* ./data/temp 폴더에 원본 wav파일을 저장\n", 290 | "* ./data/temp.txt에 파일 리스트를 작성\n", 291 | "* 아래 코드를 실행하여 h5포맷으로 데이터를 생성" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 6, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "num_hr_patches: 0\n", 304 | "num_lr_patches: 0\n", 305 | "batch_size: 1\n", 306 | "hr_patches shape: (0,)\n", 307 | "lr_patches shape: (0,)\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# create another set\n", 313 | "with h5py.File(args['temp_out'], 'w') as f:\n", 314 | " add_data_preprocessed(f, args['temp_file_list'], args['temp_in_dir'], args, save_examples=False)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.5.2" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 2 346 | } 347 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/01_basic_ASR_structure-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import tensorflow as tf\n", 11 | "from scipy import interpolate\n", 12 | "from model import Model, default_opt\n", 13 | "from subpixel import SubPixel1D, SubPixel1D_v2\n", 14 | "from standard import conv1d, deconv1d" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# shape = (num of example, sampling-ratio, num of channel)\n", 24 | "X = tf.placeholder(tf.float32, shape=(100, 8192, 1), name='X') # in model, we use input& output shape as (None, None, 1)\n", 25 | "Y = tf.placeholder(tf.float32, shape=(100, 8192, 1), name='Y') \n", 26 | "downsampled_l = []\n", 27 | "di = 0" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def reshape1Dto2D(X):\n", 37 | " n_batch, n_width, n_chan = X.get_shape()\n", 38 | " X = tf.reshape(X,[n_batch, 1, n_width, n_chan])\n", 39 | " return X\n", 40 | "\n", 41 | "def reshape2Dto1D(X):\n", 42 | " # reshape 2D -> 1D\n", 43 | " n_batch, _, n_width, n_chan = X.get_shape()\n", 44 | " X = tf.reshape(X,[n_batch, n_width, n_chan])\n", 45 | " return X\n", 46 | "\n", 47 | "def downsample_layer(x , nf, ks, B=False):\n", 48 | " x = tf.layers.conv1d(\n", 49 | " x,\n", 50 | " filters = nf,\n", 51 | " kernel_size = ks,\n", 52 | " strides=1,\n", 53 | " padding='same',\n", 54 | " data_format='channels_last',\n", 55 | " dilation_rate=1,\n", 56 | " activation=None,\n", 57 | " use_bias=True,\n", 58 | " kernel_initializer=None,\n", 59 | " bias_initializer=tf.zeros_initializer(),\n", 60 | " kernel_regularizer=None,\n", 61 | " bias_regularizer=None,\n", 62 | " activity_regularizer=None,\n", 63 | " kernel_constraint=None,\n", 64 | " bias_constraint=None,\n", 65 | " trainable=True,\n", 66 | " name=None,\n", 67 | " reuse=None\n", 68 | " )\n", 69 | " x = tf.layers.max_pooling1d(\n", 70 | " x,\n", 71 | " pool_size = 2,\n", 72 | " strides = 2,\n", 73 | " padding='same',\n", 74 | " data_format='channels_last',\n", 75 | " name=None\n", 76 | " )\n", 77 | " \n", 78 | " if B : x = tf.layers.dropout(x, rate=0.5)\n", 79 | " \n", 80 | " x = tf.nn.relu(x)\n", 81 | " return x\n", 82 | "\n", 83 | "def upsample_layer(x, nf, ks):\n", 84 | " '''x = tf.layers.conv2d_transpose(\n", 85 | " x,\n", 86 | " filters = nf,\n", 87 | " kernel_size = [1,ks],\n", 88 | " strides=(1, 1),\n", 89 | " padding='same',\n", 90 | " data_format='channels_last',\n", 91 | " activation=None,\n", 92 | " use_bias=True,\n", 93 | " kernel_initializer=None,\n", 94 | " bias_initializer=tf.zeros_initializer(),\n", 95 | " kernel_regularizer=None,\n", 96 | " bias_regularizer=None,\n", 97 | " activity_regularizer=None,\n", 98 | " kernel_constraint=None,\n", 99 | " bias_constraint=None,\n", 100 | " trainable=True,\n", 101 | " name=None,\n", 102 | " reuse=None\n", 103 | " )'''\n", 104 | " x = tf.layers.conv1d(\n", 105 | " x,\n", 106 | " filters = nf,\n", 107 | " kernel_size = ks,\n", 108 | " strides=1,\n", 109 | " padding='same',\n", 110 | " data_format='channels_last',\n", 111 | " dilation_rate=1,\n", 112 | " activation=None,\n", 113 | " use_bias=True,\n", 114 | " kernel_initializer=None,\n", 115 | " bias_initializer=tf.zeros_initializer(),\n", 116 | " kernel_regularizer=None,\n", 117 | " bias_regularizer=None,\n", 118 | " activity_regularizer=None,\n", 119 | " kernel_constraint=None,\n", 120 | " bias_constraint=None,\n", 121 | " trainable=True,\n", 122 | " name=None,\n", 123 | " reuse=None\n", 124 | " )\n", 125 | " x = tf.layers.dropout(x, rate=0.5)\n", 126 | " x = tf.nn.relu(x)\n", 127 | " x = SubPixel1D(x,r=2)\n", 128 | " return x" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | ">> Generator Model init...\n", 141 | "D-Block >> Tensor(\"Relu:0\", shape=(100, 4096, 48), dtype=float32)\n", 142 | "D-Block >> Tensor(\"Relu_1:0\", shape=(100, 2048, 96), dtype=float32)\n", 143 | "D-Block >> Tensor(\"Relu_2:0\", shape=(100, 1024, 128), dtype=float32)\n", 144 | "D-Block >> Tensor(\"Relu_3:0\", shape=(100, 512, 128), dtype=float32)\n", 145 | "B-Block >> Tensor(\"Relu_4:0\", shape=(100, 256, 128), dtype=float32)\n", 146 | "U-Block >> Tensor(\"concat:0\", shape=(100, 512, 256), dtype=float32)\n", 147 | "U-Block >> Tensor(\"concat_1:0\", shape=(100, 1024, 256), dtype=float32)\n", 148 | "U-Block >> Tensor(\"concat_2:0\", shape=(100, 2048, 192), dtype=float32)\n", 149 | "U-Block >> Tensor(\"concat_3:0\", shape=(100, 4096, 96), dtype=float32)\n", 150 | "Fin-Layer >> Tensor(\"Add:0\", shape=(100, 8192, 1), dtype=float32)\n", 151 | ">> ...finish\n", 152 | "\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "n_filters = [48, 96, 128, 128]\n", 158 | "n_filtersizes = [24, 12, 6, 3]\n", 159 | "L = 4\n", 160 | "\n", 161 | "# save origin-X \n", 162 | "oX = X\n", 163 | "print('>> Generator Model init...')\n", 164 | "print('Input >> ', X)\n", 165 | "# downsampling layers\n", 166 | "for l, nf, fs in zip(range(L), n_filters, n_filtersizes):\n", 167 | " X = downsample_layer(X, nf, fs)\n", 168 | " downsampled_l.append(X)\n", 169 | " print('D-Block >> ' ,X)\n", 170 | "\n", 171 | "# Bottle-neck layer\n", 172 | "X = downsample_layer(X, n_filters[-1], n_filtersizes[-1], B=True)\n", 173 | "print('B-Block >> ', X)\n", 174 | "\n", 175 | "# Upsample layer\n", 176 | "L = reversed(range(L))\n", 177 | "n_filters = reversed(n_filters)\n", 178 | "n_filtersizes = reversed(n_filtersizes)\n", 179 | "downsampled_l = reversed(downsampled_l)\n", 180 | "\n", 181 | "for l, nf, fs, l_in in zip( L, (n_filters), (n_filtersizes), (downsampled_l)):\n", 182 | " #X = reshape1Dto2D(X)\n", 183 | " X = upsample_layer(X, nf*2, fs)\n", 184 | " #X = reshape2Dto1D(X)\n", 185 | " X = tf.concat([X,l_in],axis=-1)\n", 186 | " print('U-Block >> ',X)\n", 187 | "\n", 188 | "# Final layer and add input layer\n", 189 | "X = upsample_layer(X,nf=2,ks=9)\n", 190 | "G = tf.add(X,oX)\n", 191 | "print('Fin-Layer >> ',G)\n", 192 | "print('>> ...finish')\n", 193 | "print()\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "data": { 203 | "text/plain": [ 204 | "\"\\n# Model config\\n#n_filters = [128, 256, 512,512]\\nnf = [12, 24, 48, 48, 48]\\n#n_filtersizes = [65, 33, 17, 9]\\nfs = [10, 7, 3, 3, 3]\\n\\nprint(X)\\n\\n# D-level-00\\ndx0 = downsample_layer(X, nf[0], fs[0])\\nprint('D-Block-00:',dx0)\\n\\n# D-level-01\\ndx1 = downsample_layer(dx0, nf[1], fs[1])\\nprint('D-Block-01:',dx1)\\n\\n# D-level-02\\ndx2 = downsample_layer(dx1, nf[2], fs[2])\\nprint('D-Block-02:',dx2)\\n\\n# D-level-03\\ndx3 = downsample_layer(dx2, nf[3], fs[3])\\nprint('D-Block-03:',dx3)\\n\\n# Bottle-Neck\\nbtn = downsample_layer(dx3, nf[4], fs[4])\\nprint('Bottle-Neck:',btn)\\n\\n# U-level-03 \\nux3 = upsample_layer(btn, nf[3] * 2, fs[3])\\nucx3 = tf.concat([ux3, dx3], axis=-1)\\nprint('U-Block-03:',ucx3)\\n\\n# U-level-02\\nux2 = upsample_layer(ucx3, nf[2] * 2, fs[2])\\nucx2 = tf.concat([ux2, dx2], axis=-1)\\nprint('U-Block-02:',ucx2)\\n\\n# U-level-01\\nux1 = upsample_layer(ucx2, nf[1] * 2, fs[1])\\nucx1 = tf.concat([ux1, dx1], axis=-1)\\nprint('U-Block-01:',ucx1)\\n\\n# U-level-00\\nux0 = upsample_layer(ucx1, nf[0] * 2, fs[0])\\nucx0 = tf.concat([ux0, dx0], axis=-1)\\nprint('U-Block-00:',ucx0)\\n\\n# U x 2\\nux = upsample_layer(ucx0, nf[0] * 2, fs[0])\\nprint('U-Block-Fin1:',ux)\\n\\n# U x 2\\nux = upsample_layer(ux, nf[0] * 2, fs[0])\\nprint('U-Block-Fin2:',ux)\\n\\n# Fin\\nG = upsample_layer(ux, 2, 9)\\nprint('Final:',G)\\n\\n# X2?\\nX4 = upsample_layer(X, 2, 9)\\nX4 = upsample_layer(X4, 2, 9)\\nprint('X4:',X4)\\n\\n# Generator\\nG = tf.add(G,X4)\\nprint('Generator:',G)\\n\"" 205 | ] 206 | }, 207 | "execution_count": 5, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "\n", 214 | "\n", 215 | "'''\n", 216 | "# Model config\n", 217 | "#n_filters = [128, 256, 512,512]\n", 218 | "nf = [12, 24, 48, 48, 48]\n", 219 | "#n_filtersizes = [65, 33, 17, 9]\n", 220 | "fs = [10, 7, 3, 3, 3]\n", 221 | "\n", 222 | "print(X)\n", 223 | "\n", 224 | "# D-level-00\n", 225 | "dx0 = downsample_layer(X, nf[0], fs[0])\n", 226 | "print('D-Block-00:',dx0)\n", 227 | "\n", 228 | "# D-level-01\n", 229 | "dx1 = downsample_layer(dx0, nf[1], fs[1])\n", 230 | "print('D-Block-01:',dx1)\n", 231 | "\n", 232 | "# D-level-02\n", 233 | "dx2 = downsample_layer(dx1, nf[2], fs[2])\n", 234 | "print('D-Block-02:',dx2)\n", 235 | "\n", 236 | "# D-level-03\n", 237 | "dx3 = downsample_layer(dx2, nf[3], fs[3])\n", 238 | "print('D-Block-03:',dx3)\n", 239 | "\n", 240 | "# Bottle-Neck\n", 241 | "btn = downsample_layer(dx3, nf[4], fs[4])\n", 242 | "print('Bottle-Neck:',btn)\n", 243 | "\n", 244 | "# U-level-03 \n", 245 | "ux3 = upsample_layer(btn, nf[3] * 2, fs[3])\n", 246 | "ucx3 = tf.concat([ux3, dx3], axis=-1)\n", 247 | "print('U-Block-03:',ucx3)\n", 248 | "\n", 249 | "# U-level-02\n", 250 | "ux2 = upsample_layer(ucx3, nf[2] * 2, fs[2])\n", 251 | "ucx2 = tf.concat([ux2, dx2], axis=-1)\n", 252 | "print('U-Block-02:',ucx2)\n", 253 | "\n", 254 | "# U-level-01\n", 255 | "ux1 = upsample_layer(ucx2, nf[1] * 2, fs[1])\n", 256 | "ucx1 = tf.concat([ux1, dx1], axis=-1)\n", 257 | "print('U-Block-01:',ucx1)\n", 258 | "\n", 259 | "# U-level-00\n", 260 | "ux0 = upsample_layer(ucx1, nf[0] * 2, fs[0])\n", 261 | "ucx0 = tf.concat([ux0, dx0], axis=-1)\n", 262 | "print('U-Block-00:',ucx0)\n", 263 | "\n", 264 | "# U x 2\n", 265 | "ux = upsample_layer(ucx0, nf[0] * 2, fs[0])\n", 266 | "print('U-Block-Fin1:',ux)\n", 267 | "\n", 268 | "# U x 2\n", 269 | "ux = upsample_layer(ux, nf[0] * 2, fs[0])\n", 270 | "print('U-Block-Fin2:',ux)\n", 271 | "\n", 272 | "# Fin\n", 273 | "G = upsample_layer(ux, 2, 9)\n", 274 | "print('Final:',G)\n", 275 | "\n", 276 | "# X2?\n", 277 | "X4 = upsample_layer(X, 2, 9)\n", 278 | "X4 = upsample_layer(X4, 2, 9)\n", 279 | "print('X4:',X4)\n", 280 | "\n", 281 | "# Generator\n", 282 | "G = tf.add(G,X4)\n", 283 | "print('Generator:',G)\n", 284 | "'''" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.5.2" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/02_training_model-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### referenced by - https://github.com/kuleshov/audio-super-res" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Training ASR model" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "os.sys.path.append(os.path.abspath('.'))\n", 25 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 26 | "import numpy as np\n", 27 | "import matplotlib\n", 28 | "from asr_model import ASRNet, default_opt\n", 29 | "from io_utils import upsample_wav\n", 30 | "from io_utils import load_h5\n", 31 | "import tensorflow as tf\n", 32 | "#matplotlib.use('Agg')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "1.5.0\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "args = {\n", 50 | " 'train' : 'train.h5',\n", 51 | " 'val' : 'valid.h5',\n", 52 | " 'alg' : 'adam',\n", 53 | " 'epochs' : 10,\n", 54 | " 'logname' : 'default_log_name',\n", 55 | " 'layers' : 4,\n", 56 | " 'lr' : 0.0005,\n", 57 | " 'batch_size' : 100\n", 58 | "}\n", 59 | "print(tf.__version__)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": { 66 | "scrolled": true 67 | }, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "List of arrays in input file: KeysView()\n", 74 | "Shape of X: ()\n", 75 | "Shape of Y: ()\n", 76 | "List of arrays in input file: KeysView()\n", 77 | "Shape of X: (162, 32768, 1)\n", 78 | "Shape of Y: (162, 32768, 1)\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# get data\n", 84 | "X_train, Y_train = load_h5(args['train'])\n", 85 | "X_val, Y_val = load_h5(args['val'])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "ename": "IndexError", 95 | "evalue": "too many indices for array", 96 | "output_type": "error", 97 | "traceback": [ 98 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 99 | "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", 100 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# determine super-resolution level\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mn_dim_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_chan_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mY_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mn_dim_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_chan_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'number of dimension Y:'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_dim_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'number of channel Y:'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_chan_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 101 | "\u001b[0;31mIndexError\u001b[0m: too many indices for array" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "# determine super-resolution level\n", 107 | "n_dim_y, n_chan_y = Y_train[0].shape\n", 108 | "n_dim_x, n_chan_x = X_train[0].shape\n", 109 | "print('number of dimension Y:',n_dim_y)\n", 110 | "print('number of channel Y:',n_chan_y)\n", 111 | "print('number of dimension X:',n_dim_x)\n", 112 | "print('number of channel X:',n_chan_x)\n", 113 | "r = int(Y_train[0].shape[0] / X_train[0].shape[0])\n", 114 | "print('r:',r)\n", 115 | "n_chan = n_chan_y\n", 116 | "n_dim = n_dim_y\n", 117 | "assert n_chan == 1 # if not number of channel is not 0 -> Error assert!" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "scrolled": false 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "# create model\n", 129 | "def get_model(args, n_dim, r, from_ckpt=False, train=True):\n", 130 | " \"\"\"Create a model based on arguments\"\"\"\n", 131 | " \n", 132 | " if train:\n", 133 | " opt_params = {\n", 134 | " 'alg' : args['alg'], \n", 135 | " 'lr' : args['lr'], \n", 136 | " 'b1' : 0.9, \n", 137 | " 'b2' : 0.999,\n", 138 | " 'batch_size': args['batch_size'], \n", 139 | " 'layers': args['layers']}\n", 140 | " else: \n", 141 | " opt_params = default_opt\n", 142 | "\n", 143 | " # create model & init\n", 144 | " model = ASRNet(\n", 145 | " from_ckpt=from_ckpt, \n", 146 | " n_dim=n_dim, \n", 147 | " r=r,\n", 148 | " opt_params=opt_params, \n", 149 | " log_prefix=args['logname'])\n", 150 | " \n", 151 | " return model\n", 152 | "\n", 153 | "model = get_model(args, n_dim, r, from_ckpt=False, train=True)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "scrolled": false 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "# train model\n", 165 | "model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args['epochs'])" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "Python 3", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.5.2" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 2 197 | } 198 | -------------------------------------------------------------------------------- /00_data_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### referenced by - https://github.com/kuleshov/audio-super-res\n", 8 | "* 아래 audio-super-res로 명시된것은 위의 참고자료 모델을 의미함" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## 1. 학습에 사용될 원본 wav파일들을 사용하여 학습용 h5포맷 데이터 생성하기 \n", 16 | "* ./data/train 폴더에 학습용 wav파일 원본들을 저장한다\n", 17 | "* ./data/train.txt에 학습용 wav파일의 목록을 저장한다\n", 18 | "* 전처리 과정에서 low-res와 high-res버전의 데이터를 생성하고 이를 h5포맷으로 저장한다" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import os, argparse\n", 28 | "import numpy as np\n", 29 | "import h5py\n", 30 | "import librosa\n", 31 | "from scipy import interpolate\n", 32 | "from scipy.signal import decimate\n", 33 | "from scipy.signal import butter, lfilter\n", 34 | "args = {\n", 35 | " 'interpolate' : 0, # False\n", 36 | " 'dimension' : 8192 *2, # dimension of patches \n", 37 | " 'stride' : 8192 *2, # stride of patches - 8192 is apploximately 1 second\n", 38 | " 'scale' : 4, # training data scale(down sampling rate)\n", 39 | " 'sr' : 16000,# sampling rate\n", 40 | " 'sam' : 1,\n", 41 | " 'train_out' : 'train.h5',\n", 42 | " 'train_in_dir' : '../temp/data/train/', # the location where training data are\n", 43 | " 'valid_out' : 'valid.h5',\n", 44 | " 'valid_in_dir' : '../temp/data/valid/',\n", 45 | " 'train_file_list' : './data/train.txt', # file name list of training data\n", 46 | " 'valid_file_list' : './data/valid.txt',\n", 47 | " 'temp_out' : 'temp.h5',\n", 48 | " 'temp_in_dir' : '../temp/data/temp/',\n", 49 | " 'temp_file_list' : './data/temp.txt'\n", 50 | "}" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "\n", 60 | "def upsample(x_lr, r):\n", 61 | " x_lr = x_lr.flatten()\n", 62 | " x_hr_len = len(x_lr) * r\n", 63 | " x_sp = np.zeros(x_hr_len)\n", 64 | "\n", 65 | " i_lr = np.arange(x_hr_len, step=r)\n", 66 | " i_hr = np.arange(x_hr_len)\n", 67 | "\n", 68 | " f = interpolate.splrep(i_lr, x_lr)\n", 69 | "\n", 70 | " x_sp = interpolate.splev(i_hr, f)\n", 71 | "\n", 72 | " return x_sp" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def add_data_preprocessed(h5_file, inputfiles, in_dir, args, save_examples=False):\n", 82 | " \n", 83 | " # 1) read original dataset\n", 84 | " \n", 85 | " file_list = []\n", 86 | " file_extensions = set(['.wav'])\n", 87 | " with open(inputfiles) as f:\n", 88 | " for line in f:\n", 89 | " filename = line.strip()\n", 90 | " ext = os.path.splitext(filename)[1]\n", 91 | " if ext in file_extensions:\n", 92 | " file_list.append(os.path.join(in_dir, filename))\n", 93 | " \n", 94 | " num_files = len(file_list)\n", 95 | " \n", 96 | " # 2) read wav file (we always use interpolate mode) \n", 97 | " # 3) create low-res version\n", 98 | " # 4) upsample low-res version for same data size\n", 99 | " # 5) patch the data\n", 100 | " \n", 101 | " d, d_lr = args['dimension'], (args['dimension'])\n", 102 | " s, s_lr = args['stride'], (args['stride'])\n", 103 | " hr_patches = list()\n", 104 | " lr_patches = list()\n", 105 | " for j, file_path in enumerate(file_list):\n", 106 | " if j % 10 == 0: print('%d/%d' % (j, num_files))\n", 107 | " \n", 108 | " # load audio file\n", 109 | " x, fs = librosa.load(file_path, sr=args['sr']) # sr = sample rates\n", 110 | " \n", 111 | " # crop so that it works with scailing ratio\n", 112 | " x_len = len(x)\n", 113 | " x = x[ : x_len - (x_len % args['scale'])]\n", 114 | " \n", 115 | " # generate low-res version\n", 116 | " x_lr = decimate(x, args['scale'])\n", 117 | " \n", 118 | " # upsample data(we will use preprocessed low-res data)\n", 119 | " # EX. scale x4 on dimension\n", 120 | " # data (low-res )2048 ---> [cubic-upscaling] --> 8192 ---> model input (8192)\n", 121 | " # label (high-res)8192 -----------------------------------> model output(8192)\n", 122 | " x_lr = upsample(x_lr, args['scale'])\n", 123 | " \n", 124 | " assert len(x) % args['scale'] == 0\n", 125 | " assert len(x_lr) == (len(x))\n", 126 | " \n", 127 | " # Generate patches\n", 128 | " max_i = len(x) - d + 1 # d = dimension\n", 129 | " for i in range(0, max_i, s): # s = strides \n", 130 | " # keep only a fraction of all the patches\n", 131 | " u = np.random.uniform()\n", 132 | " if u > args['sam']: continue\n", 133 | " \n", 134 | " i_lr = i\n", 135 | " \n", 136 | " hr_patch = np.array( x[i : i+d] )\n", 137 | " lr_patch = np.array( x_lr[i_lr : i_lr + d_lr] ) \n", 138 | " assert len(hr_patch) == d\n", 139 | " assert len(lr_patch) == d_lr\n", 140 | " \n", 141 | " hr_patches.append(hr_patch.reshape((d,1)))\n", 142 | " lr_patches.append(lr_patch.reshape((d_lr,1)))\n", 143 | " \n", 144 | " \n", 145 | " # 6) save as .h5 files \n", 146 | " # crop # of patches so that it's a multiple of mini-batch size\n", 147 | " num_hr_patches = len(hr_patches)\n", 148 | " num_lr_patches = len(lr_patches)\n", 149 | " \n", 150 | " print('num_hr_patches:', num_hr_patches)\n", 151 | " print('num_lr_patches:', num_lr_patches)\n", 152 | " print('batch_size:', args['batch_size'])\n", 153 | " num_to_keep_hr = int(np.floor(num_hr_patches / args['batch_size']) * args['batch_size'])\n", 154 | " hr_patches = np.array(hr_patches[:num_to_keep_hr])\n", 155 | " \n", 156 | " num_to_keep_lr = int(np.floor(num_lr_patches / args['batch_size']) * args['batch_size'])\n", 157 | " lr_patches = np.array(lr_patches[:num_to_keep_lr])\n", 158 | "\n", 159 | " if save_examples:\n", 160 | " librosa.output.write_wav('example-hr.wav', hr_patches[40], fs, norm=False)\n", 161 | " #librosa.output.write_wav('example-lr.wav', lr_patches[40], int(fs / args['scale']), norm=False)\n", 162 | " librosa.output.write_wav('example-lr.wav', lr_patches[40], fs, norm=False)\n", 163 | " print (hr_patches[40].shape)\n", 164 | " print (lr_patches[40].shape)\n", 165 | " print (hr_patches[40][0][:10])\n", 166 | " print (lr_patches[40][0][:10])\n", 167 | " print ('two examples saved')\n", 168 | "\n", 169 | " print ('hr_patches shape:',hr_patches.shape)\n", 170 | " print ('lr_patches shape:',lr_patches.shape)\n", 171 | "\n", 172 | " # create the hdf5 file\n", 173 | " data_set = h5_file.create_dataset('data', lr_patches.shape, np.float32) # lr\n", 174 | " label_set = h5_file.create_dataset('label', hr_patches.shape, np.float32) # hr\n", 175 | "\n", 176 | " data_set[...] = lr_patches\n", 177 | " label_set[...] = hr_patches" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 4, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "0/231\n", 190 | "10/231\n", 191 | "20/231\n", 192 | "30/231\n", 193 | "40/231\n", 194 | "50/231\n", 195 | "60/231\n", 196 | "70/231\n", 197 | "80/231\n", 198 | "90/231\n", 199 | "100/231\n", 200 | "110/231\n", 201 | "120/231\n", 202 | "130/231\n", 203 | "140/231\n", 204 | "150/231\n", 205 | "160/231\n", 206 | "170/231\n", 207 | "180/231\n", 208 | "190/231\n", 209 | "200/231\n", 210 | "210/231\n", 211 | "220/231\n", 212 | "230/231\n", 213 | "num_hr_patches: 852\n", 214 | "num_lr_patches: 852\n", 215 | "batch_size: 1\n", 216 | "(16384, 1)\n", 217 | "(16384, 1)\n", 218 | "[0.14710128]\n", 219 | "[0.1460274]\n", 220 | "two examples saved\n", 221 | "hr_patches shape: (852, 16384, 1)\n", 222 | "lr_patches shape: (852, 16384, 1)\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "# create train\n", 228 | "with h5py.File(args['train_out'], 'w') as f:\n", 229 | " add_data_preprocessed(f, args['train_file_list'], args['train_in_dir'],args, save_examples=True)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "## 2. 마찬가지로 검증 데이터셋도 h5파일로 구성한다\n", 237 | "* ./data/valid 폴더에 검증용 원본 데이터를 저장한다\n", 238 | "* ./data/valid.txt 파일에 검증용 데이터 목록을 저장한다\n", 239 | "* 1번과 같은 방식으로 h5파일 포맷으로 데이터를 생성" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 5, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "0/125\n", 252 | "10/125\n", 253 | "20/125\n", 254 | "30/125\n", 255 | "40/125\n", 256 | "50/125\n", 257 | "60/125\n", 258 | "70/125\n", 259 | "80/125\n", 260 | "90/125\n", 261 | "100/125\n", 262 | "110/125\n", 263 | "120/125\n", 264 | "num_hr_patches: 287\n", 265 | "num_lr_patches: 287\n", 266 | "batch_size: 1\n", 267 | "(16384, 1)\n", 268 | "(16384, 1)\n", 269 | "[-0.03057251]\n", 270 | "[-0.02435603]\n", 271 | "two examples saved\n", 272 | "hr_patches shape: (287, 16384, 1)\n", 273 | "lr_patches shape: (287, 16384, 1)\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "# create validation\n", 279 | "with h5py.File(args['valid_out'], 'w') as f:\n", 280 | " add_data_preprocessed(f, args['valid_file_list'], args['valid_in_dir'],args, save_examples=True)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "## 3. 기타 다른 데이터를 사용하기\n", 288 | "* ./data/temp 폴더에 원본 wav파일을 저장\n", 289 | "* ./data/temp.txt에 파일 리스트를 작성\n", 290 | "* 아래 코드를 실행하여 h5포맷으로 데이터를 생성" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 6, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "num_hr_patches: 0\n", 303 | "num_lr_patches: 0\n", 304 | "batch_size: 1\n", 305 | "hr_patches shape: (0,)\n", 306 | "lr_patches shape: (0,)\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "# create another set\n", 312 | "with h5py.File(args['temp_out'], 'w') as f:\n", 313 | " add_data_preprocessed(f, args['temp_file_list'], args['temp_in_dir'], args, save_examples=False)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [] 322 | } 323 | ], 324 | "metadata": { 325 | "kernelspec": { 326 | "display_name": "Python 3", 327 | "language": "python", 328 | "name": "python3" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.5.2" 341 | } 342 | }, 343 | "nbformat": 4, 344 | "nbformat_minor": 2 345 | } 346 | -------------------------------------------------------------------------------- /01_basic_ASR_structure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import tensorflow as tf\n", 11 | "from scipy import interpolate\n", 12 | "from model import Model, default_opt\n", 13 | "from subpixel import SubPixel1D, SubPixel1D_v2\n", 14 | "from standard import conv1d, deconv1d" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# shape = (num of example, sampling-ratio, num of channel)\n", 24 | "X = tf.placeholder(tf.float32, shape=(100, 8192, 1), name='X') # in model, we use input& output shape as (None, None, 1)\n", 25 | "Y = tf.placeholder(tf.float32, shape=(100, 8192, 1), name='Y') \n", 26 | "downsampled_l = []\n", 27 | "di = 0" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def reshape1Dto2D(X):\n", 37 | " n_batch, n_width, n_chan = X.get_shape()\n", 38 | " X = tf.reshape(X,[n_batch, 1, n_width, n_chan])\n", 39 | " return X\n", 40 | "\n", 41 | "def reshape2Dto1D(X):\n", 42 | " # reshape 2D -> 1D\n", 43 | " n_batch, _, n_width, n_chan = X.get_shape()\n", 44 | " X = tf.reshape(X,[n_batch, n_width, n_chan])\n", 45 | " return X\n", 46 | "\n", 47 | "def downsample_layer(x , nf, ks, B=False):\n", 48 | " x = tf.layers.conv1d(\n", 49 | " x,\n", 50 | " filters = nf,\n", 51 | " kernel_size = ks,\n", 52 | " strides=1,\n", 53 | " padding='same',\n", 54 | " data_format='channels_last',\n", 55 | " dilation_rate=1,\n", 56 | " activation=None,\n", 57 | " use_bias=True,\n", 58 | " kernel_initializer=None,\n", 59 | " bias_initializer=tf.zeros_initializer(),\n", 60 | " kernel_regularizer=None,\n", 61 | " bias_regularizer=None,\n", 62 | " activity_regularizer=None,\n", 63 | " kernel_constraint=None,\n", 64 | " bias_constraint=None,\n", 65 | " trainable=True,\n", 66 | " name=None,\n", 67 | " reuse=None\n", 68 | " )\n", 69 | " x = tf.layers.max_pooling1d(\n", 70 | " x,\n", 71 | " pool_size = 2,\n", 72 | " strides = 2,\n", 73 | " padding='same',\n", 74 | " data_format='channels_last',\n", 75 | " name=None\n", 76 | " )\n", 77 | " \n", 78 | " if B : x = tf.layers.dropout(x, rate=0.5)\n", 79 | " \n", 80 | " x = tf.nn.relu(x)\n", 81 | " return x\n", 82 | "\n", 83 | "def upsample_layer(x, nf, ks):\n", 84 | " '''x = tf.layers.conv2d_transpose(\n", 85 | " x,\n", 86 | " filters = nf,\n", 87 | " kernel_size = [1,ks],\n", 88 | " strides=(1, 1),\n", 89 | " padding='same',\n", 90 | " data_format='channels_last',\n", 91 | " activation=None,\n", 92 | " use_bias=True,\n", 93 | " kernel_initializer=None,\n", 94 | " bias_initializer=tf.zeros_initializer(),\n", 95 | " kernel_regularizer=None,\n", 96 | " bias_regularizer=None,\n", 97 | " activity_regularizer=None,\n", 98 | " kernel_constraint=None,\n", 99 | " bias_constraint=None,\n", 100 | " trainable=True,\n", 101 | " name=None,\n", 102 | " reuse=None\n", 103 | " )'''\n", 104 | " x = tf.layers.conv1d(\n", 105 | " x,\n", 106 | " filters = nf,\n", 107 | " kernel_size = ks,\n", 108 | " strides=1,\n", 109 | " padding='same',\n", 110 | " data_format='channels_last',\n", 111 | " dilation_rate=1,\n", 112 | " activation=None,\n", 113 | " use_bias=True,\n", 114 | " kernel_initializer=None,\n", 115 | " bias_initializer=tf.zeros_initializer(),\n", 116 | " kernel_regularizer=None,\n", 117 | " bias_regularizer=None,\n", 118 | " activity_regularizer=None,\n", 119 | " kernel_constraint=None,\n", 120 | " bias_constraint=None,\n", 121 | " trainable=True,\n", 122 | " name=None,\n", 123 | " reuse=None\n", 124 | " )\n", 125 | " x = tf.layers.dropout(x, rate=0.5)\n", 126 | " x = tf.nn.relu(x)\n", 127 | " x = SubPixel1D(x,r=2)\n", 128 | " return x" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | ">> Generator Model init...\n", 141 | "Input >> Tensor(\"X:0\", shape=(100, 8192, 1), dtype=float32)\n", 142 | "D-Block >> Tensor(\"Relu:0\", shape=(100, 4096, 48), dtype=float32)\n", 143 | "D-Block >> Tensor(\"Relu_1:0\", shape=(100, 2048, 96), dtype=float32)\n", 144 | "D-Block >> Tensor(\"Relu_2:0\", shape=(100, 1024, 128), dtype=float32)\n", 145 | "D-Block >> Tensor(\"Relu_3:0\", shape=(100, 512, 128), dtype=float32)\n", 146 | "B-Block >> Tensor(\"Relu_4:0\", shape=(100, 256, 128), dtype=float32)\n", 147 | "U-Block >> Tensor(\"concat:0\", shape=(100, 512, 256), dtype=float32)\n", 148 | "U-Block >> Tensor(\"concat_1:0\", shape=(100, 1024, 256), dtype=float32)\n", 149 | "U-Block >> Tensor(\"concat_2:0\", shape=(100, 2048, 192), dtype=float32)\n", 150 | "U-Block >> Tensor(\"concat_3:0\", shape=(100, 4096, 96), dtype=float32)\n", 151 | "Fin-Layer >> Tensor(\"Add:0\", shape=(100, 8192, 1), dtype=float32)\n", 152 | ">> ...finish\n", 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "n_filters = [48, 96, 128, 128]\n", 159 | "n_filtersizes = [24, 12, 6, 3]\n", 160 | "L = 4\n", 161 | "\n", 162 | "# save origin-X \n", 163 | "oX = X\n", 164 | "print('>> Generator Model init...')\n", 165 | "print('Input >> ', X)\n", 166 | "# downsampling layers\n", 167 | "for l, nf, fs in zip(range(L), n_filters, n_filtersizes):\n", 168 | " X = downsample_layer(X, nf, fs)\n", 169 | " downsampled_l.append(X)\n", 170 | " print('D-Block >> ' ,X)\n", 171 | "\n", 172 | "# Bottle-neck layer\n", 173 | "X = downsample_layer(X, n_filters[-1], n_filtersizes[-1], B=True)\n", 174 | "print('B-Block >> ', X)\n", 175 | "\n", 176 | "# Upsample layer\n", 177 | "L = reversed(range(L))\n", 178 | "n_filters = reversed(n_filters)\n", 179 | "n_filtersizes = reversed(n_filtersizes)\n", 180 | "downsampled_l = reversed(downsampled_l)\n", 181 | "\n", 182 | "for l, nf, fs, l_in in zip( L, (n_filters), (n_filtersizes), (downsampled_l)):\n", 183 | " #X = reshape1Dto2D(X)\n", 184 | " X = upsample_layer(X, nf*2, fs)\n", 185 | " #X = reshape2Dto1D(X)\n", 186 | " X = tf.concat([X,l_in],axis=-1)\n", 187 | " print('U-Block >> ',X)\n", 188 | "\n", 189 | "# Final layer and add input layer\n", 190 | "X = upsample_layer(X,nf=2,ks=9)\n", 191 | "G = tf.add(X,oX)\n", 192 | "print('Fin-Layer >> ',G)\n", 193 | "print('>> ...finish')\n", 194 | "print()\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 5, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "\"\\n# Model config\\n#n_filters = [128, 256, 512,512]\\nnf = [12, 24, 48, 48, 48]\\n#n_filtersizes = [65, 33, 17, 9]\\nfs = [10, 7, 3, 3, 3]\\n\\nprint(X)\\n\\n# D-level-00\\ndx0 = downsample_layer(X, nf[0], fs[0])\\nprint('D-Block-00:',dx0)\\n\\n# D-level-01\\ndx1 = downsample_layer(dx0, nf[1], fs[1])\\nprint('D-Block-01:',dx1)\\n\\n# D-level-02\\ndx2 = downsample_layer(dx1, nf[2], fs[2])\\nprint('D-Block-02:',dx2)\\n\\n# D-level-03\\ndx3 = downsample_layer(dx2, nf[3], fs[3])\\nprint('D-Block-03:',dx3)\\n\\n# Bottle-Neck\\nbtn = downsample_layer(dx3, nf[4], fs[4])\\nprint('Bottle-Neck:',btn)\\n\\n# U-level-03 \\nux3 = upsample_layer(btn, nf[3] * 2, fs[3])\\nucx3 = tf.concat([ux3, dx3], axis=-1)\\nprint('U-Block-03:',ucx3)\\n\\n# U-level-02\\nux2 = upsample_layer(ucx3, nf[2] * 2, fs[2])\\nucx2 = tf.concat([ux2, dx2], axis=-1)\\nprint('U-Block-02:',ucx2)\\n\\n# U-level-01\\nux1 = upsample_layer(ucx2, nf[1] * 2, fs[1])\\nucx1 = tf.concat([ux1, dx1], axis=-1)\\nprint('U-Block-01:',ucx1)\\n\\n# U-level-00\\nux0 = upsample_layer(ucx1, nf[0] * 2, fs[0])\\nucx0 = tf.concat([ux0, dx0], axis=-1)\\nprint('U-Block-00:',ucx0)\\n\\n# U x 2\\nux = upsample_layer(ucx0, nf[0] * 2, fs[0])\\nprint('U-Block-Fin1:',ux)\\n\\n# U x 2\\nux = upsample_layer(ux, nf[0] * 2, fs[0])\\nprint('U-Block-Fin2:',ux)\\n\\n# Fin\\nG = upsample_layer(ux, 2, 9)\\nprint('Final:',G)\\n\\n# X2?\\nX4 = upsample_layer(X, 2, 9)\\nX4 = upsample_layer(X4, 2, 9)\\nprint('X4:',X4)\\n\\n# Generator\\nG = tf.add(G,X4)\\nprint('Generator:',G)\\n\"" 206 | ] 207 | }, 208 | "execution_count": 5, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "\n", 215 | "\n", 216 | "'''\n", 217 | "# Model config\n", 218 | "#n_filters = [128, 256, 512,512]\n", 219 | "nf = [12, 24, 48, 48, 48]\n", 220 | "#n_filtersizes = [65, 33, 17, 9]\n", 221 | "fs = [10, 7, 3, 3, 3]\n", 222 | "\n", 223 | "print(X)\n", 224 | "\n", 225 | "# D-level-00\n", 226 | "dx0 = downsample_layer(X, nf[0], fs[0])\n", 227 | "print('D-Block-00:',dx0)\n", 228 | "\n", 229 | "# D-level-01\n", 230 | "dx1 = downsample_layer(dx0, nf[1], fs[1])\n", 231 | "print('D-Block-01:',dx1)\n", 232 | "\n", 233 | "# D-level-02\n", 234 | "dx2 = downsample_layer(dx1, nf[2], fs[2])\n", 235 | "print('D-Block-02:',dx2)\n", 236 | "\n", 237 | "# D-level-03\n", 238 | "dx3 = downsample_layer(dx2, nf[3], fs[3])\n", 239 | "print('D-Block-03:',dx3)\n", 240 | "\n", 241 | "# Bottle-Neck\n", 242 | "btn = downsample_layer(dx3, nf[4], fs[4])\n", 243 | "print('Bottle-Neck:',btn)\n", 244 | "\n", 245 | "# U-level-03 \n", 246 | "ux3 = upsample_layer(btn, nf[3] * 2, fs[3])\n", 247 | "ucx3 = tf.concat([ux3, dx3], axis=-1)\n", 248 | "print('U-Block-03:',ucx3)\n", 249 | "\n", 250 | "# U-level-02\n", 251 | "ux2 = upsample_layer(ucx3, nf[2] * 2, fs[2])\n", 252 | "ucx2 = tf.concat([ux2, dx2], axis=-1)\n", 253 | "print('U-Block-02:',ucx2)\n", 254 | "\n", 255 | "# U-level-01\n", 256 | "ux1 = upsample_layer(ucx2, nf[1] * 2, fs[1])\n", 257 | "ucx1 = tf.concat([ux1, dx1], axis=-1)\n", 258 | "print('U-Block-01:',ucx1)\n", 259 | "\n", 260 | "# U-level-00\n", 261 | "ux0 = upsample_layer(ucx1, nf[0] * 2, fs[0])\n", 262 | "ucx0 = tf.concat([ux0, dx0], axis=-1)\n", 263 | "print('U-Block-00:',ucx0)\n", 264 | "\n", 265 | "# U x 2\n", 266 | "ux = upsample_layer(ucx0, nf[0] * 2, fs[0])\n", 267 | "print('U-Block-Fin1:',ux)\n", 268 | "\n", 269 | "# U x 2\n", 270 | "ux = upsample_layer(ux, nf[0] * 2, fs[0])\n", 271 | "print('U-Block-Fin2:',ux)\n", 272 | "\n", 273 | "# Fin\n", 274 | "G = upsample_layer(ux, 2, 9)\n", 275 | "print('Final:',G)\n", 276 | "\n", 277 | "# X2?\n", 278 | "X4 = upsample_layer(X, 2, 9)\n", 279 | "X4 = upsample_layer(X4, 2, 9)\n", 280 | "print('X4:',X4)\n", 281 | "\n", 282 | "# Generator\n", 283 | "G = tf.add(G,X4)\n", 284 | "print('Generator:',G)\n", 285 | "'''" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": "Python 3", 320 | "language": "python", 321 | "name": "python3" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 3 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython3", 333 | "version": "3.5.2" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 2 338 | } 339 | -------------------------------------------------------------------------------- /02_training_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### referenced by - https://github.com/kuleshov/audio-super-res" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Training ASR model" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "os.sys.path.append(os.path.abspath('.'))\n", 25 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 26 | "import numpy as np\n", 27 | "import matplotlib\n", 28 | "from asr_model import ASRNet, default_opt\n", 29 | "from io_utils import upsample_wav\n", 30 | "from io_utils import load_h5\n", 31 | "import tensorflow as tf\n", 32 | "#matplotlib.use('Agg')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "1.5.0\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "args = {\n", 50 | " 'train' : 'train.h5',\n", 51 | " 'val' : 'valid.h5',\n", 52 | " 'alg' : 'adam',\n", 53 | " 'epochs' : 10,\n", 54 | " 'logname' : 'default_log_name',\n", 55 | " 'layers' : 4,\n", 56 | " 'lr' : 0.0005,\n", 57 | " 'batch_size' : 100\n", 58 | "}\n", 59 | "print(tf.__version__)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": { 66 | "scrolled": true 67 | }, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "List of arrays in input file: KeysView()\n", 74 | "Shape of X: (852, 16384, 1)\n", 75 | "Shape of Y: (852, 16384, 1)\n", 76 | "List of arrays in input file: KeysView()\n", 77 | "Shape of X: (287, 16384, 1)\n", 78 | "Shape of Y: (287, 16384, 1)\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# get data\n", 84 | "X_train, Y_train = load_h5(args['train'])\n", 85 | "X_val, Y_val = load_h5(args['val'])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "number of dimension Y: 16384\n", 98 | "number of channel Y: 1\n", 99 | "number of dimension X: 16384\n", 100 | "number of channel X: 1\n", 101 | "r: 1\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "# determine super-resolution level\n", 107 | "n_dim_y, n_chan_y = Y_train[0].shape\n", 108 | "n_dim_x, n_chan_x = X_train[0].shape\n", 109 | "print('number of dimension Y:',n_dim_y)\n", 110 | "print('number of channel Y:',n_chan_y)\n", 111 | "print('number of dimension X:',n_dim_x)\n", 112 | "print('number of channel X:',n_chan_x)\n", 113 | "r = int(Y_train[0].shape[0] / X_train[0].shape[0])\n", 114 | "print('r:',r)\n", 115 | "n_chan = n_chan_y\n", 116 | "n_dim = n_dim_y\n", 117 | "assert n_chan == 1 # if not number of channel is not 0 -> Error assert!" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": { 124 | "scrolled": false 125 | }, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | ">> Generator Model init...\n", 132 | "D-Block >> Tensor(\"generator/Relu:0\", shape=(?, ?, 32), dtype=float32)\n", 133 | "D-Block >> Tensor(\"generator/Relu_1:0\", shape=(?, ?, 48), dtype=float32)\n", 134 | "D-Block >> Tensor(\"generator/Relu_2:0\", shape=(?, ?, 64), dtype=float32)\n", 135 | "D-Block >> Tensor(\"generator/Relu_3:0\", shape=(?, ?, 64), dtype=float32)\n", 136 | "B-Block >> Tensor(\"generator/Relu_4:0\", shape=(?, ?, 64), dtype=float32)\n", 137 | "U-Block >> Tensor(\"generator/concat:0\", shape=(?, ?, 128), dtype=float32)\n", 138 | "U-Block >> Tensor(\"generator/concat_1:0\", shape=(?, ?, 128), dtype=float32)\n", 139 | "U-Block >> Tensor(\"generator/concat_2:0\", shape=(?, ?, 96), dtype=float32)\n", 140 | "U-Block >> Tensor(\"generator/concat_3:0\", shape=(?, ?, 64), dtype=float32)\n", 141 | "Fin-Layer >> Tensor(\"generator/Add:0\", shape=(?, ?, 1), dtype=float32)\n", 142 | ">> ...finish\n", 143 | "\n", 144 | "creating train_op with params: {'b1': 0.9, 'batch_size': 100, 'layers': 4, 'lr': 0.0005, 'b2': 0.999, 'alg': 'adam'}\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "# create model\n", 150 | "def get_model(args, n_dim, r, from_ckpt=False, train=True):\n", 151 | " \"\"\"Create a model based on arguments\"\"\"\n", 152 | " \n", 153 | " if train:\n", 154 | " opt_params = {\n", 155 | " 'alg' : args['alg'], \n", 156 | " 'lr' : args['lr'], \n", 157 | " 'b1' : 0.9, \n", 158 | " 'b2' : 0.999,\n", 159 | " 'batch_size': args['batch_size'], \n", 160 | " 'layers': args['layers']}\n", 161 | " else: \n", 162 | " opt_params = default_opt\n", 163 | "\n", 164 | " # create model & init\n", 165 | " model = ASRNet(\n", 166 | " from_ckpt=from_ckpt, \n", 167 | " n_dim=n_dim, \n", 168 | " r=r,\n", 169 | " opt_params=opt_params, \n", 170 | " log_prefix=args['logname'])\n", 171 | " \n", 172 | " return model\n", 173 | "\n", 174 | "model = get_model(args, n_dim, r, from_ckpt=False, train=True)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "metadata": { 181 | "scrolled": false 182 | }, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "start training epoch (n:10)\n", 189 | "num-of-batch: 100\n", 190 | "count 1 / obj: 0.011600 / snr: 19.355580\n", 191 | "count 2 / obj: 0.008685 / snr: 20.612528\n", 192 | "count 3 / obj: 0.009123 / snr: 20.398688\n", 193 | "count 4 / obj: 0.008165 / snr: 20.880310\n", 194 | "count 5 / obj: 0.007995 / snr: 20.972055\n", 195 | "count 6 / obj: 0.008121 / snr: 20.903914\n", 196 | "count 7 / obj: 0.008175 / snr: 20.875330\n", 197 | "count 8 / obj: 0.008135 / snr: 20.896617\n", 198 | "count 9 / obj: 0.007721 / snr: 21.123295\n", 199 | "\n", 200 | "Epoch 1 of 10 took 215.091s (8 minibatches)\n", 201 | " training l2_loss/segsnr:\t\t0.008504\t14.878838\n", 202 | " validation l2_loss/segsnr:\t\t0.004127\t15.779288\n", 203 | "-----------------------------------------------------------------------\n", 204 | "count 1 / obj: 0.008232 / snr: 20.844777\n", 205 | "count 2 / obj: 0.008422 / snr: 20.745691\n", 206 | "count 3 / obj: 0.008368 / snr: 20.773929\n", 207 | "count 4 / obj: 0.010174 / snr: 19.924932\n", 208 | "count 5 / obj: 0.008494 / snr: 20.709130\n", 209 | "count 6 / obj: 0.008225 / snr: 20.848514\n", 210 | "count 7 / obj: 0.007763 / snr: 21.099658\n", 211 | "count 8 / obj: 0.008255 / snr: 20.832661\n", 212 | "count 9 / obj: 0.007708 / snr: 21.130599\n", 213 | "\n", 214 | "Epoch 2 of 10 took 209.948s (8 minibatches)\n", 215 | " training l2_loss/segsnr:\t\t0.008491\t14.833782\n", 216 | " validation l2_loss/segsnr:\t\t0.004394\t15.514821\n", 217 | "-----------------------------------------------------------------------\n", 218 | "count 1 / obj: 0.008768 / snr: 20.570860\n", 219 | "count 2 / obj: 0.009080 / snr: 20.419346\n", 220 | "count 3 / obj: 0.008487 / snr: 20.712319\n", 221 | "count 4 / obj: 0.009058 / snr: 20.429468\n", 222 | "count 5 / obj: 0.008655 / snr: 20.627397\n", 223 | "count 6 / obj: 0.008303 / snr: 20.807446\n", 224 | "count 7 / obj: 0.008073 / snr: 20.929529\n", 225 | "count 8 / obj: 0.007903 / snr: 21.021856\n", 226 | "\n", 227 | "Epoch 3 of 10 took 187.922s (8 minibatches)\n", 228 | " training l2_loss/segsnr:\t\t0.008542\t14.899198\n", 229 | " validation l2_loss/segsnr:\t\t0.004204\t15.637467\n", 230 | "-----------------------------------------------------------------------\n", 231 | "count 1 / obj: 0.008562 / snr: 20.674066\n", 232 | "count 2 / obj: 0.008670 / snr: 20.619868\n", 233 | "count 3 / obj: 0.008816 / snr: 20.547241\n", 234 | "count 4 / obj: 0.008630 / snr: 20.639833\n", 235 | "count 5 / obj: 0.007288 / snr: 21.373792\n", 236 | "count 6 / obj: 0.008435 / snr: 20.739280\n", 237 | "count 7 / obj: 0.009120 / snr: 20.400047\n", 238 | "count 8 / obj: 0.008938 / snr: 20.487697\n", 239 | "count 9 / obj: 0.007500 / snr: 21.249393\n", 240 | "\n", 241 | "Epoch 4 of 10 took 210.336s (8 minibatches)\n", 242 | " training l2_loss/segsnr:\t\t0.008450\t14.862305\n", 243 | " validation l2_loss/segsnr:\t\t0.004279\t15.690650\n", 244 | "-----------------------------------------------------------------------\n", 245 | "count 1 / obj: 0.008634 / snr: 20.637885\n", 246 | "count 2 / obj: 0.008482 / snr: 20.715239\n", 247 | "count 3 / obj: 0.008695 / snr: 20.607274\n", 248 | "count 4 / obj: 0.008763 / snr: 20.573319\n", 249 | "count 5 / obj: 0.008851 / snr: 20.530265\n", 250 | "count 6 / obj: 0.009482 / snr: 20.231149\n", 251 | "count 7 / obj: 0.007371 / snr: 21.324759\n", 252 | "count 8 / obj: 0.007992 / snr: 20.973279\n", 253 | "\n", 254 | "Epoch 5 of 10 took 187.988s (8 minibatches)\n", 255 | " training l2_loss/segsnr:\t\t0.008466\t14.879717\n", 256 | " validation l2_loss/segsnr:\t\t0.004271\t15.840616\n", 257 | "-----------------------------------------------------------------------\n", 258 | "count 1 / obj: 0.008731 / snr: 20.589256\n", 259 | "count 2 / obj: 0.008768 / snr: 20.570791\n", 260 | "count 3 / obj: 0.008518 / snr: 20.696596\n", 261 | "count 4 / obj: 0.008011 / snr: 20.962950\n", 262 | "count 5 / obj: 0.008645 / snr: 20.632515\n", 263 | "count 6 / obj: 0.008772 / snr: 20.568880\n", 264 | "count 7 / obj: 0.007220 / snr: 21.414661\n", 265 | "count 8 / obj: 0.009395 / snr: 20.270842\n", 266 | "count 9 / obj: 0.008467 / snr: 20.722581\n", 267 | "\n", 268 | "Epoch 6 of 10 took 220.976s (8 minibatches)\n", 269 | " training l2_loss/segsnr:\t\t0.008544\t14.931331\n", 270 | " validation l2_loss/segsnr:\t\t0.004425\t15.754089\n", 271 | "-----------------------------------------------------------------------\n", 272 | "count 1 / obj: 0.008714 / snr: 20.597965\n", 273 | "count 2 / obj: 0.007440 / snr: 21.284035\n", 274 | "count 3 / obj: 0.008673 / snr: 20.618239\n", 275 | "count 4 / obj: 0.008557 / snr: 20.676878\n", 276 | "count 5 / obj: 0.008860 / snr: 20.525497\n", 277 | "count 6 / obj: 0.007401 / snr: 21.307359\n", 278 | "count 7 / obj: 0.008811 / snr: 20.549535\n", 279 | "count 8 / obj: 0.009500 / snr: 20.222817\n", 280 | "\n", 281 | "Epoch 7 of 10 took 188.196s (8 minibatches)\n", 282 | " training l2_loss/segsnr:\t\t0.008574\t14.844210\n", 283 | " validation l2_loss/segsnr:\t\t0.004590\t15.532207\n", 284 | "-----------------------------------------------------------------------\n", 285 | "count 1 / obj: 0.008422 / snr: 20.745747\n", 286 | "count 2 / obj: 0.010025 / snr: 19.989302\n", 287 | "count 3 / obj: 0.007982 / snr: 20.978840\n", 288 | "count 4 / obj: 0.007588 / snr: 21.198540\n", 289 | "count 5 / obj: 0.008723 / snr: 20.593132\n", 290 | "count 6 / obj: 0.007882 / snr: 21.033471\n", 291 | "count 7 / obj: 0.007978 / snr: 20.980807\n", 292 | "count 8 / obj: 0.008932 / snr: 20.490646\n", 293 | "count 9 / obj: 0.007827 / snr: 21.064203\n", 294 | "\n", 295 | "Epoch 8 of 10 took 210.552s (8 minibatches)\n", 296 | " training l2_loss/segsnr:\t\t0.008512\t14.860913\n", 297 | " validation l2_loss/segsnr:\t\t0.004529\t15.564570\n", 298 | "-----------------------------------------------------------------------\n", 299 | "count 1 / obj: 0.008904 / snr: 20.503953\n", 300 | "count 2 / obj: 0.008992 / snr: 20.461355\n", 301 | "count 3 / obj: 0.009386 / snr: 20.275136\n", 302 | "count 4 / obj: 0.008390 / snr: 20.762224\n", 303 | "count 5 / obj: 0.008614 / snr: 20.648013\n", 304 | "count 6 / obj: 0.007989 / snr: 20.975313\n", 305 | "count 7 / obj: 0.008420 / snr: 20.746943\n", 306 | "count 8 / obj: 0.008431 / snr: 20.741008\n", 307 | "\n", 308 | "Epoch 9 of 10 took 186.426s (8 minibatches)\n", 309 | " training l2_loss/segsnr:\t\t0.008423\t14.804889\n", 310 | " validation l2_loss/segsnr:\t\t0.004316\t15.689662\n", 311 | "-----------------------------------------------------------------------\n", 312 | "count 1 / obj: 0.008860 / snr: 20.525886\n", 313 | "count 2 / obj: 0.008326 / snr: 20.795741\n", 314 | "count 3 / obj: 0.007961 / snr: 20.990310\n", 315 | "count 4 / obj: 0.008715 / snr: 20.597372\n", 316 | "count 5 / obj: 0.008188 / snr: 20.868377\n", 317 | "count 6 / obj: 0.007310 / snr: 21.360688\n", 318 | "count 7 / obj: 0.009627 / snr: 20.164998\n", 319 | "count 8 / obj: 0.008232 / snr: 20.844775\n", 320 | "count 9 / obj: 0.007970 / snr: 20.985576\n", 321 | "\n", 322 | "Epoch 10 of 10 took 211.568s (8 minibatches)\n", 323 | " training l2_loss/segsnr:\t\t0.008503\t14.870770\n", 324 | " validation l2_loss/segsnr:\t\t0.004452\t15.663491\n", 325 | "-----------------------------------------------------------------------\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "# train model\n", 331 | "model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args['epochs'])" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python 3", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.5.2" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | -------------------------------------------------------------------------------- /02_training_model.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # #### referenced by - https://github.com/kuleshov/audio-super-res 5 | 6 | # # Training ASR model 7 | 8 | # In[1]: 9 | 10 | 11 | import os 12 | os.sys.path.append(os.path.abspath('.')) 13 | os.sys.path.append(os.path.dirname(os.path.abspath('.'))) 14 | import numpy as np 15 | import matplotlib 16 | from asr_model import ASRNet, default_opt 17 | from io_utils import upsample_wav 18 | from io_utils import load_h5 19 | import tensorflow as tf 20 | #matplotlib.use('Agg') 21 | 22 | 23 | # In[2]: 24 | 25 | 26 | args = { 27 | 'train' : 'train.h5', 28 | 'val' : 'valid.h5', 29 | 'alg' : 'adam', 30 | 'epochs' : 10, 31 | 'logname' : 'default_log_name', 32 | 'layers' : 4, 33 | 'lr' : 0.0005, 34 | 'batch_size' : 100 35 | } 36 | print(tf.__version__) 37 | 38 | 39 | # In[3]: 40 | 41 | 42 | # get data 43 | X_train, Y_train = load_h5(args['train']) 44 | X_val, Y_val = load_h5(args['val']) 45 | 46 | 47 | # In[4]: 48 | 49 | 50 | # determine super-resolution level 51 | n_dim_y, n_chan_y = Y_train[0].shape 52 | n_dim_x, n_chan_x = X_train[0].shape 53 | print('number of dimension Y:',n_dim_y) 54 | print('number of channel Y:',n_chan_y) 55 | print('number of dimension X:',n_dim_x) 56 | print('number of channel X:',n_chan_x) 57 | r = int(Y_train[0].shape[0] / X_train[0].shape[0]) 58 | print('r:',r) 59 | n_chan = n_chan_y 60 | n_dim = n_dim_y 61 | assert n_chan == 1 # if not number of channel is not 0 -> Error assert! 62 | 63 | 64 | # In[5]: 65 | 66 | 67 | # create model 68 | def get_model(args, n_dim, r, from_ckpt=False, train=True): 69 | """Create a model based on arguments""" 70 | 71 | if train: 72 | opt_params = { 73 | 'alg' : args['alg'], 74 | 'lr' : args['lr'], 75 | 'b1' : 0.9, 76 | 'b2' : 0.999, 77 | 'batch_size': args['batch_size'], 78 | 'layers': args['layers']} 79 | else: 80 | opt_params = default_opt 81 | 82 | # create model & init 83 | model = ASRNet( 84 | from_ckpt=from_ckpt, 85 | n_dim=n_dim, 86 | r=r, 87 | opt_params=opt_params, 88 | log_prefix=args['logname']) 89 | 90 | return model 91 | 92 | model = get_model(args, n_dim, r, from_ckpt=False, train=True) 93 | 94 | 95 | # In[ ]: 96 | 97 | 98 | # train model 99 | model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args['epochs']) 100 | 101 | -------------------------------------------------------------------------------- /03_predict_and_return_output_wav_files.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### referenced by - https://github.com/kuleshov/audio-super-res\n", 8 | "# Predict & Get output files\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import os\n", 18 | "os.sys.path.append(os.path.abspath('.'))\n", 19 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 20 | "import numpy as np\n", 21 | "import matplotlib\n", 22 | "from asr_model import ASRNet, default_opt\n", 23 | "from io_utils import upsample_wav\n", 24 | "from io_utils import load_h5\n", 25 | "import tensorflow as tf" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "1.5.0\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "args = {\n", 43 | " 'ckpt' : './default_log_name.lr0.000100.1.g4.b100/model.ckpt',\n", 44 | " 'wav_file_list' : './data/test.txt',\n", 45 | " 'r' : 6,\n", 46 | " 'sr' : 16000,\n", 47 | " 'alg' : 'adam',\n", 48 | " 'epochs' : 5,\n", 49 | " 'logname' : 'default_log_name',\n", 50 | " 'layers' : 4,\n", 51 | " 'lr' : 1e-3,\n", 52 | " 'batch_size' : 4,\n", 53 | " 'out_label' : 'asr_pred',\n", 54 | " 'in_dir' : './data/test'\n", 55 | "}\n", 56 | "print(tf.__version__)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": { 63 | "scrolled": false 64 | }, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | ">> Generator Model init...\n", 71 | "D-Block >> Tensor(\"generator/Relu:0\", shape=(?, ?, 12), dtype=float32)\n", 72 | "D-Block >> Tensor(\"generator/Relu_1:0\", shape=(?, ?, 24), dtype=float32)\n", 73 | "D-Block >> Tensor(\"generator/Relu_2:0\", shape=(?, ?, 48), dtype=float32)\n", 74 | "D-Block >> Tensor(\"generator/Relu_3:0\", shape=(?, ?, 48), dtype=float32)\n", 75 | "B-Block >> Tensor(\"generator/Relu_4:0\", shape=(?, ?, 48), dtype=float32)\n", 76 | "U-Block >> Tensor(\"generator/concat:0\", shape=(?, ?, 96), dtype=float32)\n", 77 | "U-Block >> Tensor(\"generator/concat_1:0\", shape=(?, ?, 96), dtype=float32)\n", 78 | "U-Block >> Tensor(\"generator/concat_2:0\", shape=(?, ?, 48), dtype=float32)\n", 79 | "U-Block >> Tensor(\"generator/concat_3:0\", shape=(?, ?, 24), dtype=float32)\n", 80 | "Fin-Layer >> Tensor(\"generator/Add:0\", shape=(?, ?, 1), dtype=float32)\n", 81 | ">> ...finish\n", 82 | "\n", 83 | "creating train_op with params: {'lr': 0.001, 'layers': 4, 'alg': 'adam', 'batch_size': 4, 'b1': 0.9, 'b2': 0.999}\n", 84 | "checkpoint: ./default_log_name.lr0.000100.1.g4.b100/model.ckpt\n", 85 | "ckpt: ./default_log_name.lr0.000100.1.g4.b100/model.ckpt\n", 86 | "INFO:tensorflow:Restoring parameters from ./default_log_name.lr0.000100.1.g4.b100/model.ckpt\n", 87 | "test_0.wav\n", 88 | "(1, 32768, 1)\n", 89 | "(1, 5462, 1)\n" 90 | ] 91 | }, 92 | { 93 | "data": { 94 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGYAAAEYCAYAAACugINnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHRRJREFUeJztnXuQZNd91z+/c2+/Znp2njuzq93VeyVbtrBlZJeCTYhfQQYXhiqXyzZFQjDoH3CFggIM/ANV/AH/4AQqZXAZE1OVxLENCargClEcBQcqVuzY1sOSVitt9jE7O7s7735333t+/HFOP2Z2ZtXamWlfec+naqq7773d92x999xzzu/8HqKqBLKH+XE3ILA7QZiMEoTJKEGYjBKEyShBmIwShMkoQZiMsi9hRORxETkjIq+KyOcOqlEBkFtd+YtIBLwCfBhYBL4LfEpVX9zrO/l4TIuFKWi0kChCkwQEuI2MDxXWV1T16OtdF+/jHu8BXlXVcwAi8lXgY8CewhQLUzz20BPw3FnMzBTp9VXECJqmsNd/EJG9z2UFke2fb9Le39dvXBjmJ/fzKDsBXBr4vOiPbUNEnhCR74nI9zpJfR+3u7049MFfVb+oqo+q6qO5eOywb/cTw36EuQycGvh80h+7yd2EaGULjCBxjORi1Ppuv/Nx0CXrjzFwbRz8OwD2I8x3gdMico+I5IFPAk8eSKsCtz74q2oiIv8A+N9ABHxZVX90YC27zdnPrAxV/SbwzQNqS2CAka/8NZ/DlMfRZnPUt35TEUwyGSUIk1FGKoyKUHn4KIhgt6qQpkgUjbIJbxpCj8koQZiMEoTJKCMVRqxSfmUTUussyoO8GUwvIyT0mIwShMkoQZiMsi9b2RtFOgmcu4h6E78z+ae7XPgm2LU8ZEKPyShBmIwShMkoo7WVdRKIIsxEGTGCGAG1/Qu628u3+fgCocdkliBMRhmtSSYXI8Uiquq8MLsm//DouoHQYzJKECajBGEyymiFiSKYmkArVcz4uJs+y0ATwljTI/SYjBKEyShBmIwyWmFUYaOCttvOLJPzuw4ie3v736aEHpNRXlcYEfmyiFwTkRcGjs2IyFMicta/Th9uM28/hukxvwo8vuPY54Bvqepp4Fv+c+AAeV1hVPXbwNqOwx8DvuLffwX460PdzSrJ6TuQfB7tdCBNwe6ytRy45TFmQVWv+PfLwMIBtSfg2ffgry5RwJ5L9sGo5XYaopaH5VaFuSoixwH867W9LhyMWs5HIWp5WG5VmCeBn/fvfx74n8N8SWNDvNlE4hit1Z37UthO3pVhpsu/Afwx8KCILIrIZ4B/C3xYRM4CH/KfAwfI6zr8qeqn9jj1wQNuS2CAEa/8BWm0kJkpyOWclwyEx9guBJNMRgnCZJQgTEYZrftSmqKbW5BapJDvbysHk/8NhB6TUYIwGSUIk1FGK0wnQYpFiAy0OyO99ZuN0GMyShAmowRhMsqIXWQN5HNgDGml4qLJgp1sV0KPyShBmIwShMkooxVGDLpVgVYbiXOoVaRQCOPMLoQek1GCMBll9NPl2Wm03UEiA2qR3Tz9wzZA6DFZJQiTUYIwGWW0SX6MoMUcTJZdsp9SCdvu3DhdDtPn0GOyShAmowRhMspo3ZeSFFm67t5HEVjrTP8havkGQo/JKMOEYZwSkadF5EUR+ZGI/KI/HiKXD5FhekwC/GNVfQh4DPj7IvIQIXL5UBkmavmKqn7fv68AL+EqxL7xyGVrIbVQqaHtNpraA68f+ZPCG8pULiJ3A48AzzBk5LKIPAE8AVCMyrfaztuOoQd/ESkD/x34h6q6NXjuZpHL24JjTWlfjb2dGEoYEcnhRPk1Vf0f/vDQkcsDP+RM/0niPqtF4pGWF3jTMMysTID/Arykqv9+4NQtRS4HhmOY/67vBf4W8LyI/NAf+xe4SOWv+SjmC8AnDqeJtyfDRC3/X2CvZXmIXD4kRuwlIzA9CWmKqjqvGRumybsRTDIZJQiTUYIwGWW0iwirSCchrTUAXC1Mm4KJ3Ks3/UucAyNgtf+qtj8edWvOaD9JkETR9uTa/rpevc03mckn9JiMEoTJKEGYjDJiQ5WilSoYQfJ5aLfRtnXjQxT1EmTr2+5DOimyeNV9rd2BUhGt1pB8Hlt3KRzNkSMkD5zANDtEa1XoJBAZdH0TmSij7Q7p6prL8hRFbv3kx6t+k3Tvups7j48w6V3oMRklCJNRgjAZZbRjjBgwESKCNhouaSluPSO5GKzFTE1SPVaifjRifr3ixoZGEwp59xrHiAhmapLkznmuvmcMSWD2xSL51QbS6kC1Bsa4UA//+6bkN+kaDcAVrdMk2dttSgSJIjcmdddYYraPT4dI6DEZJQiTUUYqjC3EdB48gTk270z+qZ8qq+15zWhqKV2ps/aBJs3756HeQDsdtFLF1utopeKm2iJEmw1mn28R15XG0RxmvQIrG26avHyNdGXV3VgM2m4jkXGPzzR1ppre42n36a+mKaZYcI/fOHdj+eHBx+ABe5KGHpNRgjAZJQiTUUYrjIHomRex5bHeVLZbb1miCDMzRXL/HSz99BHee99rNGdz2FodrTdI1zeRfB4plbCtFhgDaUpcT9j6UJ3agiFduorW62iz5ccQHxltxI1hzRaq6sY18B6ge0x/xbgpdhRh8rnhEhId4DgTekxGCcJklCBMRhmpMKaVwjsewGz4tUg3Mwa4sWZinMaxIrU7Lc98621MnK34OmZ+HLAWTZJeNJqdHKOxUGT8/40z9VqCdtqAN/GY/pZz1w3Xtjto210jce7m5n6PNhqoqhuzTNQfn3auf7oRCwc0zoQek1GCMBklCJNRRmz2F3j2FXRywq0n4rgfVZamsLpB+WyB0+cjzKVltFpzaxb/7LbNZu9/kl3fwCQpExeWKTeabl0zPo4OJtr2axRn3jduPWO7brlpv027ZebQ1NvvUue47V2jttVV240D2nYOPSajDBMfUxSRPxGRZ33U8r/2x+8RkWdE5FUR+U0RyR9+c28fhukxLeADqvoO4J3A4yLyGPDvgM+r6v3AOvCZw2vm7ccw8TEKVP3HnP9T4APAp/3xrwD/CvjCzX4rLUbIyXudm9GFRbf/knRABE1T0o0NpFpDO21SM2DPGmxPkrgkp0awq2v9dUXvfMdtCftt4972sFqI8kgEWN3uOrvTLUkEiXNu/yafd3s5Ay5QEPXHmoHvbNt63udYM2wMZuSjya4BTwGvARuq6oMpWcSFmO/23V5J306rtq/G3k4MJYyqpqr6TuAk8B7gLcPeYDBqOVcYv8Vm3n68oemyqm6IyNPATwFTIhL7XnMSuDzMbzROTVC64j0qt7b6jwDAFAtomhIfP4adn6Y9N8b66TzTZ9sUX7tOurjU8+AEeo8xc6SMNppIPodMLpDOT5FM5CksbqAXL4Mq2mo5808u5zx0ogjUP8L8/SUyLrGd9V473S2JfB7tJL3HmrvWfdWMjTmzTTcSm6j/mNzL5DMEw8zKjorIlH9fAj6My47xNPBxf1mIWj5ghukxx4GviEiEE/Jrqvo7IvIi8FUR+TfAD3Ah54EDYphZ2XO4NCU7j5/DjTeBQ2C0Zv+2deNLrem8+wezYqh1npiFAjoxTjJVRKzyC5/9JitvL2Anx934Au574jwtJRf7jBuRq3+WWs7/E7j4d1NWH5t3z36rYCJn3ul0vOtU6r6TzxPdsYC5+yRSLGBKxb6LU7OFrVb7Y0an437LKubIEUypiN55DHN0DikUelP53riyjylzMMlklCBMRgnCZJTRCqNK7a4yXF1x0V1p6reJ3TPctlpuHLi2Qu7yBusPFJiIGkyeT5ClFWyzhW223Lihzp1W22203kDbbdLjcyz9ygQ/95Y/IW1HmI5CFDkzjemvl7pjhoggJ4+z+r4TtE9OYau1XtSZttv9repOgqapd5sSTHkcZiaRYhFpJy6blN96PqhogNBjMkoQJqMEYTLKSLeWNTYUNjrIWAlbc5HHPVO6n/Nrq0XabiO1BmvvnuXz/+nj3PHUD7HdLBoiaIq73j/TRQRz5AhX/sIkf+f+/8Uv/d5HuP/rDXJL17EiLiekGMBnfQL3vfExVh+bJ2or+aUtLGCbze2N7pry/VZ4NDNN+/7j5C+tOvfd9Q20k/TGyYMi9JiMEoTJKKP1klHIn1lCO96TpZuURwyom26aYgGzcJTmPXPcfdd1Sp8vYtvt3o6h5POYQgGZPEI6P0ntzjJbd0bU7rT8jff/Mf/xuZ/hvq83iL5/hjS1zswCYAxmvATeDCSlErq5RVKEuafOk66t9033fiuimxBC8nkXnZDLo0fK2JwhvXIVbbX2Zdq/GaHHZJQgTEYJwmSUkY4xptEifes88uI5721CL0kcJiI+dQete4/SnMnRnDFUvnMH9+Y2ex4v0dwcrYfvpDMR0Rk31I8aaieVdLbN+FSDlXYZzo+Rltrkjs55E78fA6LIvY8jSC1azEPpKPPPbGC3Kr0tY0yE+CREEscwMM5op4PUmxTPNkl3mmC69zmgMSf0mIwShMkoQZiMMtp1TC4HiUXGxkivX3fbsR23dpAoIlmYonE0z9bdhtnnO8z/nzXSs+fcVq9VmCyzdXceFShUlInFlKlzSnsiRuwES+fvYf6kpT6fI65ME1/fQrcqLqncWBE6Cfb8otsSXvNbDfV637xDP1MHuG0IabddJEEncWb/zUp/28AImhxOUrnQYzJKECajBGEyykjHGBsbzKVlpDyOaU5gq1VnJ/Om9ejcEpPrkyy9f45Tv70Gaxu9rEdiBDarTL9UByPkltbpHJsiLcWMv3QdVjfQdptx7iMtxkSVJlhL8uAp4s0mNh9jRJBiAZkoky5dwYyNYcpltNHwLk2unZqmLiIBkMkyMj4GSYqoIvkcWCXd3OonxjPSz+Z0QIQek1GCMBklCJNRRjrGiCrMTsNW1YVM+PGjW0xBT8yz+OFpildwWTCqtYHIL4vd2CR6zq07NI6J6w1yuRx2fcNdk6bID86QE0FzObTTIScCrTam2UKKBcjn0CTBlEpuzZJ3YRmapq4tiUXinAvtaLZgdpq0XEBeuYiZmSJduko0O000ecRnJuz47IGu6EM3Qm4kEWWB0fNG6mBGIvIDEfkd/zlELR8ib6TH/CIuYKnLG49atkrnaBlNUrTRdB71XbO/GMxmjfFlS1yH1/72MVY+/QimWNgeAJumaCchrdawq2vYtXWkPE7yyGnkvrtcRIAxzuQfRWilitbqyLiP/Doxj+TzVP/yw8jcjNti3lknzdeskZPHqZ2eYeWRI1z+ew9Tffg40ayv9R3HvSBbHQy2xSfFKxb3NX0eNjj2JPBXgS/5z4KLWv6Gv2S4WsuBoRm2x/wS8E+BrmPuLLcStZzU99XY24lhYjA/ClxT1T+9lRtsi1qOx27lJ25Lhukx7wX+moicB76Ke4T9Mj5q2V8zVNSyKOReugg27XsvdhPxpCnp5WVm/+gyU6+lLLx7mbW3K1IouHoz3mvftlrOC9+6Z7uMj9N+60ku/8wY9TuPuMhkn3BOW63edNZubiETE2gUsfjxO1n8aErj/jnXLiMDW8TGtc0IG++cY+N0TOUuePcnnmPtoRgtj7nogppLQNTdkjblMmZygmj+KGZ2BjM95ZYDt8jrCqOq/1xVT6rq3cAngT9Q1b9JiFo+VPazjvlnwD8SkVdxY06IWj5A3mgChj8E/tC/D1HLh8iIXWQVyeXQliuAoMnAPN+bQ5KLi0yUxyiWqtTOHHPjRKvltn7F9Gu5+GTW9u5jXPhIgcIaFJfrpFWXr0Yin6Gi0XBjUWq5+rOn2PpQnU6zRe5KnvzqFvb6Sj95tk0Bt6ZqvfMelt+n6FgbrPCdy3cxecFCpUZaqQxsR0M0PelqSOdiV7+mk/TNPLdIMMlklCBMRgnCZJTRjjGpdYlKLyy6ZKX5fD+UAXouppUHpjj39Cz3PXmOpNFw5+2OrVufJO7i4xNMnFeOfXsV+8qfDYxB4iObfZamP3ea2U9fYu2lE8w+a5h/ehldvOKyLfXGF//TRlh5uEB+XSmejZh+pUPxioVzPyKtVvvjl1pnm5uedGNLo+UKELVa+44yCz0mowRhMkoQJqOMvNYy+DC7et3VwxwMXxDBlMts3RVx6g8aaLOFKZWwzRZdw3Y3Y1N0dA67MMN7P/osFz97Lyxfd898nzmpG6WMuIJx9ePjTERrPPjFTewLZ7Fdl6PdtoCjiNKK5eRvL7mQvk7i7j6YYBt64xybVZLVteGK/wxJ6DEZJQiTUUY+XZZOgvqtWzVsy3tsSiXSh+8l+YubVK9OMLl1B9G1dVjf6OfZ96V52w8cZ/3+IrY6jS3FxGNjmE7iHjMDEQQSx0g+T36zw5Vfv5v5P3sBUyxgG43d2+gTm87+0WXSy1f6j7udOZa7CSPa7ZCA4XYiCJNRgjAZZbRJftRCu+NKv0eu1vJgTnzJ5+hM5GidLzHz7YtQ8CYb2OZiZGt1oq0241dzVP/zSSa2Kmil0qvX7LYQEiTnXJk0Sci9vMj89+vO/ak73d1raiviamoOTqd31iPzr/1E2AdL6DEZJQiTUYIwGWW03v4mctFb05Po2sYN5zW1lM6t8cCXIuz6BrbZ6tezNAJR3iXNHh9Dnz/D2BlfBCGKsD75aXfN4cYXb3axiq6v37jVu9OFdXDs6LTdeb993Isa2+36QyD0mIwShMkoQZiMMlpbWWSQsTFnM8vFbgzp1hAT49YYl5ddZDAguXhb5gyJY3e9qrOZWduznTGYNVDd72Ndhj9ttrDtdNt5YJsL0rZE1t2xp1vwx0gvO0d3zHPrqoMz8+8k9JiMEoTJKKMVRoT2vUfRjS2kUPAldp1HpUQRRBGt9z1E89H7MZOuPovk4v5UNU37npZjpV4yUVuru9Lu3jyiVt2uo/dW6ZppdjevWB91sOOR1Mu3bF3UWNfUY9W1ebfvHCChx2SUIExGGWpW5oOWKriS3omqPioiM8BvAncD54FPqOr64TTz9uON9Jj3q+o7VfVR//lzwLdU9TTwLf/55lhLZyx2nvCtVq/OZDdxgcQxaw/m0dig1rpEpv68a60hKo9j5mb6v9nxtWF8HTJNOt7s33Genr7WTI+dJXd3Gyd8JIF7VW/e99Nxmx74NvJu7OdR9jFctDKEqOUDZ1hhFPg9EflTEXnCH1tQ1Sv+/TKwsNsXB6OW23YPB4jADQy78n+fql4WkXngKRF5efCkqqqI7Dp3VNUvAl8EmMzPH9788ieMoYRR1cv+9ZqI/BYuxO+qiBxX1SsichxXufzmv5OLKV5vIBMT2K0t54k5cF4mJxhfThl7+Srp6pr7zqCp3Vpnyq83sBubzpzfNc+zY10h0qs3tm1MGHbtsfO6Q1yz7MYwcf7jIjLRfQ/8LPAC8CQuWhlC1PKBM0yPWQB+y2UpIQZ+XVV/V0S+C3xNRD4DXAA+cXjNvP0YptbyOeAduxxfBT54GI0KjNrsL4ItxEjd5ZRR7ZvRIcIuX2Py6Rq21e67K4np17A0xiUJ9fXCJI4xpSMuC0Yn6dmyetdrP0Kgn5BOt28p38SFqf++b/7vRxKwY310sGNQMMlklCBMRgnCZJTRui+lFvPMC1ir27P2iXHrkySBegMZKxGPj6GdDnZzy13SLbDjE5wCLvloFPlzTSSOXaGEKHJrHgzadja4niurcdf3oqX3SCza2/L27/ukvcIMmvRrcR40ocdklCBMRhntdNlaooV50pVVl6s4zvnHRYqJcsj0NMxNUX1gmrW3xkydTZn83lI/sstPebuPJRkfx04fYePhKaKOUj5fI1q8jiYJdrPiSrwXC0ihQFQeR4pF0oUpNDLEF6+Rrq7fmCfZRx90p91mahLaHfebzVbvn9Kbzne3BwaiFkKNsp9ggjAZJQiTUUY7xhiDllxdMolzfTOJTVGNEUBzEdU7IhrHLOVF7/WYpu6cjxKLjx8jPTFH7WiJzkTE2tuE4ophbCkiEkEbTaLZadITc1TvLGNzggrYHNhYyFcskyslTLMJhSm3/dBsYruuUXkXVSDjY87Ts1BAtyqYUhGMQURcvbLUj0dpOuCKdTC1MEOPyShBmIwShMkoI17HKOlMGXPVZyzvJvQ0kYsMm5rgyl+aQVI49VTC2JkV7NXr/ed4HMPbT1OfL7mEQKmS30qZe9YwfrlJfOaSu+7EMZqnpqgdy9EpC7maErWUuAn5rQSskswfIbl7FtNOyV2vuiR3ttqPYJs8AnPTqLVIo4UUi951SV2dGr8tLtHAFoDaA6tTFnpMRgnCZJQgTEYZeURZvLRGUq0SL8yTVCr9iC5g413z5GrK/FOXSJevkSadbTas6OgcW/eUKV1tkbu4glYq2FqDvE8caqMIHnkrNmcwrZQj51OiWhuzsonWfFaMe06w/vZJ4qmYwlpCvNWEtU20UvGJ6ArIHQtoueS2sBtJrziQqh9f4tjVxqzV0dT2Epfe4D61j7VM6DEZJQiTUYIwGWW02ZcioXV6gfxWxdm9um5FPnq4dofh5NcvkFxe6rsIDdC5Z4HyuQqcvUDSaLqDA4myzewMnWJM/uwS6fqGC/ezKZrLo2lKfHyB639+CknhyA/W4NoqdmMTm3aTmuYwszOk02Wi5XW0WgMjpOubvchlmZp0yVarNZdMe2dIxm5R0LdA6DEZJQiTUUZcPwZyz7yM7SSIr+3VS8pjBNMG7fgpst5Yhr05X2Dieyuk7c6NjwpV9OgM8fdfIW13XBRA95TfXkjnpylf7lA6t0b62oVtdcmwqUv8EEdE1zdJFl3JNYlz7nFlYhcBNzuNXbxyY6Sap5/3P5j9fyIJwmSUYSvHTonIN0TkZRF5SUR+SkRmROQpETnrX6cPu7G3E8P2mF8GfldV34ILyXiJW4haliRFH7rXZbwYrHHsa8nkK4ouzGDGxvo1ybreksDYpRrkYsx4ybsKDWzjimCqdczcDOZIeVfzu6m3KFxz9WiiI+V+nUpvrtfUYlfWsNdXMeVyf8ou/W1j2aq6++2sz+zpmWf2yTARZZPAT+PLKapqW1U3CFHLh8owPeYe4DrwX33Z+C/5kL83HrWchqjlYRlGmBh4F/AFVX0EqLHjsaWqCuzatwdrLeej0n7be9swjDCLwKKqPuM/fwMn1FUfrcywUcsgmA3nIuRckWI3lhjnnT/9YgWNIpeZKYoGos0AMZizl6DZcpmQdqJKcvGyM6N0km3jQzfDkj1/iWh5Fan5nrsz+0WnjW00e/WZt9Vgtim2Xie5er1XAGJXDigr0zC1lpeBSyLyoD/0QeBFQtTyoTLsyv+zwK+JSB44B/wCTtQQtXxIDJuA4YfAo7ucClHLh8SIo5ahfXKa/LVVF1LRrSUZRW7ceP4s0k1u3c121MuSlLoax5XKjuMDz3Obkq7tyMw1cF5bLZLlq7ueG/wNvVl2JZseRgDZDQSTTEYJwmSUEZdbhPzljV6dF00SbwpJQf37wQQJOz3nd3qh7NzlHDTjAy5EYCDi601E6DEZJQiTUYIwGWW0wiQJUqkhxQISmV553hvGkK5J/2Z0k//QT5AgUdQvAd9Nst3dOuj+/i7eN0MxsP0wCkKPyShBmIwShMkooiNMwikiFeDMyG44PHPAyojudZeqHn29i0a9wDwzkOk8M4jI97LWrvAoyyhBmIwyamG+OOL7DUvm2jXSwT8wPOFRllGCMBllZMKIyOMickZEXhWR1y8CdLD3/rKIXBORFwaO7ep7LY7/4Nv5nIi8a5Rt7TISYUQkAn4F+AjwEPApEXloFPf2/Crw+I5je/lefwQ47f+eAL4wojZuY1Q95j3Aq6p6TlXbwFdxvs8jQVW/DaztOLyX7/XHgP+mju8AU13HxlEyKmFOAJcGPi/6Yz9O9vK9zkRbw+DPzX2vf1yMSpjLwKmBzyf9sR8ne/leZ6KtoxLmu8BpEbnHu9l+Euf7/ONkL9/rJ4Gf87Ozx4DNgUfe6FDVkfwBfwV4BXgN+Jejuq+/928AV4AObsz4DDCLm42dBX4fmPHXCm4G+RrwPPDoKNva/QsmmYwSBv+MEoTJKEGYjBKEyShBmIwShMkoQZiM8v8BgV/2SsGKAysAAAAASUVORK5CYII=\n", 95 | "text/plain": [ 96 | "
" 97 | ] 98 | }, 99 | "metadata": {}, 100 | "output_type": "display_data" 101 | } 102 | ], 103 | "source": [ 104 | "def evaluate(args):\n", 105 | " # load model\n", 106 | " model = get_model(args, 0, args['r'], from_ckpt=False, train=True)\n", 107 | " model.load(args['ckpt']) # from default checkpoint\n", 108 | "\n", 109 | " if args['wav_file_list']:\n", 110 | " with open(args['wav_file_list']) as f:\n", 111 | " for line in f:\n", 112 | " try:\n", 113 | " filename = line.strip()\n", 114 | " print(filename)\n", 115 | " filename = os.path.join(args['in_dir'], filename)\n", 116 | " upsample_wav(filename, args, model)\n", 117 | " except EOFError:\n", 118 | " print('WARNING: Error reading file:', line.strip())\n", 119 | "\n", 120 | "\n", 121 | "def get_model(args, n_dim, r, from_ckpt=False, train=True):\n", 122 | " \"\"\"Create a model based on arguments\"\"\" \n", 123 | " if train:\n", 124 | " opt_params = {\n", 125 | " 'alg' : args['alg'], \n", 126 | " 'lr' : args['lr'], \n", 127 | " 'b1' : 0.9, \n", 128 | " 'b2' : 0.999,\n", 129 | " 'batch_size': args['batch_size'], \n", 130 | " 'layers': args['layers']}\n", 131 | " else: \n", 132 | " opt_params = default_opt\n", 133 | "\n", 134 | " # create model & init\n", 135 | " model = ASRNet(\n", 136 | " from_ckpt=from_ckpt, \n", 137 | " n_dim=n_dim, \n", 138 | " r=r,\n", 139 | " opt_params=opt_params, \n", 140 | " log_prefix=args['logname'])\n", 141 | " \n", 142 | " return model\n", 143 | "\n", 144 | "evaluate(args)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.5.2" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /03_predict_and_return_output_wav_files.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # #### referenced by - https://github.com/kuleshov/audio-super-res 5 | # # Predict & Get output files 6 | # 7 | 8 | # In[1]: 9 | 10 | 11 | import os 12 | os.sys.path.append(os.path.abspath('.')) 13 | os.sys.path.append(os.path.dirname(os.path.abspath('.'))) 14 | import numpy as np 15 | import matplotlib 16 | from asr_model import ASRNet, default_opt 17 | from io_utils import upsample_wav 18 | from io_utils import load_h5 19 | import tensorflow as tf 20 | 21 | 22 | # In[2]: 23 | 24 | 25 | args = { 26 | 'ckpt' : './default_log_name.lr0.000100.1.g4.b100/model.ckpt', 27 | 'wav_file_list' : './data/test.txt', 28 | 'r' : 6, 29 | 'sr' : 16000, 30 | 'alg' : 'adam', 31 | 'epochs' : 5, 32 | 'logname' : 'default_log_name', 33 | 'layers' : 4, 34 | 'lr' : 1e-3, 35 | 'batch_size' : 4, 36 | 'out_label' : 'asr_pred', 37 | 'in_dir' : './data/test' 38 | } 39 | print(tf.__version__) 40 | 41 | 42 | # In[3]: 43 | 44 | 45 | def evaluate(args): 46 | # load model 47 | model = get_model(args, 0, args['r'], from_ckpt=False, train=True) 48 | model.load(args['ckpt']) # from default checkpoint 49 | 50 | if args['wav_file_list']: 51 | with open(args['wav_file_list']) as f: 52 | for line in f: 53 | try: 54 | filename = line.strip() 55 | print(filename) 56 | filename = os.path.join(args['in_dir'], filename) 57 | upsample_wav(filename, args, model) 58 | except EOFError: 59 | print('WARNING: Error reading file:', line.strip()) 60 | 61 | 62 | def get_model(args, n_dim, r, from_ckpt=False, train=True): 63 | """Create a model based on arguments""" 64 | if train: 65 | opt_params = { 66 | 'alg' : args['alg'], 67 | 'lr' : args['lr'], 68 | 'b1' : 0.9, 69 | 'b2' : 0.999, 70 | 'batch_size': args['batch_size'], 71 | 'layers': args['layers']} 72 | else: 73 | opt_params = default_opt 74 | 75 | # create model & init 76 | model = ASRNet( 77 | from_ckpt=from_ckpt, 78 | n_dim=n_dim, 79 | r=r, 80 | opt_params=opt_params, 81 | log_prefix=args['logname']) 82 | 83 | return model 84 | 85 | evaluate(args) 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio-Super-Resolution -------------------------------------------------------------------------------- /__pycache__/asr_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/asr_model.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/io_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/io_utils.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/standard.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/standard.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/subpixel.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/subpixel.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/summarization.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/__pycache__/summarization.cpython-35.pyc -------------------------------------------------------------------------------- /asr_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from scipy import interpolate 4 | from subpixel import SubPixel1D, SubPixel1D_v2 5 | from dataset import DataSet 6 | import os 7 | import time 8 | import librosa 9 | 10 | default_opt = { 'alg': 'adam', 'lr': 1e-4, 'b1': 0.99, 'b2': 0.999, 11 | 'layers': 4, 'batch_size': 128 } 12 | class ASRNet(): 13 | # ---------------------------------------------------------------------------- 14 | # initialize sequence 15 | 16 | # based AudioUNet 17 | def __init__(self, from_ckpt=False, n_dim=None, r=2, 18 | opt_params=default_opt, log_prefix='./run'): 19 | self.r = r 20 | # make session 21 | self.sess = tf.Session() 22 | 23 | # save params 24 | self.opt_params = opt_params 25 | self.layers = opt_params['layers'] 26 | 27 | if from_ckpt: 28 | pass # we will instead load the graph from a checkpoint 29 | else: 30 | # create input vars 31 | X = tf.placeholder(tf.float32, shape=(None, None, 1), name='X') 32 | Y = tf.placeholder(tf.float32, shape=(None, None, 1), name='Y') 33 | #alpha = tf.placeholder(tf.float32, shape=(), name='alpha') # weight multiplier 34 | 35 | # save inputs 36 | self.inputs = (X, Y) 37 | tf.add_to_collection('inputs', X) 38 | tf.add_to_collection('inputs', Y) 39 | #tf.add_to_collection('inputs', alpha) 40 | 41 | # create model outputs 42 | self.predictions = self.create_model(n_dim, r) 43 | tf.add_to_collection('preds', self.predictions) 44 | 45 | # init the model 46 | # init = tf.global_variables_initializer() 47 | # init = tf.initialize_all_variables() 48 | # self.sess.run(init) 49 | 50 | # create training updates 51 | self.train_op = self.create_train_op(X, Y) 52 | tf.add_to_collection('train_op', self.train_op) 53 | 54 | # logging 55 | lr_str = '.' + 'lr%f' % opt_params['lr'] 56 | g_str = '.g%d' % self.layers 57 | b_str = '.b%d' % int(opt_params['batch_size']) 58 | 59 | self.logdir = log_prefix + lr_str + '.%d' % r + g_str + b_str 60 | self.checkpoint_root = os.path.join(self.logdir, 'model.ckpt') 61 | 62 | def create_model(self, n_dim, r): 63 | # load inputs 64 | X, _ = self.inputs 65 | L = self.layers 66 | 67 | #n_filters = [128, 256, 512,512] 68 | #n_filtersizes = [65, 33, 17, 9] 69 | 70 | n_filters = [32, 48, 64, 64] 71 | n_filtersizes = [16, 10, 5, 5] 72 | 73 | #n_filters = [10, 20, 40, 40, 40, 40, 40, 40] 74 | #n_filtersizes = [30, 20, 15, 10, 10, 10, 10, 10] 75 | 76 | downsampled_l = [] 77 | with tf.name_scope('generator'): 78 | # save origin-X 79 | oX = X 80 | print('>> Generator Model init...') 81 | # downsampling layers 82 | for l, nf, fs in zip(range(L), n_filters, n_filtersizes): 83 | X = downsample_layer(X, nf, fs) 84 | downsampled_l.append(X) 85 | print('D-Block >> ' ,X) 86 | 87 | # Bottle-neck layer 88 | X = downsample_layer(X, n_filters[-1], n_filtersizes[-1], B=True) 89 | print('B-Block >> ', X) 90 | 91 | # Upsample layer 92 | L = reversed(range(L)) 93 | n_filters = reversed(n_filters) 94 | n_filtersizes = reversed(n_filtersizes) 95 | downsampled_l = reversed(downsampled_l) 96 | 97 | for l, nf, fs, l_in in zip( L, (n_filters), (n_filtersizes), (downsampled_l)): 98 | #X = reshape1Dto2D(X) 99 | X = upsample_layer(X, nf*2, fs) 100 | #X = reshape2Dto1D(X) 101 | X = tf.concat([X,l_in],axis=-1) 102 | print('U-Block >> ',X) 103 | 104 | # Final layer and add input layer 105 | X = upsample_layer(X,nf=2,ks=9) 106 | G = tf.add(X,oX) 107 | print('Fin-Layer >> ',G) 108 | print('>> ...finish') 109 | print() 110 | 111 | return G 112 | 113 | def create_train_op(self, X, Y): 114 | # load params 115 | opt_params = self.opt_params 116 | print('creating train_op with params:', opt_params) 117 | 118 | # create loss 119 | self.loss = self.create_objective(X, Y, opt_params) 120 | 121 | # create params - get trainable variables 122 | # params = self.get_params() 123 | 124 | # create optimizer 125 | self.optimizer = self.create_optimzier(opt_params) 126 | 127 | # create gradients 128 | # grads = self.create_gradients(self.loss, params) 129 | 130 | # create training op 131 | #with tf.name_scope('optimizer'): 132 | # ref - https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 133 | #train_op = self.create_updates(params, grads, alpha, opt_params) 134 | train_op = self.optimizer.minimize(self.loss) 135 | 136 | # initialize the optimizer variabLes 137 | #optimizer_vars = [] 138 | #for v in tf.global_variables(): 139 | # if 'optimizer/' in v.name: 140 | # optimizer_vars.append(v) 141 | 142 | #init = tf.variables_initializer(optimizer_vars) 143 | #self.sess.run(init) 144 | 145 | return train_op 146 | 147 | 148 | def create_objective(self, X, Y, opt_params): 149 | # load model output and true output 150 | P = self.predictions 151 | 152 | # compute l2 loss 153 | sqrt_l2_loss = tf.sqrt(tf.reduce_mean((P-Y)**2 + 1e-6, axis=[1,2])) 154 | sqrn_l2_norm = tf.sqrt(tf.reduce_mean(Y**2, axis=[1,2])) 155 | snr = 20 * tf.log(sqrn_l2_norm / sqrt_l2_loss + 1e-8) / tf.log(10.) 156 | 157 | avg_sqrt_l2_loss = tf.reduce_mean(sqrt_l2_loss, axis=0) 158 | avg_snr = tf.reduce_mean(snr, axis=0) 159 | 160 | # track losses 161 | tf.summary.scalar('l2_loss', avg_sqrt_l2_loss) 162 | tf.summary.scalar('snr', avg_snr) 163 | 164 | # save losses into collection 165 | tf.add_to_collection('losses', avg_sqrt_l2_loss) 166 | tf.add_to_collection('losses', avg_snr) 167 | 168 | return avg_sqrt_l2_loss 169 | 170 | 171 | def create_optimzier(self, opt_params): 172 | if opt_params['alg'] == 'adam': 173 | lr, b1, b2 = opt_params['lr'], opt_params['b1'], opt_params['b2'] 174 | optimizer = tf.train.AdamOptimizer(lr, b1, b2) 175 | else: 176 | raise ValueError('Invalid optimizer: ' + opt_params['alg']) 177 | 178 | return optimizer 179 | 180 | 181 | # ---------------------------------------------------------------------------- 182 | # in training sequence 183 | 184 | def fit(self, X_train, Y_train, X_val, Y_val, n_epoch=100): 185 | 186 | # init the model 187 | init = tf.global_variables_initializer() 188 | # init = tf.initialize_all_variables() 189 | self.sess.run(init) 190 | 191 | 192 | # initialize log directory 193 | if tf.gfile.Exists(self.logdir): tf.gfile.DeleteRecursively(self.logdir) 194 | tf.gfile.MakeDirs(self.logdir) 195 | 196 | # load some training params 197 | n_batch = self.opt_params['batch_size'] 198 | 199 | # create saver 200 | self.saver = tf.train.Saver() 201 | 202 | # summarization 203 | summary = tf.summary.merge_all() 204 | summary_writer = tf.summary.FileWriter(self.logdir, self.sess.graph) 205 | 206 | # load data into DataSet 207 | train_data = DataSet(X_train, Y_train) 208 | val_data = DataSet(X_val, Y_val) 209 | 210 | # train the model 211 | start_time = time.time() 212 | step, epoch = 0, train_data.epochs_completed 213 | 214 | print('start training epoch (n:%d)'%n_epoch) 215 | print('num-of-batch:',n_batch) 216 | 217 | 218 | 219 | for epoch in range(n_epoch): 220 | is_batch_fin = False 221 | step = 1 222 | start_time = time.time() 223 | # loop train data on batch size 224 | count = 0 225 | while not is_batch_fin: 226 | count += 1 227 | 228 | # load next batch data 229 | d,l,is_batch_fin = train_data.next_batch(n_batch) 230 | batch = (d,l) 231 | 232 | # get batch feed-dict 233 | feed_dict = self.load_batch(batch) 234 | 235 | # training batch-size 236 | tr_objective = self.train_batch(feed_dict) 237 | tr_obj_snr = 20 * np.log10(1. / np.sqrt(tr_objective) + 1e-8) 238 | 239 | # print batch log 240 | print('count %d / obj: %f / snr: %f'%(count, tr_objective, tr_obj_snr)) 241 | 242 | # last case 243 | if is_batch_fin: 244 | end_time = time.time() 245 | 246 | # evaluation model each epoch 247 | tr_l2_loss, tr_l2_snr = self.eval_err(X_train, Y_train, n_batch=n_batch) 248 | va_l2_loss, va_l2_snr = self.eval_err(X_val, Y_val, n_batch=n_batch) 249 | 250 | # print epoch log 251 | print() 252 | print("Epoch {} of {} took {:.3f}s ({} minibatches)".format( 253 | epoch+1, n_epoch, end_time - start_time, len(X_train) // n_batch)) 254 | print(" training l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 255 | tr_l2_loss, tr_l2_snr)) 256 | print(" validation l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 257 | va_l2_loss, va_l2_snr)) 258 | print("-----------------------------------------------------------------------") 259 | 260 | # compute summaries for overall loss 261 | objectives_summary = tf.Summary() 262 | objectives_summary.value.add(tag='tr_l2_loss', simple_value=tr_l2_loss) 263 | objectives_summary.value.add(tag='tr_l2_snr' , simple_value=tr_l2_snr) 264 | objectives_summary.value.add(tag='va_l2_snr' , simple_value=va_l2_loss) 265 | 266 | # compute summaries for all other metrics 267 | summary_str = self.sess.run(summary, feed_dict=feed_dict) 268 | summary_writer.add_summary(summary_str, step) 269 | summary_writer.add_summary(objectives_summary, step) 270 | 271 | # write summaries and checkpoints 272 | summary_writer.flush() 273 | self.saver.save(self.sess, self.checkpoint_root, global_step=step) 274 | self.saver.save(self.sess, self.checkpoint_root) 275 | 276 | # restart clock 277 | start_time = time.time() 278 | 279 | # in for loop 280 | step += 1 281 | 282 | def load_batch(self, batch, train=True): 283 | X_in, Y_in = self.inputs 284 | X, Y = batch 285 | 286 | if Y is not None: 287 | feed_dict = {X_in : X, Y_in : Y} 288 | else: 289 | feed_dict = {X_in : X} 290 | 291 | return feed_dict 292 | 293 | 294 | def train_batch(self, feed_dict): 295 | _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict) 296 | return loss 297 | 298 | 299 | def eval_err(self, X, Y, n_batch=128): 300 | batch_iterator = iterate_minibatches(X, Y, n_batch, shuffle=True) 301 | l2_loss_op, l2_snr_op = tf.get_collection('losses') 302 | l2_loss, snr = 0, 0 303 | tot_l2_loss, tot_snr = 0, 0 304 | for bn, batch in enumerate(batch_iterator): 305 | feed_dict = self.load_batch(batch, train=False) 306 | l2_loss, l2_snr = self.sess.run([l2_loss_op, l2_snr_op], feed_dict=feed_dict) 307 | tot_l2_loss += l2_loss 308 | tot_snr += l2_snr 309 | 310 | return tot_l2_loss / (bn+1), tot_snr / (bn+1) 311 | 312 | 313 | def predict(self, X, Y): 314 | assert len(X) == 1 315 | x_sp = spline_up(X, self.r) 316 | x_sp = x_sp[:len(x_sp) - (len(x_sp) % (2**(self.layers+1)))] 317 | X = x_sp.reshape((1,len(x_sp),1)) 318 | feed_dict = self.load_batch((X,Y), train=False) 319 | return self.sess.run(self.predictions, feed_dict=feed_dict) 320 | 321 | 322 | def load(self, ckpt): 323 | # get checkpoint name 324 | # ref - https://www.tensorflow.org/api_docs/python/tf/train/latest_checkpoint 325 | if os.path.isdir(ckpt): checkpoint = tf.train.latest_checkpoint(ckpt) 326 | else: checkpoint = ckpt 327 | meta = checkpoint + '.meta' 328 | print('checkpoint:',checkpoint) 329 | print('ckpt:',ckpt) 330 | 331 | # load 332 | self.saver = tf.train.Saver() 333 | #self.sess.run(tf.global_variables_initializer()) 334 | self.saver.restore(self.sess, checkpoint) 335 | 336 | ''' 337 | # load graph 338 | self.saver = tf.train.import_meta_graph(meta) 339 | g = tf.get_default_graph() 340 | # load weights 341 | self.saver.restore(self.sess, checkpoint) 342 | 343 | # get graph tensors 344 | X, Y = tf.get_collection('inputs') 345 | 346 | # save tensors as instance variables 347 | self.inputs = X, Y 348 | self.predictions = tf.get_collection('preds')[0] 349 | 350 | # load existing loss, or erase it, if creating new one 351 | g.clear_collection('losses') 352 | 353 | # create a new training op 354 | self.train_op = self.create_train_op(X, Y) 355 | g.clear_collection('train_op') 356 | tf.add_to_collection('train_op', self.train_op) 357 | 358 | ''' 359 | 360 | # ---------------------------------------------------------------------------- 361 | # helpers 362 | 363 | def iterate_minibatches(inputs, targets, batchsize, shuffle=False): 364 | assert len(inputs) == len(targets) 365 | if shuffle: 366 | indices = np.arange(len(inputs)) 367 | np.random.shuffle(indices) 368 | for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): 369 | if shuffle: 370 | excerpt = indices[start_idx:start_idx + batchsize] 371 | else: 372 | excerpt = slice(start_idx, start_idx + batchsize) 373 | yield inputs[excerpt], targets[excerpt] 374 | 375 | 376 | def downsample_layer(x , nf, ks, B=False): 377 | x = tf.layers.conv1d( 378 | x, 379 | filters = nf, 380 | kernel_size = ks, 381 | strides=1, 382 | padding='same', 383 | data_format='channels_last', 384 | dilation_rate=1, 385 | activation=None, 386 | use_bias=True, 387 | kernel_initializer=None, 388 | bias_initializer=tf.zeros_initializer(), 389 | kernel_regularizer=None, 390 | bias_regularizer=None, 391 | activity_regularizer=None, 392 | kernel_constraint=None, 393 | bias_constraint=None, 394 | trainable=True, 395 | name=None, 396 | reuse=None 397 | ) 398 | x = tf.layers.max_pooling1d( 399 | x, 400 | pool_size = 2, 401 | strides = 2, 402 | padding='same', 403 | data_format='channels_last', 404 | name=None 405 | ) 406 | 407 | if B : x = tf.layers.dropout(x, rate=0.5) 408 | 409 | x = tf.nn.relu(x) 410 | return x 411 | 412 | def upsample_layer(x, nf, ks): 413 | '''x = tf.layers.conv2d_transpose( 414 | x, 415 | filters = nf, 416 | kernel_size = [1,ks], 417 | strides=(1, 1), 418 | padding='same', 419 | data_format='channels_last', 420 | activation=None, 421 | use_bias=True, 422 | kernel_initializer=None, 423 | bias_initializer=tf.zeros_initializer(), 424 | kernel_regularizer=None, 425 | bias_regularizer=None, 426 | activity_regularizer=None, 427 | kernel_constraint=None, 428 | bias_constraint=None, 429 | trainable=True, 430 | name=None, 431 | reuse=None 432 | )''' 433 | x = tf.layers.conv1d( 434 | x, 435 | filters = nf, 436 | kernel_size = ks, 437 | strides=1, 438 | padding='same', 439 | data_format='channels_last', 440 | dilation_rate=1, 441 | activation=None, 442 | use_bias=True, 443 | kernel_initializer=None, 444 | bias_initializer=tf.zeros_initializer(), 445 | kernel_regularizer=None, 446 | bias_regularizer=None, 447 | activity_regularizer=None, 448 | kernel_constraint=None, 449 | bias_constraint=None, 450 | trainable=True, 451 | name=None, 452 | reuse=None 453 | ) 454 | x = tf.layers.dropout(x, rate=0.5) 455 | x = tf.nn.relu(x) 456 | x = SubPixel1D(x,r=2) 457 | return x 458 | 459 | def spline_up(x_lr, r): 460 | x_lr = x_lr.flatten() 461 | x_hr_len = len(x_lr) * r 462 | x_sp = np.zeros(x_hr_len) 463 | 464 | i_lr = np.arange(x_hr_len, step=r) 465 | i_hr = np.arange(x_hr_len) 466 | 467 | f = interpolate.splrep(i_lr, x_lr) 468 | 469 | x_sp = interpolate.splev(i_hr, f) 470 | 471 | return x_sp 472 | -------------------------------------------------------------------------------- /asr_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/asr_model.pyc -------------------------------------------------------------------------------- /data/temp.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/data/temp.txt -------------------------------------------------------------------------------- /data/test.txt: -------------------------------------------------------------------------------- 1 | test_0.wav -------------------------------------------------------------------------------- /data/train.txt: -------------------------------------------------------------------------------- 1 | p225_001.wav 2 | p225_002.wav 3 | p225_003.wav 4 | p225_004.wav 5 | p225_005.wav 6 | p225_006.wav 7 | p225_007.wav 8 | p225_008.wav 9 | p225_009.wav 10 | p225_010.wav 11 | p225_011.wav 12 | p225_012.wav 13 | p225_013.wav 14 | p225_014.wav 15 | p225_016.wav 16 | p225_017.wav 17 | p225_018.wav 18 | p225_019.wav 19 | p225_020.wav 20 | p225_021.wav 21 | p225_022.wav 22 | p225_023.wav 23 | p225_024.wav 24 | p225_025.wav 25 | p225_026.wav 26 | p225_027.wav 27 | p225_028.wav 28 | p225_029.wav 29 | p225_030.wav 30 | p225_033.wav 31 | p225_035.wav 32 | p225_036.wav 33 | p225_037.wav 34 | p225_038.wav 35 | p225_039.wav 36 | p225_040.wav 37 | p225_044.wav 38 | p225_045.wav 39 | p225_046.wav 40 | p225_049.wav 41 | p225_051.wav 42 | p225_052.wav 43 | p225_053.wav 44 | p225_054.wav 45 | p225_056.wav 46 | p225_057.wav 47 | p225_058.wav 48 | p225_059.wav 49 | p225_060.wav 50 | p225_061.wav 51 | p225_062.wav 52 | p225_063.wav 53 | p225_064.wav 54 | p225_065.wav 55 | p225_066.wav 56 | p225_067.wav 57 | p225_070.wav 58 | p225_071.wav 59 | p225_072.wav 60 | p225_073.wav 61 | p225_081.wav 62 | p225_082.wav 63 | p225_083.wav 64 | p225_084.wav 65 | p225_086.wav 66 | p225_089.wav 67 | p225_090.wav 68 | p225_092.wav 69 | p225_094.wav 70 | p225_103.wav 71 | p225_104.wav 72 | p225_108.wav 73 | p225_109.wav 74 | p225_110.wav 75 | p225_111.wav 76 | p225_113.wav 77 | p225_114.wav 78 | p225_115.wav 79 | p225_116.wav 80 | p225_117.wav 81 | p225_118.wav 82 | p225_120.wav 83 | p225_121.wav 84 | p225_122.wav 85 | p225_123.wav 86 | p225_124.wav 87 | p225_126.wav 88 | p225_127.wav 89 | p225_128.wav 90 | p225_131.wav 91 | p225_133.wav 92 | p225_135.wav 93 | p225_136.wav 94 | p225_141.wav 95 | p225_142.wav 96 | p225_143.wav 97 | p225_144.wav 98 | p225_145.wav 99 | p225_147.wav 100 | p225_149.wav 101 | p225_150.wav 102 | p225_151.wav 103 | p225_152.wav 104 | p225_153.wav 105 | p225_156.wav 106 | p225_157.wav 107 | p225_158.wav 108 | p225_159.wav 109 | p225_165.wav 110 | p225_166.wav 111 | p225_169.wav 112 | p225_171.wav 113 | p225_172.wav 114 | p225_173.wav 115 | p225_174.wav 116 | p225_175.wav 117 | p225_176.wav 118 | p225_177.wav 119 | p225_179.wav 120 | p225_182.wav 121 | p225_191.wav 122 | p225_192.wav 123 | p225_193.wav 124 | p225_195.wav 125 | p225_196.wav 126 | p225_197.wav 127 | p225_199.wav 128 | p225_200.wav 129 | p225_201.wav 130 | p225_202.wav 131 | p225_203.wav 132 | p225_208.wav 133 | p225_210.wav 134 | p225_211.wav 135 | p225_212.wav 136 | p225_218.wav 137 | p225_219.wav 138 | p225_220.wav 139 | p225_221.wav 140 | p225_222.wav 141 | p225_223.wav 142 | p225_224.wav 143 | p225_225.wav 144 | p225_235.wav 145 | p225_236.wav 146 | p225_237.wav 147 | p225_238.wav 148 | p225_239.wav 149 | p225_240.wav 150 | p225_241.wav 151 | p225_242.wav 152 | p225_243.wav 153 | p225_244.wav 154 | p225_248.wav 155 | p225_253.wav 156 | p225_254.wav 157 | p225_257.wav 158 | p225_258.wav 159 | p225_264.wav 160 | p225_265.wav 161 | p225_266.wav 162 | p225_268.wav 163 | p225_273.wav 164 | p225_274.wav 165 | p225_275.wav 166 | p225_276.wav 167 | p225_277.wav 168 | p225_279.wav 169 | p225_280.wav 170 | p225_281.wav 171 | p225_282.wav 172 | p225_285.wav 173 | p225_286.wav 174 | p225_287.wav 175 | p225_289.wav 176 | p225_290.wav 177 | p225_291.wav 178 | p225_293.wav 179 | p225_294.wav 180 | p225_295.wav 181 | p225_296.wav 182 | p225_297.wav 183 | p225_298.wav 184 | p225_299.wav 185 | p225_300.wav 186 | p225_301.wav 187 | p225_302.wav 188 | p225_303.wav 189 | p225_305.wav 190 | p225_308.wav 191 | p225_309.wav 192 | p225_310.wav 193 | p225_312.wav 194 | p225_314.wav 195 | p225_315.wav 196 | p225_316.wav 197 | p225_317.wav 198 | p225_318.wav 199 | p225_319.wav 200 | p225_320.wav 201 | p225_322.wav 202 | p225_323.wav 203 | p225_324.wav 204 | p225_325.wav 205 | p225_326.wav 206 | p225_328.wav 207 | p225_329.wav 208 | p225_330.wav 209 | p225_331.wav 210 | p225_332.wav 211 | p225_334.wav 212 | p225_335.wav 213 | p225_336.wav 214 | p225_337.wav 215 | p225_346.wav 216 | p225_347.wav 217 | p225_348.wav 218 | p225_349.wav 219 | p225_350.wav 220 | p225_351.wav 221 | p225_352.wav 222 | p225_353.wav 223 | p225_354.wav 224 | p225_355.wav 225 | p225_356.wav 226 | p225_357.wav 227 | p225_358.wav 228 | p225_359.wav 229 | p225_363.wav 230 | p225_365.wav 231 | p225_366.wav 232 | -------------------------------------------------------------------------------- /data/valid.txt: -------------------------------------------------------------------------------- 1 | p362_300.wav 2 | p362_301.wav 3 | p362_302.wav 4 | p362_303.wav 5 | p362_304.wav 6 | p362_305.wav 7 | p362_306.wav 8 | p362_307.wav 9 | p362_308.wav 10 | p362_309.wav 11 | p362_310.wav 12 | p362_311.wav 13 | p362_312.wav 14 | p362_313.wav 15 | p362_314.wav 16 | p362_315.wav 17 | p362_316.wav 18 | p362_317.wav 19 | p362_318.wav 20 | p362_319.wav 21 | p362_320.wav 22 | p362_321.wav 23 | p362_322.wav 24 | p362_323.wav 25 | p362_324.wav 26 | p362_325.wav 27 | p362_326.wav 28 | p362_327.wav 29 | p362_328.wav 30 | p362_329.wav 31 | p362_330.wav 32 | p362_331.wav 33 | p362_332.wav 34 | p362_333.wav 35 | p362_334.wav 36 | p362_335.wav 37 | p362_336.wav 38 | p362_337.wav 39 | p362_338.wav 40 | p362_339.wav 41 | p362_340.wav 42 | p362_341.wav 43 | p362_342.wav 44 | p362_343.wav 45 | p362_344.wav 46 | p362_345.wav 47 | p362_346.wav 48 | p362_347.wav 49 | p362_348.wav 50 | p362_349.wav 51 | p362_350.wav 52 | p362_351.wav 53 | p362_352.wav 54 | p362_353.wav 55 | p362_354.wav 56 | p362_355.wav 57 | p362_356.wav 58 | p362_357.wav 59 | p362_358.wav 60 | p362_359.wav 61 | p362_360.wav 62 | p362_361.wav 63 | p362_362.wav 64 | p362_363.wav 65 | p362_364.wav 66 | p362_365.wav 67 | p362_366.wav 68 | p362_367.wav 69 | p362_368.wav 70 | p362_369.wav 71 | p362_370.wav 72 | p362_371.wav 73 | p362_372.wav 74 | p362_373.wav 75 | p362_374.wav 76 | p362_375.wav 77 | p362_376.wav 78 | p362_377.wav 79 | p362_378.wav 80 | p362_379.wav 81 | p362_380.wav 82 | p362_381.wav 83 | p362_382.wav 84 | p362_383.wav 85 | p362_384.wav 86 | p362_385.wav 87 | p362_386.wav 88 | p362_387.wav 89 | p362_388.wav 90 | p362_389.wav 91 | p362_390.wav 92 | p362_391.wav 93 | p362_392.wav 94 | p362_393.wav 95 | p362_394.wav 96 | p362_395.wav 97 | p362_396.wav 98 | p362_397.wav 99 | p362_398.wav 100 | p362_399.wav 101 | p362_400.wav 102 | p362_401.wav 103 | p362_402.wav 104 | p362_403.wav 105 | p362_404.wav 106 | p362_405.wav 107 | p362_406.wav 108 | p362_407.wav 109 | p362_408.wav 110 | p362_409.wav 111 | p362_410.wav 112 | p362_411.wav 113 | p362_412.wav 114 | p362_413.wav 115 | p362_414.wav 116 | p362_415.wav 117 | p362_416.wav 118 | p362_417.wav 119 | p362_418.wav 120 | p362_419.wav 121 | p362_420.wav 122 | p362_421.wav 123 | p362_422.wav 124 | p362_423.wav 125 | p362_424.wav 126 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """Class for doing iterations over datasets 2 | 3 | This is stolen from the tensorflow tutorial 4 | """ 5 | 6 | import numpy 7 | 8 | from tensorflow.contrib.learn.python.learn.datasets import base 9 | from tensorflow.python.framework import dtypes 10 | 11 | # ---------------------------------------------------------------------------- 12 | 13 | class DataSet(object): 14 | 15 | def __init__(self, 16 | datapoints, 17 | labels, 18 | fake_data=False, 19 | one_hot=False, 20 | dtype=dtypes.float32): 21 | """Construct a DataSet. 22 | one_hot arg is used only if fake_data is true. `dtype` can be either 23 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 24 | `[0, 1]`. 25 | """ 26 | dtype = dtypes.as_dtype(dtype).base_dtype 27 | if dtype not in (dtypes.uint8, dtypes.float32): 28 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 29 | dtype) 30 | 31 | if labels is None: 32 | labels = np.zeros((len(datapoints),)) 33 | 34 | if fake_data: 35 | self._num_examples = 10000 36 | self.one_hot = one_hot 37 | else: 38 | assert datapoints.shape[0] == labels.shape[0], ( 39 | 'datapoints.shape: %s labels.shape: %s' % (datapoints.shape, labels.shape)) 40 | self._num_examples = datapoints.shape[0] 41 | 42 | self._datapoints = datapoints 43 | self._labels = labels 44 | self._epochs_completed = 0 45 | self._index_in_epoch = 0 46 | 47 | @property 48 | def datapoints(self): 49 | return self._datapoints 50 | 51 | @property 52 | def labels(self): 53 | return self._labels 54 | 55 | @property 56 | def num_examples(self): 57 | return self._num_examples 58 | 59 | @property 60 | def epochs_completed(self): 61 | return self._epochs_completed 62 | 63 | def next_batch(self, batch_size, fake_data=False, shuffle=True): 64 | """Return the next `batch_size` examples from this data set.""" 65 | if fake_data: 66 | fake_image = [1] * 784 67 | if self.one_hot: 68 | fake_label = [1] + [0] * 9 69 | else: 70 | fake_label = 0 71 | return [fake_image for _ in xrange(batch_size)], [ 72 | fake_label for _ in xrange(batch_size) 73 | ] 74 | 75 | start = self._index_in_epoch 76 | 77 | # Shuffle for the first epoch 78 | if self._epochs_completed == 0 and start == 0 and shuffle: 79 | perm0 = numpy.arange(self._num_examples) 80 | numpy.random.shuffle(perm0) 81 | self._datapoints = self.datapoints[perm0] 82 | self._labels = self.labels[perm0] 83 | 84 | # Go to the next epoch 85 | if start + batch_size > self._num_examples: 86 | # Finished epoch 87 | self._epochs_completed += 1 88 | 89 | # Get the rest examples in this epoch 90 | rest_num_examples = self._num_examples - start 91 | datapoints_rest_part = self._datapoints[start:self._num_examples] 92 | labels_rest_part = self._labels[start:self._num_examples] 93 | 94 | # Shuffle the data 95 | if shuffle: 96 | perm = numpy.arange(self._num_examples) 97 | numpy.random.shuffle(perm) 98 | self._datapoints = self.datapoints[perm] 99 | self._labels = self.labels[perm] 100 | 101 | # Start next epoch 102 | start = 0 103 | self._index_in_epoch = batch_size - rest_num_examples 104 | end = self._index_in_epoch 105 | datapoints_new_part = self._datapoints[start:end] 106 | labels_new_part = self._labels[start:end] 107 | return numpy.concatenate((datapoints_rest_part, datapoints_new_part), axis=0),numpy.concatenate((labels_rest_part, labels_new_part), axis=0), True 108 | else: 109 | self._index_in_epoch += batch_size 110 | end = self._index_in_epoch 111 | return self._datapoints[start:end], self._labels[start:end], False 112 | -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt-1" 3 | all_model_checkpoint_paths: "model.ckpt" 4 | -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/events.out.tfevents.1531203790.smart-deep-learning: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/events.out.tfevents.1531203790.smart-deep-learning -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.data-00000-of-00001 -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.index -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt-1.meta -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt.index -------------------------------------------------------------------------------- /default_log_name.lr0.000500.1.g4.b100/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/default_log_name.lr0.000500.1.g4.b100/model.ckpt.meta -------------------------------------------------------------------------------- /example-hr.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/example-hr.wav -------------------------------------------------------------------------------- /example-lr.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/example-lr.wav -------------------------------------------------------------------------------- /git-lfs-2.4.2/README.md: -------------------------------------------------------------------------------- 1 | # Git Large File Storage 2 | 3 | | Linux | macOS | Windows | 4 | | :---- | :------ | :---- | 5 | [ ![Linux build status][1]][2] | [![macOS build status][3]][4] | [![Windows build status][5]][6] | 6 | 7 | [1]: https://travis-ci.org/git-lfs/git-lfs.svg?branch=master 8 | [2]: https://travis-ci.org/git-lfs/git-lfs 9 | [3]: https://circleci.com/gh/git-lfs/git-lfs.svg?style=shield&circle-token=856152c2b02bfd236f54d21e1f581f3e4ebf47ad 10 | [4]: https://circleci.com/gh/git-lfs/git-lfs 11 | [5]: https://ci.appveyor.com/api/projects/status/46a5yoqc3hk59bl5/branch/master?svg=true 12 | [6]: https://ci.appveyor.com/project/git-lfs/git-lfs/branch/master 13 | 14 | [Git LFS](https://git-lfs.github.com) is a command line extension and 15 | [specification](docs/spec.md) for managing large files with Git. 16 | 17 | The client is written in Go, with pre-compiled binaries available for Mac, 18 | Windows, Linux, and FreeBSD. Check out the [website](http://git-lfs.github.com) 19 | for an overview of features. 20 | 21 | ## Getting Started 22 | 23 | ### Installation 24 | 25 | You can install the Git LFS client in several different ways, depending on your 26 | setup and preferences. 27 | 28 | * **Linux users**. Debian and RPM packages are available from 29 | [PackageCloud](https://packagecloud.io/github/git-lfs/install). 30 | * **macOS users**. [Homebrew](https://brew.sh) bottles are distributed, and can 31 | be installed via `brew install git-lfs`. 32 | * **Windows users**. Chocolatey packages are distributed, and can be installed 33 | via `choco install git-lfs`. 34 | 35 | In addition, [binary packages](https://github.com/git-lfs/git-lfs/releases) are 36 | available for Linux, macOS, Windows, and FreeBSD. This repository can also be 37 | built-from-source using the latest version of [Go](https://golang.org). 38 | 39 | ### Usage 40 | 41 | Git LFS requires a global installation once per-machine. This can be done by 42 | running: 43 | 44 | ```bash 45 | $ git lfs install 46 | ``` 47 | 48 | To begin using Git LFS within your Git repository, you can indicate which files 49 | you would like Git LFS to manage. This can be done by running the following 50 | _from within Git repository_: 51 | 52 | ```bash 53 | $ git lfs track "*.psd" 54 | ``` 55 | 56 | (Where `*.psd` is the pattern of filenames that you wish to track. You can read 57 | more about this pattern syntax 58 | [here](https://git-scm.com/docs/gitattributes)). 59 | 60 | After any invocation of `git-lfs-track(1)` or `git-lfs-untrack(1)`, you _must 61 | commit changes to your `.gitattributes` file_. This can be done by running: 62 | 63 | ```bash 64 | $ git add .gitattributes 65 | $ git commit -m "track *.psd files using Git LFS" 66 | ``` 67 | 68 | You can now interact with your Git repository as usual, and Git LFS will take 69 | care of managing your large files. For example, changing a file named `my.psd` 70 | (tracked above via `*.psd`): 71 | 72 | ```bash 73 | $ git add my.psd 74 | $ git commit -m "add psd" 75 | ``` 76 | 77 | > _Tip:_ if you have large files already in your repository's history, `git lfs 78 | > track` will _not_ track them retroactively. To migrate existing large files 79 | > in your history to use Git LFS, use `git lfs migrate`. For example: 80 | > 81 | > ``` 82 | > $ git lfs migrate import --include="*.psd" 83 | > ``` 84 | > 85 | > For more information, read [`git-lfs-migrate(1)`](https://github.com/git-lfs/git-lfs/blob/master/docs/man/git-lfs-migrate.1.ronn). 86 | 87 | You can confirm that Git LFS is managing your PSD file: 88 | 89 | ```bash 90 | $ git lfs ls-files 91 | 3c2f7aedfb * my.psd 92 | ``` 93 | 94 | Once you've made your commits, push your files to the Git remote: 95 | 96 | ```bash 97 | $ git push origin master 98 | Uploading LFS objects: 100% (1/1), 810 B, 1.2 KB/s 99 | # ... 100 | To https://github.com/git-lfs/git-lfs-test 101 | 67fcf6a..47b2002 master -> master 102 | ``` 103 | 104 | Note: Git LFS requires Git v1.8.5 or higher. 105 | 106 | ## Limitations 107 | 108 | Git LFS maintains a list of currently known limitations, which you can find and 109 | edit [here](https://github.com/git-lfs/git-lfs/wiki/Limitations). 110 | 111 | ## Need Help? 112 | 113 | You can get help on specific commands directly: 114 | 115 | ```bash 116 | $ git lfs help 117 | ``` 118 | 119 | The [official documentation](docs) has command references and specifications for 120 | the tool. 121 | 122 | You can always [open an issue](https://github.com/git-lfs/git-lfs/issues), and 123 | one of the Core Team members will respond to you. Please be sure to include: 124 | 125 | 1. The output of `git lfs env`, which displays helpful information about your 126 | Git repository useful in debugging. 127 | 2. Any failed commands re-run with `GIT_TRACE=1` in the environment, which 128 | displays additional information pertaining to why a command crashed. 129 | 130 | ## Contributing 131 | 132 | See [CONTRIBUTING.md](CONTRIBUTING.md) for info on working on Git LFS and 133 | sending patches. Related projects are listed on the [Implementations wiki 134 | page](https://github.com/git-lfs/git-lfs/wiki/Implementations). 135 | 136 | ## Core Team 137 | 138 | These are the humans that form the Git LFS core team, which runs the project. 139 | 140 | In alphabetical order: 141 | 142 | | [@larsxschneider][larsxschneider-user] | [@ttaylorr][ttaylorr-user] | 143 | |---|---| 144 | | [![][larsxschneider-img]][larsxschneider-user] | [![][ttaylorr-img]][ttaylorr-user] | 145 | 146 | [larsxschneider-img]: https://avatars1.githubusercontent.com/u/477434?s=100&v=4 147 | [ttaylorr-img]: https://avatars2.githubusercontent.com/u/443245?s=100&v=4 148 | [larsxschneider-user]: https://github.com/larsxschneider 149 | [ttaylorr-user]: https://github.com/ttaylorr 150 | 151 | ### Alumni 152 | 153 | These are the humans that have in the past formed the Git LFS core team, or 154 | have otherwise contributed a significant amount to the project. Git LFS would 155 | not be possible without them. 156 | 157 | In alphabetical order: 158 | 159 | | [@andyneff][andyneff-user] | [@rubyist][rubyist-user] | [@sinbad][sinbad-user] | [@technoweenie][technoweenie-user] | 160 | |---|---|---|---| 161 | | [![][andyneff-img]][andyneff-user] | [![][rubyist-img]][rubyist-user] | [![][sinbad-img]][sinbad-user] | [![][technoweenie-img]][technoweenie-user] | 162 | 163 | [andyneff-img]: https://avatars1.githubusercontent.com/u/7596961?v=3&s=100 164 | [rubyist-img]: https://avatars1.githubusercontent.com/u/143?v=3&s=100 165 | [sinbad-img]: https://avatars1.githubusercontent.com/u/142735?v=3&s=100 166 | [technoweenie-img]: https://avatars3.githubusercontent.com/u/21?v=3&s=100 167 | [andyneff-user]: https://github.com/andyneff 168 | [sinbad-user]: https://github.com/sinbad 169 | [rubyist-user]: https://github.com/rubyist 170 | [technoweenie-user]: https://github.com/technoweenie 171 | -------------------------------------------------------------------------------- /git-lfs-2.4.2/git-lfs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/git-lfs-2.4.2/git-lfs -------------------------------------------------------------------------------- /git-lfs-2.4.2/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | 4 | prefix="/usr/local" 5 | 6 | if [ "${PREFIX:-}" != "" ] ; then 7 | prefix=${PREFIX:-} 8 | elif [ "${BOXEN_HOME:-}" != "" ] ; then 9 | prefix=${BOXEN_HOME:-} 10 | fi 11 | 12 | mkdir -p $prefix/bin 13 | rm -rf $prefix/bin/git-lfs* 14 | 15 | pushd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null 16 | for g in git*; do 17 | install $g "$prefix/bin/$g" 18 | done 19 | popd > /dev/null 20 | 21 | PATH+=:$prefix/bin 22 | git lfs install 23 | -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import librosa 5 | 6 | from scipy.signal import decimate 7 | 8 | from matplotlib import pyplot as plt 9 | 10 | # ---------------------------------------------------------------------------- 11 | 12 | def load_h5(h5_path): 13 | # load training data 14 | with h5py.File(h5_path, 'r') as hf: 15 | print ('List of arrays in input file:', hf.keys()) 16 | X = np.array(hf.get('data')) 17 | Y = np.array(hf.get('label')) 18 | print ('Shape of X:', X.shape) 19 | print ('Shape of Y:', Y.shape) 20 | 21 | return X, Y 22 | 23 | def upsample_wav(wav, args, model): 24 | # load signal 25 | x_hr, fs = librosa.load(wav, sr=args['sr']) 26 | 27 | # downscale signal 28 | # x_lr = np.array(x_hr[0::args.r]) 29 | x_lr = decimate(x_hr, args['r']) 30 | # x_lr = decimate(x_hr, args.r, ftype='fir', zero_phase=True) 31 | # x_lr = downsample_bt(x_hr, args.r) 32 | 33 | # upscale the low-res version using trained model 34 | X_hr = x_hr.reshape((1,len(x_hr),1)) 35 | X_lr = x_lr.reshape((1,len(x_lr),1)) 36 | print(X_hr.shape) 37 | print(X_lr.shape) 38 | P = model.predict(X_lr,X_hr) 39 | x_pr = P.flatten() 40 | 41 | # crop so that it works with scaling ratio 42 | #x_hr = x_hr[:len(x_pr)] 43 | #x_lr = x_lr[:len(x_pr)] 44 | 45 | # save the file 46 | outname = wav + '.' + args['out_label'] 47 | librosa.output.write_wav(outname + '.hr.wav', x_hr, fs) 48 | librosa.output.write_wav(outname + '.lr.wav', x_lr, int(fs/args['r'])) 49 | librosa.output.write_wav(outname + '.pr.wav', x_pr, fs) 50 | 51 | # save the spectrum 52 | S = get_spectrum(x_pr, n_fft=2048) 53 | save_spectrum(S, outfile=outname + '.pr.png') 54 | S = get_spectrum(x_hr, n_fft=2048) 55 | save_spectrum(S, outfile=outname + '.hr.png') 56 | n_fft = int(2048/args['r']) 57 | S = get_spectrum(x_lr, n_fft=n_fft) 58 | save_spectrum(S, outfile=outname + '.lr.png') 59 | 60 | # ---------------------------------------------------------------------------- 61 | 62 | def get_spectrum(x, n_fft=2048): 63 | S = librosa.stft(x, n_fft) 64 | p = np.angle(S) 65 | S = np.log1p(np.abs(S)) 66 | return S 67 | 68 | def save_spectrum(S, lim=800, outfile='spectrogram.png'): 69 | plt.imshow(S.T, aspect=10) 70 | # plt.xlim([0,lim]) 71 | plt.tight_layout() 72 | plt.savefig(outfile) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | import librosa 6 | from dataset import DataSet 7 | default_opt = { 'alg': 'adam', 'lr': 1e-4, 'b1': 0.99, 'b2': 0.999, 8 | 'layers': 4, 'batch_size': 128 } 9 | 10 | class Model(object): 11 | """Generic tensorflow model training code""" 12 | 13 | def __init__(self, from_ckpt=False, n_dim=None, r=2,opt_params=default_opt, log_prefix='./run'): 14 | 15 | # make session 16 | self.sess = tf.Session() 17 | 18 | # save params 19 | self.opt_params = opt_params 20 | self.layers = opt_params['layers'] 21 | 22 | if from_ckpt: 23 | pass # we will instead load the graph from a checkpoint 24 | else: 25 | # create input vars 26 | X = tf.placeholder(tf.float32, shape=(None, 8192, 1), name='X') 27 | Y = tf.placeholder(tf.float32, shape=(None, 16384, 1), name='Y') 28 | alpha = tf.placeholder(tf.float32, shape=(), name='alpha') # weight multiplier 29 | 30 | # save inputs 31 | self.inputs = (X, Y, alpha) 32 | tf.add_to_collection('inputs', X) 33 | tf.add_to_collection('inputs', Y) 34 | tf.add_to_collection('inputs', alpha) 35 | 36 | # create model outputs 37 | self.predictions = self.create_model(n_dim, r) 38 | tf.add_to_collection('preds', self.predictions) 39 | 40 | # init the model 41 | init = tf.global_variables_initializer() 42 | self.sess.run(init) 43 | 44 | # create training updates 45 | self.train_op = self.create_train_op(X, Y, alpha) 46 | tf.add_to_collection('train_op', self.train_op) 47 | 48 | # logging 49 | lr_str = '.' + 'lr%f' % opt_params['lr'] 50 | g_str = '.g%d' % self.layers 51 | b_str = '.b%d' % int(opt_params['batch_size']) 52 | 53 | self.logdir = log_prefix + lr_str + '.%d' % r + g_str + b_str 54 | self.checkpoint_root = os.path.join(self.logdir, 'model.ckpt') 55 | 56 | 57 | def get_params(self): 58 | return [ v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 59 | if 'soundnet' not in v.name ] 60 | 61 | 62 | def create_model(self, n_dim, r): 63 | raise NotImplementedError() # The inherited must be defined 64 | 65 | 66 | def create_optimzier(self, opt_params): 67 | if opt_params['alg'] == 'adam': 68 | lr, b1, b2 = opt_params['lr'], opt_params['b1'], opt_params['b2'] 69 | optimizer = tf.train.AdamOptimizer(lr, b1, b2) 70 | else: 71 | raise ValueError('Invalid optimizer: ' + opt_params['alg']) 72 | 73 | return optimizer 74 | 75 | 76 | def create_gradients(self, loss, params): 77 | ''' 78 | 79 | compute_gradients( 80 | loss, 81 | var_list=None, 82 | gate_gradients=GATE_OP, 83 | aggregation_method=None, 84 | colocate_gradients_with_ops=False, 85 | grad_loss=None 86 | ) 87 | Compute gradients of loss for the variables in var_list. 88 | 89 | This is the first part of minimize(). 90 | It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". 91 | Note that "gradient" can be a Tensor, an IndexedSlices, 92 | or None if there is no gradient for the given variable. 93 | 94 | ''' 95 | gv = self.optimizer.compute_gradients(loss, params) # return 'gradient' and 'variable' 96 | g, v = zip(*gv) 97 | return g 98 | 99 | 100 | def create_objective(self, X, Y, opt_params): 101 | # load model output and true output 102 | P = self.predictions 103 | 104 | # compute l2 loss 105 | sqrt_l2_loss = tf.sqrt(tf.reduce_mean((P-Y)**2 + 1e-6, axis=[1,2])) 106 | sqrn_l2_norm = tf.sqrt(tf.reduce_mean(Y**2, axis=[1,2])) 107 | snr = 20 * tf.log(sqrn_l2_norm / sqrt_l2_loss + 1e-8) / tf.log(10.) 108 | 109 | avg_sqrt_l2_loss = tf.reduce_mean(sqrt_l2_loss, axis=0) 110 | avg_snr = tf.reduce_mean(snr, axis=0) 111 | 112 | # track losses 113 | tf.summary.scalar('l2_loss', avg_sqrt_l2_loss) 114 | tf.summary.scalar('snr', avg_snr) 115 | 116 | # save losses into collection 117 | tf.add_to_collection('losses', avg_sqrt_l2_loss) 118 | tf.add_to_collection('losses', avg_snr) 119 | 120 | return avg_sqrt_l2_loss 121 | 122 | 123 | def create_train_op(self, X, Y, alpha): 124 | # load params 125 | opt_params = self.opt_params 126 | print('creating train_op with params:', opt_params) 127 | 128 | # create loss 129 | self.loss = self.create_objective(X, Y, opt_params) 130 | 131 | # create params - get trainable variables 132 | params = self.get_params() 133 | 134 | # create optimizer 135 | self.optimizer = self.create_optimzier(opt_params) 136 | 137 | # create gradients 138 | grads = self.create_gradients(self.loss, params) 139 | 140 | # create training op 141 | with tf.name_scope('optimizer'): 142 | # ref - https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 143 | train_op = self.create_updates(params, grads, alpha, opt_params) 144 | 145 | # initialize the optimizer variabLes 146 | optimizer_vars = [] 147 | for v in tf.global_variables(): 148 | if 'optimizer/' in v.name: 149 | optimizer_vars.append(v) 150 | 151 | init = tf.variables_initializer(optimizer_vars) 152 | self.sess.run(init) 153 | 154 | return train_op 155 | 156 | 157 | def create_updates(self, params, grads, alpha, opt_params): 158 | # create a variable to track the global step. 159 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 160 | 161 | # update grads 162 | grads = [alpha*g for g in grads] 163 | 164 | # use the optimizer to apply the gradients that minimize the loss 165 | gv = zip(grads, params) 166 | ''' 167 | apply_gradients( 168 | grads_and_vars, 169 | global_step=None, 170 | name=None 171 | ) 172 | Apply gradients to variables. 173 | 174 | This is the second part of minimize(). It returns an Operation that applies gradients. 175 | 176 | Args: 177 | grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). 178 | global_step: Optional Variable to increment by one after the variables have been updated. 179 | name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. 180 | Returns: 181 | An Operation that applies the specified gradients. If global_step was not None, that operation also increments global_step. 182 | 183 | Raises: 184 | TypeError: If grads_and_vars is malformed. 185 | ValueError: If none of the variables have gradients. 186 | RuntimeError: If you should use _distributed_apply() instead. 187 | ''' 188 | train_op = self.optimizer.apply_gradients(gv, global_step=self.global_step) 189 | 190 | return train_op 191 | 192 | ######################################################################################################################################################### 193 | ######################################################################################################################################################### 194 | 195 | def load(self, ckpt): 196 | # get checkpoint name 197 | # ref - https://www.tensorflow.org/api_docs/python/tf/train/latest_checkpoint 198 | if os.path.isdir(ckpt): checkpoint = tf.train.latest_checkpoint(ckpt) 199 | else: checkpoint = ckpt 200 | meta = checkpoint + '.meta' 201 | print(checkpoint) 202 | 203 | # load graph 204 | self.saver = tf.train.import_meta_graph(meta) 205 | g = tf.get_default_graph() 206 | 207 | # load weights 208 | self.saver.restore(self.sess, checkpoint) 209 | 210 | # get graph tensors 211 | X, Y, alpha = tf.get_collection('inputs') 212 | 213 | # save tensors as instance variables 214 | self.inputs = X, Y, alpha 215 | self.predictions = tf.get_collection('preds')[0] 216 | 217 | # load existing loss, or erase it, if creating new one 218 | g.clear_collection('losses') 219 | 220 | # create a new training op 221 | self.train_op = self.create_train_op(X, Y, alpha) 222 | g.clear_collection('train_op') 223 | tf.add_to_collection('train_op', self.train_op) 224 | 225 | 226 | def load_batch(self, batch, alpha=1, train=True): 227 | X_in, Y_in, alpha_in = self.inputs 228 | X, Y = batch 229 | 230 | if Y is not None: 231 | feed_dict = {X_in : X, Y_in : Y, alpha_in : alpha} 232 | else: 233 | feed_dict = {X_in : X, alpha_in : alpha} 234 | 235 | '''# this is ugly, but only way I found to get this var after model reload 236 | g = tf.get_default_graph() 237 | 238 | k_tensors = [] 239 | for n in g.as_graph_def().node: 240 | if 'keras_learning_phase' in n.name and 'input' not in n.name: 241 | print('tf.default_graph.node:',n.name) 242 | k_tensors.append(n) 243 | 244 | #k_tensors = [n for n in g.as_graph_def().node if 'keras_learning_phase' in n.name] 245 | 246 | # ?????????????????????????/ 247 | #assert len(k_tensors) <= 1 248 | assert len(k_tensors) <= 1 249 | 250 | if k_tensors: 251 | k_learning_phase = g.get_tensor_by_name(k_tensors[0].name + ':0') 252 | feed_dict[k_learning_phase] = train''' 253 | 254 | return feed_dict 255 | 256 | def train(self, feed_dict): 257 | _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict) 258 | return loss 259 | 260 | def fit(self, X_train, Y_train, X_val, Y_val, n_epoch=10): 261 | # initialize log directory 262 | if tf.gfile.Exists(self.logdir): tf.gfile.DeleteRecursively(self.logdir) 263 | tf.gfile.MakeDirs(self.logdir) 264 | 265 | # load some training params 266 | n_batch = self.opt_params['batch_size'] 267 | 268 | # create saver 269 | self.saver = tf.train.Saver() 270 | 271 | # summarization 272 | summary = tf.summary.merge_all() 273 | summary_writer = tf.summary.FileWriter(self.logdir, self.sess.graph) 274 | 275 | # load data into DataSet 276 | train_data = DataSet(X_train, Y_train) 277 | val_data = DataSet(X_val, Y_val) 278 | 279 | # train the model 280 | start_time = time.time() 281 | step, epoch = 0, train_data.epochs_completed 282 | 283 | print('start training epoch (n:%d)'%n_epoch) 284 | print('num-of-batch:',n_batch) 285 | 286 | while train_data.epochs_completed < n_epoch: 287 | 288 | step += 1 289 | # load the batch 290 | # alpha = min((n_epoch - train_data.epochs_completed) / 200, 1.) 291 | # alpha = 1.0 if epoch < 100 else 0.1 292 | alpha = 1.0 293 | 294 | print('get next batch from train data...') 295 | batch = train_data.next_batch(n_batch) 296 | print('...done') 297 | 298 | print('load batch and get feed-dict...') 299 | feed_dict = self.load_batch(batch, alpha) 300 | print('..done') 301 | 302 | # take training step 303 | print('train sequence start...') 304 | tr_objective = self.train(feed_dict) 305 | print('...done') 306 | 307 | tr_obj_snr = 20 * np.log10(1. / np.sqrt(tr_objective) + 1e-8) 308 | 309 | if step % 50 == 0: 310 | print(step, tr_objective, tr_obj_snr) 311 | 312 | # log results at the end of each epoch 313 | if train_data.epochs_completed > epoch: 314 | print('epoch-complete!') 315 | epoch = train_data.epochs_completed 316 | end_time = time.time() 317 | 318 | print('eval-err start...') 319 | tr_l2_loss, tr_l2_snr = self.eval_err(X_train, Y_train, n_batch=n_batch) 320 | va_l2_loss, va_l2_snr = self.eval_err(X_val, Y_val, n_batch=n_batch) 321 | print('...done!') 322 | 323 | print("Epoch {} of {} took {:.3f}s ({} minibatches)".format( 324 | epoch, n_epoch, end_time - start_time, len(X_train) // n_batch)) 325 | print(" training l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 326 | tr_l2_loss, tr_l2_snr)) 327 | print(" validation l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 328 | va_l2_loss, va_l2_snr)) 329 | 330 | # compute summaries for overall loss 331 | objectives_summary = tf.Summary() 332 | objectives_summary.value.add(tag='tr_l2_loss', simple_value=tr_l2_loss) 333 | objectives_summary.value.add(tag='tr_l2_snr' , simple_value=tr_l2_snr) 334 | objectives_summary.value.add(tag='va_l2_snr' , simple_value=va_l2_loss) 335 | 336 | # compute summaries for all other metrics 337 | summary_str = self.sess.run(summary, feed_dict=feed_dict) 338 | summary_writer.add_summary(summary_str, step) 339 | summary_writer.add_summary(objectives_summary, step) 340 | 341 | # write summaries and checkpoints 342 | summary_writer.flush() 343 | self.saver.save(self.sess, self.checkpoint_root, global_step=step) 344 | 345 | # restart clock 346 | start_time = time.time() 347 | 348 | 349 | def eval_err(self, X, Y, n_batch=128): 350 | batch_iterator = iterate_minibatches(X, Y, n_batch, shuffle=True) 351 | l2_loss_op, l2_snr_op = tf.get_collection('losses') 352 | 353 | l2_loss, snr = 0, 0 354 | tot_l2_loss, tot_snr = 0, 0 355 | for bn, batch in enumerate(batch_iterator): 356 | feed_dict = self.load_batch(batch, train=False) 357 | l2_loss, l2_snr = self.sess.run([l2_loss_op, l2_snr_op], feed_dict=feed_dict) 358 | tot_l2_loss += l2_loss 359 | tot_snr += l2_snr 360 | 361 | return tot_l2_loss / (bn+1), tot_snr / (bn+1) 362 | 363 | def predict(self, X): 364 | raise NotImplementedError() 365 | 366 | # ---------------------------------------------------------------------------- 367 | # helpers 368 | 369 | def iterate_minibatches(inputs, targets, batchsize, shuffle=False): 370 | assert len(inputs) == len(targets) 371 | if shuffle: 372 | indices = np.arange(len(inputs)) 373 | np.random.shuffle(indices) 374 | for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): 375 | if shuffle: 376 | excerpt = indices[start_idx:start_idx + batchsize] 377 | else: 378 | excerpt = slice(start_idx, start_idx + batchsize) 379 | yield inputs[excerpt], targets[excerpt] 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | -------------------------------------------------------------------------------- /practice.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "a1 = tf.constant([\n", 19 | " [[1,1],[1,1],[1,1]],\n", 20 | " [[1,1],[1,1],[1,1]],\n", 21 | " [[1,1],[1,1],[1,1]],\n", 22 | " [[1,1],[1,1],[1,1]]\n", 23 | "])\n", 24 | "\n", 25 | "a2 = tf.constant([\n", 26 | " [[1,1],[1,1]],\n", 27 | " [[1,1],[1,1]],\n", 28 | " [[1,1],[1,1]],\n", 29 | " [[1,1],[1,1]]\n", 30 | "])" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "array([4, 3, 2], dtype=int32)" 42 | ] 43 | }, 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "output_type": "execute_result" 47 | } 48 | ], 49 | "source": [ 50 | "sess = tf.Session()\n", 51 | "sess.run(tf.shape(a1))" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "array([4, 2, 2], dtype=int32)" 63 | ] 64 | }, 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "sess.run(tf.shape(a2))" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 9, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "0\n", 84 | "1\n", 85 | "2\n", 86 | "3\n", 87 | "4\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "for i in range(5):\n", 93 | " print(i)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python 3", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.5.2" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 2 132 | } 133 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .audiounet import AudioUNet -------------------------------------------------------------------------------- /src/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__init__.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/audiounet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__pycache__/audiounet.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/io.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__pycache__/io.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/audiounet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from scipy import interpolate 5 | from .model import Model, default_opt 6 | 7 | from .layers.subpixel import SubPixel1D, SubPixel1D_v2 8 | 9 | from keras import backend as K 10 | from keras.layers import merge, concatenate, add, Dropout 11 | from keras.layers.core import Activation 12 | from keras.layers.convolutional import Convolution1D 13 | from keras.layers.normalization import BatchNormalization 14 | from keras.layers.advanced_activations import LeakyReLU 15 | from keras.layers import Conv1D 16 | 17 | #from keras.initializations import normal, orthogonal # in Keras 1 18 | from keras.initializers import normal, orthogonal 19 | 20 | # ---------------------------------------------------------------------------- 21 | 22 | class AudioUNet(Model): 23 | """Generic tensorflow model training code""" 24 | 25 | def __init__(self, from_ckpt=False, n_dim=None, r=2, 26 | opt_params=default_opt, log_prefix='./run'): 27 | # perform the usual initialization 28 | self.r = r 29 | Model.__init__(self, from_ckpt=from_ckpt, n_dim=n_dim, r=r, 30 | opt_params=opt_params, log_prefix=log_prefix) 31 | 32 | def create_model(self, n_dim, r): 33 | # load inputs 34 | X, _, _ = self.inputs 35 | K.set_session(self.sess) 36 | 37 | with tf.name_scope('generator'): 38 | x = X 39 | L = self.layers 40 | # dim/layer: 4096, 2048, 1024, 512, 256, 128, 64, 32, 41 | # n_filters = [ 64, 128, 256, 384, 384, 384, 384, 384] 42 | n_filters = [128, 256, 512, 512, 512, 512, 512, 512] 43 | # n_filters = [ 256, 512, 512, 512, 512, 1024, 1024, 1024] 44 | # n_filtersizes = [129, 65, 33, 17, 9, 9, 9, 9] 45 | # n_filtersizes = [31, 31, 31, 31, 31, 31, 31, 31] 46 | n_filtersizes = [65, 33, 17, 9, 9, 9, 9, 9, 9] 47 | downsampling_l = [] 48 | 49 | print('building model...') 50 | 51 | # downsampling layers 52 | for l, nf, fs in zip(range(L), n_filters, n_filtersizes): 53 | with tf.name_scope('downsc_conv%d' % l): 54 | # in Keras 2 55 | x = Conv1D(padding='same', kernel_initializer='Orthogonal', filters=nf, kernel_size=fs, activation=None)(x) 56 | 57 | #x = (Convolution1D(nb_filter=nf, filter_length=fs, 58 | # activation=None, border_mode='same', init=orthogonal_init, 59 | # subsample_length=2))(x) 60 | # if l > 0: x = BatchNormalization(mode=2)(x) 61 | 62 | x = LeakyReLU(0.2)(x) 63 | print('D-Block: ', x.get_shape()) 64 | downsampling_l.append(x) 65 | 66 | # bottleneck layer 67 | with tf.name_scope('bottleneck_conv'): 68 | # in Keras 2 69 | x = Conv1D(padding='same', kernel_initializer='Orthogonal', 70 | filters=n_filters[-1], kernel_size=n_filtersizes[-1], activation=None)(x) 71 | 72 | #x = (Convolution1D(nb_filter=n_filters[-1], filter_length=n_filtersizes[-1], 73 | # activation=None, border_mode='same', init=orthogonal_init, 74 | # subsample_length=2))(x) 75 | x = Dropout(rate=0.5)(x) 76 | # x = BatchNormalization(mode=2)(x) 77 | x = LeakyReLU(0.2)(x) 78 | 79 | # upsampling layers 80 | # for l, nf, fs, l_in in reversed(zip(range(L), n_filters, n_filtersizes, downsampling_l)): 81 | for l, nf, fs, l_in in zip(reversed(range(L)), reversed(n_filters), reversed(n_filtersizes), reversed(downsampling_l)): 82 | with tf.name_scope('upsc_conv%d' % l): 83 | # (-1, n/2, 2f) 84 | # in Keras 2 85 | x = Conv1D(padding='same', kernel_initializer='Orthogonal', 86 | filters=2*nf, kernel_size=fs, activation=None)(x) 87 | #x = (Convolution1D(nb_filter=2*nf, filter_length=fs, 88 | # activation=None, border_mode='same', init=orthogonal_init))(x) 89 | # x = BatchNormalization(mode=2)(x) 90 | x = Dropout(rate=0.5)(x) 91 | x = Activation('relu')(x) 92 | # (-1, n, f) 93 | x = SubPixel1D(x, r=2) 94 | # (-1, n, 2f) 95 | 96 | # in Keras 2 97 | x = concatenate([x, l_in])# axis = -1 (by default) 98 | 99 | # in Keras 1 100 | #x = merge([x, l_in], mode='concat', concat_axis=-1) 101 | 102 | print('U-Block: ', x.get_shape()) 103 | 104 | # final conv layer 105 | with tf.name_scope('lastconv'): 106 | # in Keras 2 107 | x = Conv1D(padding='same', kernel_initializer='he_normal', 108 | filters=2, kernel_size=9, activation=None)(x) 109 | #x = Convolution1D(nb_filter=2, filter_length=9, 110 | # activation=None, border_mode='same', init=normal_init)(x) 111 | x = SubPixel1D(x, r=2) 112 | print(x.get_shape()) 113 | 114 | # in Keras 2 115 | g = add([x,X]) 116 | 117 | # in Keras 1 118 | #g = merge([x, X], mode='sum') 119 | 120 | return g 121 | 122 | def predict(self, X): 123 | assert len(X) == 1 124 | x_sp = spline_up(X, self.r) 125 | x_sp = x_sp[:len(x_sp) - (len(x_sp) % (2**(self.layers+1)))] 126 | X = x_sp.reshape((1,len(x_sp),1)) 127 | feed_dict = self.load_batch((X,X), train=False) 128 | return self.sess.run(self.predictions, feed_dict=feed_dict) 129 | 130 | # ---------------------------------------------------------------------------- 131 | # helpers 132 | 133 | def normal_init(shape, dim_ordering='tf', name=None): 134 | return normal(shape, scale=1e-3, name=name, dim_ordering=dim_ordering) 135 | 136 | def orthogonal_init(shape, dim_ordering='tf', name=None): 137 | return orthogonal(shape, name=name, dim_ordering=dim_ordering) 138 | 139 | def spline_up(x_lr, r): 140 | x_lr = x_lr.flatten() 141 | x_hr_len = len(x_lr) * r 142 | x_sp = np.zeros(x_hr_len) 143 | 144 | i_lr = np.arange(x_hr_len, step=r) 145 | i_hr = np.arange(x_hr_len) 146 | 147 | f = interpolate.splrep(i_lr, x_lr) 148 | 149 | x_sp = interpolate.splev(i_hr, f) 150 | 151 | return x_sp -------------------------------------------------------------------------------- /src/models/audiounet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/audiounet.pyc -------------------------------------------------------------------------------- /src/models/dataset.py: -------------------------------------------------------------------------------- 1 | """Class for doing iterations over datasets 2 | 3 | This is stolen from the tensorflow tutorial 4 | """ 5 | 6 | import numpy 7 | 8 | from tensorflow.contrib.learn.python.learn.datasets import base 9 | from tensorflow.python.framework import dtypes 10 | 11 | # ---------------------------------------------------------------------------- 12 | 13 | class DataSet(object): 14 | 15 | def __init__(self, 16 | datapoints, 17 | labels, 18 | fake_data=False, 19 | one_hot=False, 20 | dtype=dtypes.float32): 21 | """Construct a DataSet. 22 | one_hot arg is used only if fake_data is true. `dtype` can be either 23 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 24 | `[0, 1]`. 25 | """ 26 | dtype = dtypes.as_dtype(dtype).base_dtype 27 | if dtype not in (dtypes.uint8, dtypes.float32): 28 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 29 | dtype) 30 | 31 | if labels is None: 32 | labels = np.zeros((len(datapoints),)) 33 | 34 | if fake_data: 35 | self._num_examples = 10000 36 | self.one_hot = one_hot 37 | else: 38 | assert datapoints.shape[0] == labels.shape[0], ( 39 | 'datapoints.shape: %s labels.shape: %s' % (datapoints.shape, labels.shape)) 40 | self._num_examples = datapoints.shape[0] 41 | 42 | self._datapoints = datapoints 43 | self._labels = labels 44 | self._epochs_completed = 0 45 | self._index_in_epoch = 0 46 | 47 | @property 48 | def datapoints(self): 49 | return self._datapoints 50 | 51 | @property 52 | def labels(self): 53 | return self._labels 54 | 55 | @property 56 | def num_examples(self): 57 | return self._num_examples 58 | 59 | @property 60 | def epochs_completed(self): 61 | return self._epochs_completed 62 | 63 | def next_batch(self, batch_size, fake_data=False, shuffle=True): 64 | """Return the next `batch_size` examples from this data set.""" 65 | if fake_data: 66 | fake_image = [1] * 784 67 | if self.one_hot: 68 | fake_label = [1] + [0] * 9 69 | else: 70 | fake_label = 0 71 | return [fake_image for _ in xrange(batch_size)], [ 72 | fake_label for _ in xrange(batch_size) 73 | ] 74 | 75 | start = self._index_in_epoch 76 | # Shuffle for the first epoch 77 | if self._epochs_completed == 0 and start == 0 and shuffle: 78 | perm0 = numpy.arange(self._num_examples) 79 | numpy.random.shuffle(perm0) 80 | self._datapoints = self.datapoints[perm0] 81 | self._labels = self.labels[perm0] 82 | # Go to the next epoch 83 | if start + batch_size > self._num_examples: 84 | # Finished epoch 85 | self._epochs_completed += 1 86 | 87 | # Get the rest examples in this epoch 88 | rest_num_examples = self._num_examples - start 89 | datapoints_rest_part = self._datapoints[start:self._num_examples] 90 | labels_rest_part = self._labels[start:self._num_examples] 91 | 92 | # Shuffle the data 93 | if shuffle: 94 | perm = numpy.arange(self._num_examples) 95 | numpy.random.shuffle(perm) 96 | self._datapoints = self.datapoints[perm] 97 | self._labels = self.labels[perm] 98 | 99 | # Start next epoch 100 | start = 0 101 | self._index_in_epoch = batch_size - rest_num_examples 102 | end = self._index_in_epoch 103 | datapoints_new_part = self._datapoints[start:end] 104 | labels_new_part = self._labels[start:end] 105 | return numpy.concatenate((datapoints_rest_part, datapoints_new_part), axis=0),numpy.concatenate((labels_rest_part, labels_new_part), axis=0) 106 | else: 107 | self._index_in_epoch += batch_size 108 | end = self._index_in_epoch 109 | return self._datapoints[start:end], self._labels[start:end] 110 | -------------------------------------------------------------------------------- /src/models/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import librosa 5 | 6 | from scipy.signal import decimate 7 | 8 | from matplotlib import pyplot as plt 9 | 10 | # ---------------------------------------------------------------------------- 11 | 12 | def load_h5(h5_path): 13 | # load training data 14 | with h5py.File(h5_path, 'r') as hf: 15 | print ('List of arrays in input file:', hf.keys()) 16 | X = np.array(hf.get('data')) 17 | Y = np.array(hf.get('label')) 18 | print ('Shape of X:', X.shape) 19 | print ('Shape of Y:', Y.shape) 20 | 21 | return X, Y 22 | 23 | def upsample_wav(wav, args, model): 24 | # load signal 25 | x_hr, fs = librosa.load(wav, sr=args.sr) 26 | 27 | # downscale signal 28 | # x_lr = np.array(x_hr[0::args.r]) 29 | x_lr = decimate(x_hr, args.r) 30 | # x_lr = decimate(x_hr, args.r, ftype='fir', zero_phase=True) 31 | # x_lr = downsample_bt(x_hr, args.r) 32 | 33 | # upscale the low-res version 34 | P = model.predict(x_lr.reshape((1,len(x_lr),1))) 35 | x_pr = P.flatten() 36 | 37 | # crop so that it works with scaling ratio 38 | x_hr = x_hr[:len(x_pr)] 39 | x_lr = x_lr[:len(x_pr)] 40 | 41 | # save the file 42 | outname = wav + '.' + args.out_label 43 | librosa.output.write_wav(outname + '.hr.wav', x_hr, fs) 44 | librosa.output.write_wav(outname + '.lr.wav', x_lr, fs / args.r) 45 | librosa.output.write_wav(outname + '.pr.wav', x_pr, fs) 46 | 47 | # save the spectrum 48 | S = get_spectrum(x_pr, n_fft=2048) 49 | save_spectrum(S, outfile=outname + '.pr.png') 50 | S = get_spectrum(x_hr, n_fft=2048) 51 | save_spectrum(S, outfile=outname + '.hr.png') 52 | S = get_spectrum(x_lr, n_fft=2048/args.r) 53 | save_spectrum(S, outfile=outname + '.lr.png') 54 | 55 | # ---------------------------------------------------------------------------- 56 | 57 | def get_spectrum(x, n_fft=2048): 58 | S = librosa.stft(x, n_fft) 59 | p = np.angle(S) 60 | S = np.log1p(np.abs(S)) 61 | return S 62 | 63 | def save_spectrum(S, lim=800, outfile='spectrogram.png'): 64 | plt.imshow(S.T, aspect=10) 65 | # plt.xlim([0,lim]) 66 | plt.tight_layout() 67 | plt.savefig(outfile) -------------------------------------------------------------------------------- /src/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/layers/__init__.py -------------------------------------------------------------------------------- /src/models/layers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/layers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/layers/__pycache__/subpixel.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leekh7411/Audio-Super-Resolution-Python3-TF/37cb46c37bde183392dc6109a6309330f24c6dcf/src/models/layers/__pycache__/subpixel.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/layers/standard.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from keras.layers.advanced_activations import PReLU 5 | 6 | from .summarization import create_var_summaries 7 | 8 | # ---------------------------------------------------------------------------- 9 | 10 | def conv1d(x, n_filters, n_size, stride=1, nl='relu', name='conv1d'): 11 | n_batch, n_dim, n_input_chan = x.get_shape() 12 | with tf.variable_scope(name): 13 | # create and track weights 14 | with tf.name_scope('weights'): 15 | W = tf.get_variable('W', shape=[n_size, n_input_chan, n_filters], 16 | initializer=tf.random_normal_initializer(stddev=1e-3)) 17 | create_var_summaries(W) 18 | 19 | # create and track biases 20 | with tf.name_scope('biases'): 21 | b = tf.get_variable('b', [n_filters], initializer=tf.constant_initializer(0.)) 22 | create_var_summaries(b) 23 | 24 | # create and track pre-activations 25 | with tf.name_scope('preactivations'): 26 | x = tf.nn.conv1d(x, W, stride=1, padding='SAME') 27 | x = tf.nn.bias_add(x, b) 28 | tf.summary.histogram('preactivations', x) 29 | 30 | # create and track activations 31 | if nl == 'relu': 32 | x = tf.nn.relu(x) 33 | elif nl == 'prelu': 34 | x = PReLU()(x) 35 | elif nl == None: 36 | pass 37 | else: 38 | raise ValueError('Invalid non-linearity') 39 | 40 | tf.summary.histogram('activations', x) 41 | 42 | return x 43 | 44 | def deconv1d(x, r, n_chan, n_in_dim, n_in_chan, name='deconv1d'): 45 | x = tf.reshape(x, [128, 1, n_in_dim, n_in_chan]) 46 | with tf.variable_scope(name): 47 | # filter : [height, width, output_channels, in_channels] 48 | W = tf.get_variable('W', shape=[1, r, n_chan, n_in_chan], 49 | initializer=tf.random_normal_initializer(stddev=1e-3)) 50 | b = tf.get_variable('b', [n_chan], initializer=tf.constant_initializer(0.)) 51 | 52 | x = tf.nn.conv2d_transpose(x, W, output_shape=(128, 1, r*n_in_dim, n_chan), 53 | strides=[1, 1, r, 1]) 54 | x = tf.nn.bias_add(x, b) 55 | 56 | return tf.reshape(x, [-1, r*n_in_dim, n_chan]) -------------------------------------------------------------------------------- /src/models/layers/subpixel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | # ---------------------------------------------------------------------------- 5 | 6 | def SubPixel1D_v2(I, r): 7 | """One-dimensional subpixel upsampling layer 8 | 9 | Based on https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py 10 | """ 11 | with tf.name_scope('subpixel'): 12 | bsize, a, r = I.get_shape().as_list() 13 | bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim 14 | X = tf.split(1, a, I) # a, [bsize, 1, r] 15 | if 'axis' in tf.squeeze.func_code.co_varnames: 16 | X = tf.concat(1, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r 17 | elif 'squeeze_dims' in tf.squeeze.func_code.co_varnames: 18 | X = tf.concat(1, [tf.squeeze(x, squeeze_dims=[1]) for x in X]) # bsize, a*r 19 | else: 20 | raise Exception('Unsupported version of tensorflow') 21 | 22 | return tf.reshape(X, (bsize, a*r, 1)) 23 | 24 | def SubPixel1D(I, r): 25 | """One-dimensional subpixel upsampling layer 26 | 27 | Calls a tensorflow function that directly implements this functionality. 28 | We assume input has dim (batch, width, r) 29 | """ 30 | with tf.name_scope('subpixel'): 31 | X = tf.transpose(I, [2,1,0]) # (r, w, b) 32 | X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (1, r*w, b) 33 | X = tf.transpose(X, [2,1,0]) 34 | return X 35 | 36 | def SubPixel1D_multichan(I, r): 37 | """One-dimensional subpixel upsampling layer 38 | 39 | Calls a tensorflow function that directly implements this functionality. 40 | We assume input has dim (batch, width, r). 41 | 42 | Works with multiple channels: (B,L,rC) -> (B,rL,C) 43 | """ 44 | with tf.name_scope('subpixel'): 45 | _, w, rc = I.get_shape() 46 | assert rc % r == 0 47 | c = rc / r 48 | X = tf.transpose(I, [2,1,0]) # (rc, w, b) 49 | X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (c, r*w, b) 50 | X = tf.transpose(X, [2,1,0]) 51 | return X 52 | 53 | # ---------------------------------------------------------------------------- 54 | 55 | # demonstration 56 | if __name__ == "__main__": 57 | with tf.Session() as sess: 58 | x = np.arange(2*4*2).reshape(2, 4, 2) 59 | X = tf.placeholder("float32", shape=(2, 4, 2), name="X") 60 | Y = SubPixel1D(X, 2) 61 | y = sess.run(Y, feed_dict={X: x}) 62 | 63 | print ('single-channel:') 64 | print ('original, element 0 (2 channels):', x[0,:,0], x[0,:,1]) 65 | print ('rescaled, element 1:', y[0,:,0]) 66 | print () 67 | print ('original, element 0 (2 channels) :', x[1,:,0], x[1,:,1]) 68 | print ('rescaled, element 1:', y[1,:,0]) 69 | print () 70 | 71 | x = np.arange(2*4*4).reshape(2, 4, 4) 72 | X = tf.placeholder("float32", shape=(2, 4, 4), name="X") 73 | Y = SubPixel1D(X, 2) 74 | y = sess.run(Y, feed_dict={X: x}) 75 | 76 | print ('multichannel:') 77 | print ('original, element 0 (4 channels):', x[0,:,0], x[0,:,1], x[0,:,2], x[0,:,3]) 78 | print ('rescaled, element 1:', y[0,:,0], y[0,:,1]) 79 | print () 80 | print ('original, element 0 (2 channels) :', x[1,:,0], x[1,:,1], x[1,:,2], x[1,:,3]) 81 | print ('rescaled, element 1:', y[1,:,0], y[1,:,1]) -------------------------------------------------------------------------------- /src/models/layers/summarization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # ---------------------------------------------------------------------------- 4 | 5 | def create_var_summaries(var): 6 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 7 | with tf.name_scope('summaries'): 8 | mean = tf.reduce_mean(var) 9 | tf.summary.scalar('mean', mean) 10 | with tf.name_scope('stddev'): 11 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 12 | tf.summary.scalar('stddev', stddev) 13 | tf.summary.scalar('max', tf.reduce_max(var)) 14 | tf.summary.scalar('min', tf.reduce_min(var)) 15 | tf.summary.histogram('histogram', var) -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import librosa 8 | from keras import backend as K 9 | from .dataset import DataSet 10 | 11 | # ---------------------------------------------------------------------------- 12 | 13 | default_opt = { 'alg': 'adam', 'lr': 1e-4, 'b1': 0.99, 'b2': 0.999, 14 | 'layers': 4, 'batch_size': 128 } 15 | 16 | class Model(object): 17 | """Generic tensorflow model training code""" 18 | 19 | def __init__(self, from_ckpt=False, n_dim=None, r=2,opt_params=default_opt, log_prefix='./run'): 20 | 21 | # create session 22 | # if use CPU ? 23 | #self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) 24 | 25 | # if use GPU ? 26 | gpu_options = tf.GPUOptions(allow_growth=False) 27 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) 28 | 29 | 30 | K.set_session(self.sess) # pass keras the session 31 | 32 | # save params 33 | self.opt_params = opt_params 34 | self.layers = opt_params['layers'] 35 | 36 | if from_ckpt: 37 | pass # we will instead load the graph from a checkpoint 38 | else: 39 | # create input vars 40 | X = tf.placeholder(tf.float32, shape=(None, None, 1), name='X') 41 | Y = tf.placeholder(tf.float32, shape=(None, None, 1), name='Y') 42 | alpha = tf.placeholder(tf.float32, shape=(), name='alpha') # weight multiplier 43 | 44 | # save inputs 45 | self.inputs = (X, Y, alpha) 46 | tf.add_to_collection('inputs', X) 47 | tf.add_to_collection('inputs', Y) 48 | tf.add_to_collection('inputs', alpha) 49 | 50 | # create model outputs 51 | self.predictions = self.create_model(n_dim, r) 52 | tf.add_to_collection('preds', self.predictions) 53 | 54 | # init the model 55 | init = tf.global_variables_initializer() 56 | self.sess.run(init) 57 | 58 | # create training updates 59 | self.train_op = self.create_train_op(X, Y, alpha) 60 | tf.add_to_collection('train_op', self.train_op) 61 | 62 | # logging 63 | lr_str = '.' + 'lr%f' % opt_params['lr'] 64 | g_str = '.g%d' % self.layers 65 | b_str = '.b%d' % int(opt_params['batch_size']) 66 | 67 | self.logdir = log_prefix + lr_str + '.%d' % r + g_str + b_str 68 | self.checkpoint_root = os.path.join(self.logdir, 'model.ckpt') 69 | 70 | def create_train_op(self, X, Y, alpha): 71 | # load params 72 | opt_params = self.opt_params 73 | print('creating train_op with params:', opt_params) 74 | 75 | # create loss 76 | self.loss = self.create_objective(X, Y, opt_params) 77 | 78 | # create params 79 | params = self.get_params() 80 | 81 | # create optimizer 82 | self.optimizer = self.create_optimzier(opt_params) 83 | 84 | # create gradients 85 | grads = self.create_gradients(self.loss, params) 86 | 87 | # create training op 88 | with tf.name_scope('optimizer'): 89 | train_op = self.create_updates(params, grads, alpha, opt_params) 90 | 91 | # initialize the optimizer variabLes 92 | optimizer_vars = [] 93 | for v in tf.global_variables(): 94 | if 'optimizer/' in v.name: 95 | optimizer_vars.append(v) 96 | 97 | #optimizer_vars = [ v for v in tf.global_variables() if 'optimizer/' in v.name ] 98 | 99 | init = tf.variables_initializer(optimizer_vars) 100 | self.sess.run(init) 101 | 102 | return train_op 103 | 104 | def create_model(self, n_dim, r): 105 | raise NotImplementedError() 106 | 107 | def create_objective(self, X, Y, opt_params): 108 | # load model output and true output 109 | P = self.predictions 110 | 111 | # compute l2 loss 112 | sqrt_l2_loss = tf.sqrt(tf.reduce_mean((P-Y)**2 + 1e-6, axis=[1,2])) 113 | sqrn_l2_norm = tf.sqrt(tf.reduce_mean(Y**2, axis=[1,2])) 114 | snr = 20 * tf.log(sqrn_l2_norm / sqrt_l2_loss + 1e-8) / tf.log(10.) 115 | 116 | avg_sqrt_l2_loss = tf.reduce_mean(sqrt_l2_loss, axis=0) 117 | avg_snr = tf.reduce_mean(snr, axis=0) 118 | 119 | # track losses 120 | tf.summary.scalar('l2_loss', avg_sqrt_l2_loss) 121 | tf.summary.scalar('snr', avg_snr) 122 | 123 | # save losses into collection 124 | tf.add_to_collection('losses', avg_sqrt_l2_loss) 125 | tf.add_to_collection('losses', avg_snr) 126 | 127 | return avg_sqrt_l2_loss 128 | 129 | def get_params(self): 130 | return [ v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 131 | if 'soundnet' not in v.name ] 132 | 133 | def create_optimzier(self, opt_params): 134 | if opt_params['alg'] == 'adam': 135 | lr, b1, b2 = opt_params['lr'], opt_params['b1'], opt_params['b2'] 136 | optimizer = tf.train.AdamOptimizer(lr, b1, b2) 137 | else: 138 | raise ValueError('Invalid optimizer: ' + opt_params['alg']) 139 | 140 | return optimizer 141 | 142 | def create_gradients(self, loss, params): 143 | gv = self.optimizer.compute_gradients(loss, params) 144 | g, v = zip(*gv) 145 | return g 146 | 147 | def create_updates(self, params, grads, alpha, opt_params): 148 | # create a variable to track the global step. 149 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 150 | 151 | # update grads 152 | grads = [alpha*g for g in grads] 153 | 154 | # use the optimizer to apply the gradients that minimize the loss 155 | gv = zip(grads, params) 156 | train_op = self.optimizer.apply_gradients(gv, global_step=self.global_step) 157 | 158 | return train_op 159 | 160 | def load(self, ckpt): 161 | # get checkpoint name 162 | if os.path.isdir(ckpt): checkpoint = tf.train.latest_checkpoint(ckpt) 163 | else: checkpoint = ckpt 164 | meta = checkpoint + '.meta' 165 | print(checkpoint) 166 | 167 | # load graph 168 | self.saver = tf.train.import_meta_graph(meta) 169 | g = tf.get_default_graph() 170 | 171 | # load weights 172 | self.saver.restore(self.sess, checkpoint) 173 | 174 | # get graph tensors 175 | X, Y, alpha = tf.get_collection('inputs') 176 | 177 | # save tensors as instance variables 178 | self.inputs = X, Y, alpha 179 | self.predictions = tf.get_collection('preds')[0] 180 | 181 | # load existing loss, or erase it, if creating new one 182 | g.clear_collection('losses') 183 | 184 | # create a new training op 185 | self.train_op = self.create_train_op(X, Y, alpha) 186 | g.clear_collection('train_op') 187 | tf.add_to_collection('train_op', self.train_op) 188 | 189 | # or, get existing train op: 190 | # self.train_op = tf.get_collection('train_op') 191 | 192 | def fit(self, X_train, Y_train, X_val, Y_val, n_epoch=100): 193 | # initialize log directory 194 | if tf.gfile.Exists(self.logdir): tf.gfile.DeleteRecursively(self.logdir) 195 | tf.gfile.MakeDirs(self.logdir) 196 | 197 | # load some training params 198 | n_batch = self.opt_params['batch_size'] 199 | 200 | # create saver 201 | self.saver = tf.train.Saver() 202 | 203 | # summarization 204 | summary = tf.summary.merge_all() 205 | summary_writer = tf.summary.FileWriter(self.logdir, self.sess.graph) 206 | 207 | # load data into DataSet 208 | train_data = DataSet(X_train, Y_train) 209 | val_data = DataSet(X_val, Y_val) 210 | 211 | # train the model 212 | start_time = time.time() 213 | step, epoch = 0, train_data.epochs_completed 214 | while train_data.epochs_completed < n_epoch: 215 | 216 | step += 1 217 | 218 | # load the batch 219 | # alpha = min((n_epoch - train_data.epochs_completed) / 200, 1.) 220 | # alpha = 1.0 if epoch < 100 else 0.1 221 | alpha = 1.0 222 | batch = train_data.next_batch(n_batch) 223 | feed_dict = self.load_batch(batch, alpha) 224 | 225 | # take training step 226 | tr_objective = self.train(feed_dict) 227 | # tr_obj_snr = 20 * np.log10(1. / np.sqrt(tr_objective) + 1e-8) 228 | # if step % 50 == 0: 229 | # print step, tr_objective, tr_obj_snr 230 | 231 | # log results at the end of each epoch 232 | if train_data.epochs_completed > epoch: 233 | epoch = train_data.epochs_completed 234 | end_time = time.time() 235 | 236 | tr_l2_loss, tr_l2_snr = self.eval_err(X_train, Y_train, n_batch=n_batch) 237 | va_l2_loss, va_l2_snr = self.eval_err(X_val, Y_val, n_batch=n_batch) 238 | 239 | print("Epoch {} of {} took {:.3f}s ({} minibatches)".format( 240 | epoch, n_epoch, end_time - start_time, len(X_train) // n_batch)) 241 | print(" training l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 242 | tr_l2_loss, tr_l2_snr)) 243 | print(" validation l2_loss/segsnr:\t\t{:.6f}\t{:.6f}".format( 244 | va_l2_loss, va_l2_snr)) 245 | 246 | # compute summaries for overall loss 247 | objectives_summary = tf.Summary() 248 | objectives_summary.value.add(tag='tr_l2_loss', simple_value=tr_l2_loss) 249 | objectives_summary.value.add(tag='tr_l2_snr' , simple_value=tr_l2_snr) 250 | objectives_summary.value.add(tag='va_l2_snr' , simple_value=va_l2_loss) 251 | 252 | # compute summaries for all other metrics 253 | summary_str = self.sess.run(summary, feed_dict=feed_dict) 254 | summary_writer.add_summary(summary_str, step) 255 | summary_writer.add_summary(objectives_summary, step) 256 | 257 | # write summaries and checkpoints 258 | summary_writer.flush() 259 | self.saver.save(self.sess, self.checkpoint_root, global_step=step) 260 | 261 | # restart clock 262 | start_time = time.time() 263 | 264 | def train(self, feed_dict): 265 | _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict) 266 | return loss 267 | 268 | def load_batch(self, batch, alpha=1, train=True): 269 | X_in, Y_in, alpha_in = self.inputs 270 | X, Y = batch 271 | 272 | if Y is not None: 273 | feed_dict = {X_in : X, Y_in : Y, alpha_in : alpha} 274 | else: 275 | feed_dict = {X_in : X, alpha_in : alpha} 276 | 277 | # this is ugly, but only way I found to get this var after model reload 278 | g = tf.get_default_graph() 279 | 280 | k_tensors = [] 281 | for n in g.as_graph_def().node: 282 | if 'keras_learning_phase' in n.name and 'input' not in n.name: 283 | print('tf.default_graph.node:',n.name) 284 | k_tensors.append(n) 285 | 286 | #k_tensors = [n for n in g.as_graph_def().node if 'keras_learning_phase' in n.name] 287 | 288 | # ?????????????????????????/ 289 | #assert len(k_tensors) <= 1 290 | assert len(k_tensors) <= 1 291 | 292 | if k_tensors: 293 | k_learning_phase = g.get_tensor_by_name(k_tensors[0].name + ':0') 294 | feed_dict[k_learning_phase] = train 295 | 296 | return feed_dict 297 | 298 | def eval_err(self, X, Y, n_batch=128): 299 | batch_iterator = iterate_minibatches(X, Y, n_batch, shuffle=True) 300 | l2_loss_op, l2_snr_op = tf.get_collection('losses') 301 | 302 | l2_loss, snr = 0, 0 303 | tot_l2_loss, tot_snr = 0, 0 304 | for bn, batch in enumerate(batch_iterator): 305 | feed_dict = self.load_batch(batch, train=False) 306 | l2_loss, l2_snr = self.sess.run([l2_loss_op, l2_snr_op], feed_dict=feed_dict) 307 | tot_l2_loss += l2_loss 308 | tot_snr += l2_snr 309 | 310 | return tot_l2_loss / (bn+1), tot_snr / (bn+1) 311 | 312 | def predict(self, X): 313 | raise NotImplementedError() 314 | 315 | # ---------------------------------------------------------------------------- 316 | # helpers 317 | 318 | def iterate_minibatches(inputs, targets, batchsize, shuffle=False): 319 | assert len(inputs) == len(targets) 320 | if shuffle: 321 | indices = np.arange(len(inputs)) 322 | np.random.shuffle(indices) 323 | for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): 324 | if shuffle: 325 | excerpt = indices[start_idx:start_idx + batchsize] 326 | else: 327 | excerpt = slice(start_idx, start_idx + batchsize) 328 | yield inputs[excerpt], targets[excerpt] -------------------------------------------------------------------------------- /standard.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from summarization import create_var_summaries 5 | 6 | # ---------------------------------------------------------------------------- 7 | 8 | def parametric_relu(_x): 9 | alphas = tf.get_variable('alpha', _x.get_shape()[-1], 10 | initializer=tf.constant_initializer(0.0), 11 | dtype=tf.float32) 12 | pos = tf.nn.relu(_x) 13 | neg = alphas * (_x - abs(_x)) * 0.5 14 | 15 | return pos + neg 16 | 17 | 18 | def conv1d(x, n_filters, n_size, stride=1, nl='relu', name='conv1d', dropOut=False): 19 | n_batch, n_dim, n_input_chan = x.get_shape() 20 | #with tf.variable_scope(name): 21 | 22 | # create and track weights 23 | #with tf.name_scope('weights') as scope: 24 | W = tf.get_variable('W', shape=[n_size, n_input_chan, n_filters], initializer=tf.random_normal_initializer(stddev=1e-3)) 25 | create_var_summaries(W) 26 | 27 | # create and track biases 28 | #with tf.name_scope('biases'): 29 | b = tf.get_variable('b', [n_filters], initializer=tf.constant_initializer(0.)) 30 | create_var_summaries(b) 31 | 32 | # create drop out layer 33 | if dropOut: 34 | #with tf.name_scope('dropout'): 35 | x = tf.layers.dropout(x,training=dropOut) # default = 0.5 36 | 37 | # create and track pre-activations 38 | #with tf.name_scope('preactivations'): 39 | x = tf.nn.conv1d(x, W, stride=1, padding='SAME') 40 | x = tf.nn.bias_add(x, b) 41 | tf.summary.histogram('preactivations', x) 42 | 43 | # create and track activations 44 | if nl == 'relu': 45 | x = tf.nn.relu(x) 46 | elif nl == 'prelu': 47 | x = parametric_relu(x) 48 | elif nl == None: 49 | pass 50 | else: 51 | raise ValueError('Invalid non-linearity') 52 | 53 | tf.summary.histogram('activations', x) 54 | 55 | return x 56 | 57 | 58 | def deconv1d(x, r, n_chan, n_in_dim, n_in_chan, name='deconv1d'): 59 | x = tf.reshape(x, [128, 1, n_in_dim, n_in_chan]) 60 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 61 | # filter : [height, width, output_channels, in_channels] 62 | W = tf.get_variable('W', shape=[1, r, n_chan, n_in_chan], 63 | initializer=tf.random_normal_initializer(stddev=1e-3)) 64 | b = tf.get_variable('b', [n_chan], initializer=tf.constant_initializer(0.)) 65 | 66 | x = tf.nn.conv2d_transpose(x, W, output_shape=(128, 1, r*n_in_dim, n_chan), 67 | strides=[1, 1, r, 1]) 68 | x = tf.nn.bias_add(x, b) 69 | 70 | return tf.reshape(x, [-1, r*n_in_dim, n_chan]) -------------------------------------------------------------------------------- /subpixel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | # ---------------------------------------------------------------------------- 5 | 6 | def SubPixel1D_v2(I, r): 7 | """One-dimensional subpixel upsampling layer 8 | 9 | Based on https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py 10 | """ 11 | with tf.name_scope('subpixel'): 12 | bsize, a, r = I.get_shape().as_list() 13 | bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim 14 | X = tf.split(1, a, I) # a, [bsize, 1, r] 15 | if 'axis' in tf.squeeze.func_code.co_varnames: 16 | X = tf.concat(1, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r 17 | elif 'squeeze_dims' in tf.squeeze.func_code.co_varnames: 18 | X = tf.concat(1, [tf.squeeze(x, squeeze_dims=[1]) for x in X]) # bsize, a*r 19 | else: 20 | raise Exception('Unsupported version of tensorflow') 21 | 22 | return tf.reshape(X, (bsize, a*r, 1)) 23 | 24 | def SubPixel1D(I, r): 25 | """One-dimensional subpixel upsampling layer 26 | 27 | Calls a tensorflow function that directly implements this functionality. 28 | We assume input has dim (batch, width, r) 29 | """ 30 | with tf.name_scope('subpixel'): 31 | X = tf.transpose(I, [2,1,0]) # (r, w, b) 32 | X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (1, r*w, b) 33 | X = tf.transpose(X, [2,1,0]) 34 | return X 35 | 36 | def SubPixel1D_multichan(I, r): 37 | """One-dimensional subpixel upsampling layer 38 | 39 | Calls a tensorflow function that directly implements this functionality. 40 | We assume input has dim (batch, width, r). 41 | 42 | Works with multiple channels: (B,L,rC) -> (B,rL,C) 43 | """ 44 | with tf.name_scope('subpixel'): 45 | _, w, rc = I.get_shape() 46 | assert rc % r == 0 47 | c = rc / r 48 | X = tf.transpose(I, [2,1,0]) # (rc, w, b) 49 | X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (c, r*w, b) 50 | X = tf.transpose(X, [2,1,0]) 51 | return X 52 | 53 | # ---------------------------------------------------------------------------- 54 | 55 | # demonstration 56 | if __name__ == "__main__": 57 | with tf.Session() as sess: 58 | x = np.arange(2*4*2).reshape(2, 4, 2) 59 | X = tf.placeholder("float32", shape=(2, 4, 2), name="X") 60 | Y = SubPixel1D(X, 2) 61 | y = sess.run(Y, feed_dict={X: x}) 62 | 63 | print ('single-channel:') 64 | print ('original, element 0 (2 channels):', x[0,:,0], x[0,:,1]) 65 | print ('rescaled, element 1:', y[0,:,0]) 66 | print () 67 | print ('original, element 0 (2 channels) :', x[1,:,0], x[1,:,1]) 68 | print ('rescaled, element 1:', y[1,:,0]) 69 | print () 70 | 71 | x = np.arange(2*4*4).reshape(2, 4, 4) 72 | X = tf.placeholder("float32", shape=(2, 4, 4), name="X") 73 | Y = SubPixel1D(X, 2) 74 | y = sess.run(Y, feed_dict={X: x}) 75 | 76 | print ('multichannel:') 77 | print ('original, element 0 (4 channels):', x[0,:,0], x[0,:,1], x[0,:,2], x[0,:,3]) 78 | print ('rescaled, element 1:', y[0,:,0], y[0,:,1]) 79 | print () 80 | print ('original, element 0 (2 channels) :', x[1,:,0], x[1,:,1], x[1,:,2], x[1,:,3]) 81 | print ('rescaled, element 1:', y[1,:,0], y[1,:,1]) -------------------------------------------------------------------------------- /summarization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # ---------------------------------------------------------------------------- 4 | 5 | def create_var_summaries(var): 6 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 7 | with tf.name_scope('summaries'): 8 | mean = tf.reduce_mean(var) 9 | tf.summary.scalar('mean', mean) 10 | with tf.name_scope('stddev'): 11 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 12 | tf.summary.scalar('stddev', stddev) 13 | tf.summary.scalar('max', tf.reduce_max(var)) 14 | tf.summary.scalar('min', tf.reduce_min(var)) 15 | tf.summary.histogram('histogram', var) -------------------------------------------------------------------------------- /temp.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5415a2a27fc6bf576e43dd33d03c6bbb3ca918a3125098e0f5d93d8197b13985 3 | size 1672 4 | -------------------------------------------------------------------------------- /train.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fb30335022ce394282e8dc972b473fdbb1d57934ef15e3df244f2045b5c6890c 3 | size 111675392 4 | -------------------------------------------------------------------------------- /valid.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ed5e61aea96a163423d7744bfd2f97c69c7f624fa2a97908561c8e3e7b5f0de1 3 | size 37619712 4 | --------------------------------------------------------------------------------