├── web ├── img │ └── favicon.ico ├── ckpts │ ├── dumper │ │ ├── dump.sh │ │ ├── dump_checkpoint_vars.py │ │ ├── tensorflow_checkpoint_dumper.py │ │ └── checkpoint_dumper.py │ └── drums │ │ └── manifest.json ├── dev.py ├── bundle.py ├── js │ ├── wavegan_reqs.js │ ├── wavegan_cfg.js │ ├── wavegan_player.js │ ├── wavegan_visualizer.js │ ├── wavegan_savewav.js │ ├── wavegan_net.js │ ├── wavegan_sequencer.js │ └── wavegan_ui.js ├── css │ └── wavegan.css └── index.html ├── images ├── spectrogram.png ├── Architecture.pdf └── Architecture-1.png ├── slides └── ts-rir_final.pdf ├── backup.py ├── RT60.py ├── loader.py ├── generator ├── loader.py └── generator.py ├── README.md ├── TSRIRgan.py └── train_TSRIRgan.py /web/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anton-jeran/TS-RIR/HEAD/web/img/favicon.ico -------------------------------------------------------------------------------- /images/spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anton-jeran/TS-RIR/HEAD/images/spectrogram.png -------------------------------------------------------------------------------- /images/Architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anton-jeran/TS-RIR/HEAD/images/Architecture.pdf -------------------------------------------------------------------------------- /slides/ts-rir_final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anton-jeran/TS-RIR/HEAD/slides/ts-rir_final.pdf -------------------------------------------------------------------------------- /images/Architecture-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anton-jeran/TS-RIR/HEAD/images/Architecture-1.png -------------------------------------------------------------------------------- /web/ckpts/dumper/dump.sh: -------------------------------------------------------------------------------- 1 | python dump_checkpoint_vars.py \ 2 | --model tensorflow \ 3 | --checkpoint_file ${1} \ 4 | --output_dir dumped 5 | -------------------------------------------------------------------------------- /web/dev.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, send_from_directory 2 | 3 | app = Flask(__name__, static_url_path='') 4 | 5 | @app.route('/js/') 6 | def send_js(path): 7 | return send_from_directory('js', path) 8 | 9 | @app.route('/img/') 10 | def send_img(path): 11 | return send_from_directory('img', path) 12 | 13 | @app.route('/css/') 14 | def send_css(path): 15 | return send_from_directory('css', path) 16 | 17 | @app.route('/ckpts/') 18 | def send_ckpts(path): 19 | return send_from_directory('ckpts', path) 20 | 21 | @app.route('/') 22 | def root(): 23 | return send_from_directory('', 'index.html') 24 | 25 | if __name__ == "__main__": 26 | app.run(host='0.0.0.0', port=6006) 27 | -------------------------------------------------------------------------------- /web/bundle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import shutil 4 | 5 | bundle_dir = 'bundle' 6 | 7 | paths = [ 8 | 'ckpts/drums', 9 | 'css', 10 | 'img', 11 | 'js', 12 | 'index.html' 13 | ] 14 | 15 | if os.path.exists(bundle_dir): 16 | shutil.rmtree(bundle_dir) 17 | 18 | for path in paths: 19 | out_path = os.path.join(bundle_dir, path) 20 | print('{}->{}'.format(path, out_path)) 21 | 22 | if os.path.isdir(path): 23 | shutil.copytree(path, out_path) 24 | else: 25 | out_dir = os.path.split(out_path)[0] 26 | if not os.path.exists(out_dir): 27 | os.makedirs(out_dir) 28 | 29 | shutil.copy(path, out_path) 30 | 31 | wavegan_cfg_fp = os.path.join(bundle_dir, 'js', 'wavegan_cfg.js') 32 | with open(wavegan_cfg_fp, 'r') as f: 33 | wavegan_cfg = f.read() 34 | 35 | wavegan_cfg = wavegan_cfg.replace('var debug = true;', 'var debug = false;') 36 | 37 | with open(wavegan_cfg_fp, 'w') as f: 38 | f.write(wavegan_cfg) 39 | -------------------------------------------------------------------------------- /backup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | if __name__ == '__main__': 4 | import glob 5 | import os 6 | import shutil 7 | import sys 8 | import time 9 | 10 | import tensorflow as tf 11 | 12 | train_dir, nmin = sys.argv[1:3] 13 | nsec = int(float(nmin) * 60.) 14 | 15 | backup_dir = os.path.join(train_dir, 'backup') 16 | 17 | if not os.path.exists(backup_dir): 18 | os.makedirs(backup_dir) 19 | 20 | while tf.train.latest_checkpoint(train_dir) is None: 21 | print('Waiting for first checkpoint') 22 | time.sleep(1) 23 | 24 | while True: 25 | latest_ckpt = tf.train.latest_checkpoint(train_dir) 26 | 27 | # Sleep for two seconds in case file flushing 28 | time.sleep(2) 29 | 30 | for fp in glob.glob(latest_ckpt + '*'): 31 | _, name = os.path.split(fp) 32 | backup_fp = os.path.join(backup_dir, name) 33 | print('{}->{}'.format(fp, backup_fp)) 34 | shutil.copyfile(fp, backup_fp) 35 | print('-' * 80) 36 | 37 | # Sleep for an hour 38 | time.sleep(nsec) 39 | -------------------------------------------------------------------------------- /web/js/wavegan_reqs.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (deeplearn, wavegan) { 4 | var cfg = wavegan.cfg; 5 | 6 | // Prompt if no WebGL 7 | try { 8 | var math = new deeplearn.NDArrayMathGPU(); 9 | } 10 | catch (err) { 11 | cfg.debugMsg('WebGL error: ' + String(err)); 12 | 13 | if (confirm(cfg.reqs.noWebGlWarning) === false) { 14 | cfg.reqs.userCanceled = true; 15 | cfg.debugMsg('User canceled demo (no WebGL)'); 16 | } 17 | 18 | document.getElementById('canceled').removeAttribute('hidden'); 19 | document.getElementById('content').removeAttribute('hidden'); 20 | document.getElementById('overlay').setAttribute('hidden', ''); 21 | 22 | return; 23 | } 24 | 25 | // Prompt if mobile 26 | if(/Android|webOS|iPhone|iPad|iPod|BlackBerry|IEMobile|Opera Mini/i.test(navigator.userAgent) ) { 27 | if (confirm(cfg.reqs.mobileWarning) === false) { 28 | cfg.reqs.userCanceled = true; 29 | cfg.debugMsg('User canceled demo (mobile)'); 30 | } 31 | 32 | document.getElementById('canceled').removeAttribute('hidden'); 33 | document.getElementById('content').removeAttribute('hidden'); 34 | document.getElementById('overlay').setAttribute('hidden', ''); 35 | 36 | return; 37 | } 38 | 39 | })(window.deeplearn, window.wavegan); 40 | -------------------------------------------------------------------------------- /web/js/wavegan_cfg.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (wavegan) { 4 | var debug = true; 5 | 6 | // Config 7 | wavegan.cfg = { 8 | reqs: { 9 | userCanceled: false, 10 | noWebGlWarning: 'Warning: We did not find WebGL in your browser. This demo uses WebGL to accelerate neural network computation. Performance will be slow and may hang your browser. Continue?', 11 | mobileWarning: 'Warning: This demo runs a neural network in your browser. It appears you are on a mobile device. Consider running the demo on your laptop instead. Continue?' 12 | }, 13 | net: { 14 | ckptDir: 'ckpts/drums', 15 | ppFilt: true, 16 | zDim: 100, 17 | cherries: [5, 2, 0, 62, 55, 12, 56, 21] 18 | }, 19 | audio: { 20 | gainDefault: 0.5, 21 | reverbDefault: 0.25, 22 | reverbLen: 2, 23 | reverbDecay: 10 24 | }, 25 | ui: { 26 | canvasFlushDelayMs: 25, 27 | visualizerGain: 1, 28 | zactorNumRows: 2, 29 | zactorNumCols: 4, 30 | rmsAnimDelayMs: 25 31 | }, 32 | sequencer: { 33 | labelWidth: 80, 34 | numCols: 16, 35 | tempoMin: 30, 36 | tempoMax: 300, 37 | tempoDefault: 120, 38 | swingMin: 0.5, 39 | swingMax: 0.8, 40 | swingDefault: 0.5, 41 | pattern: { 42 | 0: [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1], 43 | 1: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 44 | 2: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0], 45 | 3: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 46 | 4: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 47 | 5: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 48 | 6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0], 49 | 7: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] 50 | } 51 | } 52 | }; 53 | 54 | wavegan.cfg.debugMsg = function (msg) { 55 | if (debug) { 56 | console.log(msg); 57 | } 58 | }; 59 | 60 | })(window.wavegan); 61 | -------------------------------------------------------------------------------- /web/js/wavegan_player.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (wavegan) { 4 | // Config 5 | var cfg = wavegan.cfg; 6 | 7 | var ResamplingPlayer = function (fs) { 8 | this.sample = null; 9 | this.sampleIdx = 0; 10 | this.sampleIdxInc = 0; 11 | 12 | this.playing = false; 13 | this.fs = fs; 14 | this.rms = 0; 15 | }; 16 | ResamplingPlayer.prototype.setSample = function (sample, sampleFs) { 17 | var samplePadded = new Float32Array(sample.length + 1); 18 | for (var i = 0; i < sample.length; ++i) { 19 | samplePadded[i] = sample[i]; 20 | } 21 | samplePadded[i] = 0; 22 | 23 | this.sample = samplePadded; 24 | this.sampleLength = sample.length; 25 | this.sampleIdx = 0; 26 | this.sampleIdxInc = sampleFs / this.fs; 27 | this.playing = false; 28 | }; 29 | ResamplingPlayer.prototype.bang = function () { 30 | this.sampleIdx = 0; 31 | this.playing = true; 32 | }; 33 | ResamplingPlayer.prototype.readBlock = function (buffer) { 34 | if (!this.playing) { 35 | this.rms = 0; 36 | return; 37 | } 38 | 39 | var sample = this.sample; 40 | var sampleLength = this.sampleLength; 41 | var sampleIdx = this.sampleIdx; 42 | var sampleIdxInc = this.sampleIdxInc; 43 | var floor, frac; 44 | var samp; 45 | var rms = 0; 46 | for (var i = 0; i < buffer.length; ++i) { 47 | floor = Math.floor(sampleIdx); 48 | frac = sampleIdx - floor; 49 | 50 | if (floor < sampleLength) { 51 | samp = (1 - frac) * sample[floor] + frac * sample[floor + 1]; 52 | buffer[i] += samp; 53 | rms += (samp * samp); 54 | } 55 | else { 56 | this.playing = false; 57 | break; 58 | } 59 | 60 | sampleIdx += sampleIdxInc; 61 | } 62 | 63 | this.rms = Math.sqrt(rms / buffer.length); 64 | 65 | this.sampleIdx = sampleIdx; 66 | }; 67 | ResamplingPlayer.prototype.getRmsAmplitude = function () { 68 | return this.rms; 69 | }; 70 | 71 | // Exports 72 | wavegan.player = {}; 73 | wavegan.player.ResamplingPlayer = ResamplingPlayer; 74 | 75 | })(window.wavegan); 76 | -------------------------------------------------------------------------------- /web/ckpts/drums/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "G/pp_filt/conv1d/kernel": { 3 | "filename": "G_pp_filt_conv1d_kernel", 4 | "shape": [ 5 | 512, 6 | 1, 7 | 1 8 | ] 9 | }, 10 | "G/upconv_0/conv2d_transpose/bias": { 11 | "filename": "G_upconv_0_conv2d_transpose_bias", 12 | "shape": [ 13 | 512 14 | ] 15 | }, 16 | "G/upconv_0/conv2d_transpose/kernel": { 17 | "filename": "G_upconv_0_conv2d_transpose_kernel", 18 | "shape": [ 19 | 1, 20 | 25, 21 | 512, 22 | 1024 23 | ] 24 | }, 25 | "G/upconv_1/conv2d_transpose/bias": { 26 | "filename": "G_upconv_1_conv2d_transpose_bias", 27 | "shape": [ 28 | 256 29 | ] 30 | }, 31 | "G/upconv_1/conv2d_transpose/kernel": { 32 | "filename": "G_upconv_1_conv2d_transpose_kernel", 33 | "shape": [ 34 | 1, 35 | 25, 36 | 256, 37 | 512 38 | ] 39 | }, 40 | "G/upconv_2/conv2d_transpose/bias": { 41 | "filename": "G_upconv_2_conv2d_transpose_bias", 42 | "shape": [ 43 | 128 44 | ] 45 | }, 46 | "G/upconv_2/conv2d_transpose/kernel": { 47 | "filename": "G_upconv_2_conv2d_transpose_kernel", 48 | "shape": [ 49 | 1, 50 | 25, 51 | 128, 52 | 256 53 | ] 54 | }, 55 | "G/upconv_3/conv2d_transpose/bias": { 56 | "filename": "G_upconv_3_conv2d_transpose_bias", 57 | "shape": [ 58 | 64 59 | ] 60 | }, 61 | "G/upconv_3/conv2d_transpose/kernel": { 62 | "filename": "G_upconv_3_conv2d_transpose_kernel", 63 | "shape": [ 64 | 1, 65 | 25, 66 | 64, 67 | 128 68 | ] 69 | }, 70 | "G/upconv_4/conv2d_transpose/bias": { 71 | "filename": "G_upconv_4_conv2d_transpose_bias", 72 | "shape": [ 73 | 1 74 | ] 75 | }, 76 | "G/upconv_4/conv2d_transpose/kernel": { 77 | "filename": "G_upconv_4_conv2d_transpose_kernel", 78 | "shape": [ 79 | 1, 80 | 25, 81 | 1, 82 | 64 83 | ] 84 | }, 85 | "G/z_project/dense/bias": { 86 | "filename": "G_z_project_dense_bias", 87 | "shape": [ 88 | 16384 89 | ] 90 | }, 91 | "G/z_project/dense/kernel": { 92 | "filename": "G_z_project_dense_kernel", 93 | "shape": [ 94 | 100, 95 | 16384 96 | ] 97 | }, 98 | "cherries": { 99 | "filename": "cherries", 100 | "shape": [ 101 | 65, 102 | 100 103 | ] 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /web/css/wavegan.css: -------------------------------------------------------------------------------- 1 | html, body { 2 | margin: 0px; 3 | padding: 8px; 4 | width: 696px; 5 | } 6 | 7 | body { 8 | background-color: #242424; 9 | color: #dfdfdf; 10 | font-family: Arial, Helvetica, sans-serif; 11 | font-size: 1em; 12 | } 13 | 14 | button { 15 | margin: 0px; 16 | padding: 0px; 17 | height: 30px; 18 | width: 80px; 19 | user-select: none; 20 | } 21 | 22 | .slider label { 23 | display: inline; 24 | margin-right: 8px; 25 | } 26 | .slider input { 27 | width: 135px; 28 | height: 1em; 29 | } 30 | 31 | #gain-slider { 32 | display: inline; 33 | } 34 | 35 | #gain-slider input { 36 | width: 240px; 37 | } 38 | 39 | #reverb-slider { 40 | display: inline; 41 | margin-left: 8px; 42 | } 43 | 44 | #reverb-slider input { 45 | width: 240px; 46 | } 47 | 48 | /* Loading screen */ 49 | 50 | #spinner { 51 | position: absolute; 52 | left: 50%; 53 | top: 50%; 54 | z-index: 1; 55 | width: 150px; 56 | height: 150px; 57 | margin: -75px 0 0 -75px; 58 | border: 16px solid #dfdfdf; 59 | border-radius: 50%; 60 | border-top: 16px solid #3498db; 61 | width: 120px; 62 | height: 120px; 63 | -webkit-animation: spin 2s linear infinite; 64 | animation: spin 2s linear infinite; 65 | } 66 | 67 | @-webkit-keyframes spin { 68 | 0% { -webkit-transform: rotate(0deg); } 69 | 100% { -webkit-transform: rotate(360deg); } 70 | } 71 | 72 | @keyframes spin { 73 | 0% { transform: rotate(0deg); } 74 | 100% { transform: rotate(360deg); } 75 | } 76 | 77 | .overlay_txt { 78 | position: absolute; 79 | left: 50%; 80 | top: 50%; 81 | width: 300px; 82 | height: 20px; 83 | margin: 120px 0px 0 -116px; 84 | font-size: 2em; 85 | } 86 | 87 | #overlay_canceled { 88 | display: none; 89 | } 90 | 91 | /* Zactor grid */ 92 | 93 | .zactor { 94 | display: inline-block; 95 | width: 160px; 96 | height: 120px; 97 | margin-right: 8px; 98 | margin-bottom: 8px; 99 | padding: 0px; 100 | } 101 | 102 | .zactor canvas { 103 | display: inline; 104 | margin: 0px; 105 | padding: 0px; 106 | user-select: none; 107 | } 108 | 109 | .zactor button { 110 | float: left; 111 | } 112 | 113 | .row { 114 | display: block; 115 | } 116 | 117 | #zactors { 118 | margin-top: 24px; 119 | } 120 | 121 | /* Sequencer */ 122 | 123 | #sequencer { 124 | margin-top: 16px; 125 | } 126 | 127 | #sequencer button { 128 | display: inline-block; 129 | } 130 | 131 | #sequencer-buttons { 132 | display: inline-block; 133 | } 134 | 135 | #sequencer-sliders { 136 | display: inline-block; 137 | } 138 | 139 | #sequencer-sliders div { 140 | display: inline-block; 141 | } 142 | 143 | #sequencer-ui { 144 | margin-top: 8px; 145 | } 146 | 147 | #sequencer-canvas { 148 | user-select: none; 149 | } 150 | 151 | 152 | /* 153 | #sequencer-sliders label { 154 | display: inline-block; 155 | } 156 | 157 | #sequencer-sliders input { 158 | display: inline-block; 159 | } 160 | */ 161 | -------------------------------------------------------------------------------- /RT60.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from scipy.io import wavfile 4 | from scipy import stats 5 | 6 | from acoustics.utils import _is_1d 7 | from acoustics.signal import bandpass 8 | from acoustics.bands import (_check_band_type, octave_low, octave_high, third_low, third_high) 9 | 10 | def t60_impulse(raw_signal,fs): # pylint: disable=too-many-locals 11 | """ 12 | Reverberation time from a WAV impulse response. 13 | :param file_name: name of the WAV file containing the impulse response. 14 | :param bands: Octave or third bands as NumPy array. 15 | :param rt: Reverberation time estimator. It accepts `'t30'`, `'t20'`, `'t10'` and `'edt'`. 16 | :returns: Reverberation time :math:`T_{60}` 17 | """ 18 | bands =np.array([62.5 ,125, 250, 500,1000, 2000]) 19 | if np.max(raw_signal)==0 and np.min(raw_signal)==0: 20 | print('came 1') 21 | return .5 22 | 23 | # fs, raw_signal = wavfile.read(file_name) 24 | band_type = _check_band_type(bands) 25 | 26 | # if band_type == 'octave': 27 | low = octave_low(bands[0], bands[-1]) 28 | high = octave_high(bands[0], bands[-1]) 29 | # elif band_type == 'third': 30 | # low = third_low(bands[0], bands[-1]) 31 | # high = third_high(bands[0], bands[-1]) 32 | 33 | 34 | init = -0.0 35 | end = -60.0 36 | factor = 1.0 37 | bands =bands[3:5] 38 | low = low[3:5] 39 | high = high[3:5] 40 | 41 | t60 = np.zeros(bands.size) 42 | 43 | for band in range(bands.size): 44 | # Filtering signal 45 | filtered_signal = bandpass(raw_signal, low[band], high[band], fs, order=8) 46 | abs_signal = np.abs(filtered_signal) / np.max(np.abs(filtered_signal)) 47 | 48 | # Schroeder integration 49 | sch = np.cumsum(abs_signal[::-1]**2)[::-1] 50 | 51 | sch_db = 10.0 * np.log10(sch / np.max(sch)) 52 | if math.isnan(sch_db[1]): 53 | print('came 2') 54 | return .5 55 | # print("leng sch_db ",sch_db.size) 56 | # print("sch_db ",sch_db) 57 | # Linear regression 58 | sch_init = sch_db[np.abs(sch_db - init).argmin()] 59 | sch_end = sch_db[np.abs(sch_db - end).argmin()] 60 | init_sample = np.where(sch_db == sch_init)[0][0] 61 | end_sample = np.where(sch_db == sch_end)[0][0] 62 | x = np.arange(init_sample, end_sample + 1) / fs 63 | y = sch_db[init_sample:end_sample + 1] 64 | slope, intercept = stats.linregress(x, y)[0:2] 65 | 66 | # Reverberation time (T30, T20, T10 or EDT) 67 | db_regress_init = (init - intercept) / slope 68 | db_regress_end = (end - intercept) / slope 69 | t60[band] = factor * (db_regress_end - db_regress_init) 70 | mean_t60 =(t60[1]+t60[0])/2 71 | # print("meant60 is ", mean_t60) 72 | if math.isnan(mean_t60): 73 | print('came 3') 74 | return .5 75 | return mean_t60 76 | 77 | if __name__ == '__main__': 78 | t60_impulse('/home/anton/Desktop/gamma101/data/evaluation_all/SF1/Hotel_SkalskyDvur_ConferenceRoom2-MicID01-SpkID01_20170906_S-09-RIR-IR_sweep_15s_45Hzto22kHz_FS16kHz.v00.wav') 79 | -------------------------------------------------------------------------------- /web/ckpts/dumper/dump_checkpoint_vars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | This script is an entry point for dumping checkpoints for various deeplearning 18 | frameworks. 19 | """ 20 | from __future__ import print_function 21 | 22 | import argparse 23 | 24 | 25 | def get_checkpoint_dumper(model_type, checkpoint_file, output_dir, remove_variables_regex): 26 | """Returns Checkpoint dumper instance for a given model type. 27 | 28 | Parameters 29 | ---------- 30 | model_type : str 31 | Type of deeplearning framework 32 | checkpoint_file : str 33 | Path to checkpoint file 34 | output_dir : str 35 | Path to output directory 36 | remove_variables_regex : str 37 | Regex for variables to be ignored 38 | 39 | Returns 40 | ------- 41 | (TensorflowCheckpointDumper, PytorchCheckpointDumper) 42 | Checkpoint Dumper Instance for corresponding model type 43 | 44 | Raises 45 | ------ 46 | Error 47 | If particular model type is not supported 48 | """ 49 | if model_type == 'tensorflow': 50 | from tensorflow_checkpoint_dumper import TensorflowCheckpointDumper 51 | 52 | return TensorflowCheckpointDumper( 53 | checkpoint_file, output_dir, remove_variables_regex) 54 | elif model_type == 'pytorch': 55 | from pytorch_checkpoint_dumper import PytorchCheckpointDumper 56 | 57 | return PytorchCheckpointDumper( 58 | checkpoint_file, output_dir, remove_variables_regex) 59 | else: 60 | raise ValueError('Currently, "{}" models are not supported'.format(model_type)) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | '--model_type', 67 | type=str, 68 | required=True, 69 | help='Model checkpoint type') 70 | parser.add_argument( 71 | '--checkpoint_file', 72 | type=str, 73 | required=True, 74 | help='Path to the model checkpoint') 75 | parser.add_argument( 76 | '--output_dir', 77 | type=str, 78 | required=True, 79 | help='The output directory where to store the converted weights') 80 | parser.add_argument( 81 | '--remove_variables_regex', 82 | type=str, 83 | default='', 84 | help='A regular expression to match against variable names that should ' 85 | 'not be included') 86 | FLAGS, unparsed = parser.parse_known_args() 87 | 88 | if unparsed: 89 | parser.print_help() 90 | print('Unrecognized flags: ', unparsed) 91 | exit(-1) 92 | 93 | checkpoint_dumper = get_checkpoint_dumper( 94 | FLAGS.model_type, FLAGS.checkpoint_file, FLAGS.output_dir, FLAGS.remove_variables_regex) 95 | checkpoint_dumper.build_and_dump_vars() 96 | -------------------------------------------------------------------------------- /web/ckpts/dumper/tensorflow_checkpoint_dumper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """This script defines TensorflowCheckpointDumper class. 17 | 18 | This class takes a tensorflow checkpoint file and writes all of the variables in the 19 | checkpoint to a directory which deeplearnjs can take as input. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from six import iteritems 27 | 28 | import argparse 29 | import json 30 | import os 31 | import re 32 | 33 | import tensorflow as tf 34 | 35 | from checkpoint_dumper import CheckpointDumper 36 | 37 | class TensorflowCheckpointDumper(CheckpointDumper): 38 | 39 | """Class for dumping Tensorflow Checkpoints. 40 | 41 | Attributes 42 | ---------- 43 | reader : NewCheckpointReader 44 | Reader for given tensorflow checkpoint 45 | """ 46 | 47 | def __init__(self, checkpoint_file, output_dir, remove_variables_regex): 48 | """Constructs object for Tensorflow Checkpoint Dumper. 49 | 50 | Parameters 51 | ---------- 52 | checkpoint_file : str 53 | Path to the model checkpoint 54 | output_dir : str 55 | Output directory path 56 | remove_variables_regex : str 57 | Regex expression for variables to be ignored 58 | """ 59 | super(TensorflowCheckpointDumper, self).__init__( 60 | checkpoint_file, output_dir, remove_variables_regex) 61 | 62 | self.reader = tf.train.NewCheckpointReader(self.checkpoint_file) 63 | 64 | def var_name_to_filename(self, var_name): 65 | """Converts variable names to standard file names. 66 | 67 | Parameters 68 | ---------- 69 | var_name : str 70 | Variable name to be converted 71 | 72 | Returns 73 | ------- 74 | str 75 | Standardized file name 76 | """ 77 | chars = [] 78 | 79 | for c in var_name: 80 | if c in CheckpointDumper.FILENAME_CHARS: 81 | chars.append(c) 82 | elif c == '/': 83 | chars.append('_') 84 | 85 | return ''.join(chars) 86 | 87 | def build_and_dump_vars(self): 88 | """Builds and dumps variables and a manifest file. 89 | """ 90 | var_to_shape_map = self.reader.get_variable_to_shape_map() 91 | 92 | for (var_name, var_shape) in iteritems(var_to_shape_map): 93 | if self.should_ignore(var_name) or var_name == 'global_step': 94 | print('Ignoring ' + var_name) 95 | continue 96 | 97 | var_filename = self.var_name_to_filename(var_name) 98 | self.manifest[var_name] = {'filename': var_filename, 'shape': var_shape} 99 | 100 | tensor = self.reader.get_tensor(var_name) 101 | self.dump_weights(var_name, var_filename, var_shape, tensor) 102 | 103 | self.dump_manifest() 104 | -------------------------------------------------------------------------------- /web/js/wavegan_visualizer.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (wavegan) { 4 | // Config 5 | var cfg = wavegan.cfg; 6 | 7 | var WaveformVisualizer = function (canvas, name, color) { 8 | this.canvas = canvas; 9 | this.canvasCtx = this.canvas.getContext('2d'); 10 | this.canvasWidth = this.canvas.width; 11 | this.canvasHeight = this.canvas.height; 12 | 13 | this.canvasBuffer = document.createElement('canvas'); 14 | this.canvasBufferCtx = this.canvasBuffer.getContext('2d'); 15 | this.canvasBuffer.width = this.canvasWidth; 16 | this.canvasBuffer.height = this.canvasHeight; 17 | 18 | this.name = name; 19 | this.color = color; 20 | }; 21 | WaveformVisualizer.prototype.render = function (rms) { 22 | rms = rms === undefined ? 0 : rms; 23 | var ctx = this.canvasCtx; 24 | var w = this.canvasWidth; 25 | var h = this.canvasHeight; 26 | 27 | // Draw buffer 28 | ctx.clearRect(0, 0, w, h); 29 | ctx.drawImage(this.canvasBuffer, 0, 0); 30 | 31 | // Draw outline 32 | ctx.globalAlpha = Math.min(rms * 2, 1); 33 | ctx.fillStyle = '#FF0000'; 34 | ctx.fillRect(0, 0, w, h); 35 | ctx.globalAlpha = 1; 36 | }; 37 | WaveformVisualizer.prototype.setSample = function (sample) { 38 | var ctx = this.canvasBufferCtx; 39 | var w = this.canvasWidth; 40 | var h = this.canvasHeight; 41 | var gain = cfg.ui.visualizerGain; 42 | 43 | var hd2 = h / 2; 44 | var t = sample.length; 45 | var pxdt = t / w; 46 | 47 | // Clear background 48 | ctx.clearRect(0, 0, w, h); 49 | ctx.fillStyle = '#000000'; 50 | ctx.fillRect(0, 0, w, h); 51 | 52 | // Draw DC line 53 | if (this.color !== undefined) { 54 | ctx.fillStyle = this.color; 55 | } 56 | else { 57 | ctx.fillStyle = '#33ccff'; 58 | } 59 | ctx.fillRect(0, hd2, w, 1); 60 | 61 | // Draw waveform 62 | for (var i = 0; i < w; ++i) { 63 | var tl = Math.floor(i * pxdt); 64 | var th = Math.floor((i + 1) * pxdt); 65 | 66 | var max = 0; 67 | for (var k = tl; k < th; ++k) { 68 | if (Math.abs(sample[k]) > max) { 69 | max = Math.abs(sample[k]); 70 | } 71 | } 72 | 73 | var rect_height = max * hd2 * gain; 74 | 75 | ctx.fillRect(i, hd2 - rect_height, 1, rect_height); 76 | ctx.fillRect(i, hd2, 1, rect_height); 77 | } 78 | 79 | // Draw name 80 | if (this.name !== undefined) { 81 | var textHeight = 14; 82 | ctx.font = String(textHeight) + 'px sans-serif'; 83 | var textSize = ctx.measureText(this.name); 84 | var textWidth = Math.ceil(textSize.width); 85 | var boxWidth = textWidth + 6; 86 | var boxHeight = textHeight + 6; 87 | 88 | ctx.strokeStyle = '#ffffff'; 89 | ctx.lineWidth = '1'; 90 | //ctx.rect(w - boxWidth - 2, 2, boxWidth, boxHeight); 91 | //ctx.stroke(); 92 | ctx.fillStyle = '#ffffff'; 93 | ctx.fillText(this.name, w - boxWidth + 1, textHeight + 3); 94 | } 95 | 96 | // Render to canvas 97 | this.render(); 98 | }; 99 | 100 | // Exports 101 | wavegan.visualizer = {}; 102 | wavegan.visualizer.WaveformVisualizer = WaveformVisualizer; 103 | 104 | })(window.wavegan); 105 | -------------------------------------------------------------------------------- /web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | WaveGAN Demo 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 |
29 |

30 | Loading network... 31 |

32 | 33 |
34 | 35 | 84 | 85 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /web/js/wavegan_savewav.js: -------------------------------------------------------------------------------- 1 | // Sourced from https://gist.github.com/asanoboy/3979747 2 | 3 | window.wavegan = window.wavegan || {}; 4 | 5 | (function (wavegan) { 6 | var Wav = function(opt_params){ 7 | this._sampleRate = opt_params && opt_params.sampleRate ? opt_params.sampleRate : 44100; 8 | this._channels = opt_params && opt_params.channels ? opt_params.channels : 2; 9 | this._eof = true; 10 | this._bufferNeedle = 0; 11 | this._buffer; 12 | }; 13 | 14 | Wav.prototype.setBuffer = function(buffer){ 15 | this._buffer = this.getWavInt16Array(buffer); 16 | this._bufferNeedle = 0; 17 | this._internalBuffer = ''; 18 | this._hasOutputHeader = false; 19 | this._eof = false; 20 | }; 21 | 22 | Wav.prototype.getBuffer = function(len){ 23 | var rt; 24 | if( this._bufferNeedle + len >= this._buffer.length ){ 25 | rt = new Int16Array(this._buffer.length - this._bufferNeedle); 26 | this._eof = true; 27 | } 28 | else { 29 | rt = new Int16Array(len); 30 | } 31 | 32 | for(var i=0; i> 16; // RIFF size 52 | 53 | intBuffer[4] = 0x4157; // "WA" 54 | intBuffer[5] = 0x4556; // "VE" 55 | 56 | intBuffer[6] = 0x6d66; // "fm" 57 | intBuffer[7] = 0x2074; // "t " 58 | 59 | intBuffer[8] = 0x0012; // fmt chunksize: 18 60 | intBuffer[9] = 0x0000; // 61 | 62 | intBuffer[10] = 0x0001; // format tag : 1 63 | intBuffer[11] = this._channels; // channels: 2 64 | 65 | intBuffer[12] = this._sampleRate & 0x0000ffff; // sample per sec 66 | intBuffer[13] = (this._sampleRate & 0xffff0000) >> 16; // sample per sec 67 | 68 | intBuffer[14] = (2*this._channels*this._sampleRate) & 0x0000ffff; // byte per sec 69 | intBuffer[15] = ((2*this._channels*this._sampleRate) & 0xffff0000) >> 16; // byte per sec 70 | 71 | intBuffer[16] = 0x0004; // block align 72 | intBuffer[17] = 0x0010; // bit per sample 73 | intBuffer[18] = 0x0000; // cb size 74 | intBuffer[19] = 0x6164; // "da" 75 | intBuffer[20] = 0x6174; // "ta" 76 | intBuffer[21] = (2*buffer.length) & 0x0000ffff; // data size[byte] 77 | intBuffer[22] = ((2*buffer.length) & 0xffff0000) >> 16; // data size[byte] 78 | 79 | for (var i = 0; i < buffer.length; i++) { 80 | tmp = buffer[i]; 81 | if (tmp >= 1) { 82 | intBuffer[i+23] = (1 << 15) - 1; 83 | } 84 | else if (tmp <= -1) { 85 | intBuffer[i+23] = -(1 << 15); 86 | } 87 | else { 88 | intBuffer[i+23] = Math.round(tmp * (1 << 15)); 89 | } 90 | } 91 | 92 | return intBuffer; 93 | }; 94 | 95 | wavegan.savewav = {}; 96 | wavegan.savewav.randomFilename = function () { 97 | return Math.random().toString(36).substring(7) + '.wav'; 98 | }; 99 | wavegan.savewav.saveWav = function (fn, buffer) { 100 | var wav = new Wav({sampleRate: 16000, channels: 1}); 101 | wav.setBuffer(buffer); 102 | 103 | // Create file 104 | var srclist = []; 105 | while (!wav.eof()) { 106 | srclist.push(wav.getBuffer(1024)); 107 | } 108 | var b = new Blob(srclist, {type:'audio/wav'}); 109 | 110 | // Download 111 | var URLObject = window.webkitURL || window.URL; 112 | var url = URLObject.createObjectURL(b); 113 | var a = document.createElement('a'); 114 | a.style = 'display:none'; 115 | a.href = url; 116 | a.download = fn; 117 | a.click(); 118 | URLObject.revokeObjectURL(url); 119 | }; 120 | 121 | })(window.wavegan); 122 | -------------------------------------------------------------------------------- /web/ckpts/dumper/checkpoint_dumper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """This script defines CheckpointDumper class. 17 | 18 | This class serves as a base class for other deeplearning checkpoint dumper 19 | classes and defines common methods, attributes etc. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import json 27 | import os 28 | import re 29 | import string 30 | 31 | class CheckpointDumper(object): 32 | 33 | """Base Checkpoint Dumper class. 34 | 35 | Attributes 36 | ---------- 37 | checkpoint_file : str 38 | Path to the model checkpoint 39 | FILENAME_CHARS : str 40 | Allowed file char names 41 | manifest : dict 42 | Manifest file defining variables 43 | output_dir : str 44 | Output directory path 45 | remove_variables_regex : str 46 | Regex expression for variables to be ignored 47 | remove_variables_regex_re : sre.SRE_Pattern 48 | Compiled `remove variable` regex 49 | """ 50 | 51 | FILENAME_CHARS = string.ascii_letters + string.digits + '_' 52 | 53 | def __init__(self, checkpoint_file, output_dir, remove_variables_regex): 54 | """Constructs object for Checkpoint Dumper. 55 | 56 | Parameters 57 | ---------- 58 | checkpoint_file : str 59 | Path to the model checkpoint 60 | output_dir : str 61 | Output directory path 62 | remove_variables_regex : str 63 | Regex expression for variables to be ignored 64 | """ 65 | self.checkpoint_file = os.path.expanduser(checkpoint_file) 66 | self.output_dir = os.path.expanduser(output_dir) 67 | self.remove_variables_regex = remove_variables_regex 68 | 69 | self.manifest = {} 70 | self.remove_variables_regex_re = re.compile(self.remove_variables_regex) 71 | 72 | self.make_dir(self.output_dir) 73 | 74 | 75 | @staticmethod 76 | def make_dir(directory): 77 | """Makes directory if not existing. 78 | 79 | Parameters 80 | ---------- 81 | directory : str 82 | Path to directory 83 | """ 84 | if not os.path.exists(directory): 85 | os.makedirs(directory) 86 | 87 | 88 | def should_ignore(self, name): 89 | """Checks whether name should be ignored or not. 90 | 91 | Parameters 92 | ---------- 93 | name : str 94 | Name to be checked 95 | 96 | Returns 97 | ------- 98 | bool 99 | Whether to ignore the name or not 100 | """ 101 | return self.remove_variables_regex and re.match(self.remove_variables_regex_re, name) 102 | 103 | 104 | def dump_weights(self, variable_name, filename, shape, weights): 105 | """Creates a file with given name and dumps byte weights in it. 106 | 107 | Parameters 108 | ---------- 109 | variable_name : str 110 | Name of given variable 111 | filename : str 112 | File name for given variable 113 | shape : list 114 | Shape of given variable 115 | weights : ndarray 116 | Weights for given variable 117 | """ 118 | self.manifest[variable_name] = {'filename': filename, 'shape': shape} 119 | 120 | print('Writing variable ' + variable_name + '...') 121 | with open(os.path.join(self.output_dir, filename), 'wb') as f: 122 | f.write(weights.tobytes()) 123 | 124 | 125 | def dump_manifest(self, filename='manifest.json'): 126 | """Creates a manifest file with given name and dumps meta information 127 | related to model. 128 | 129 | Parameters 130 | ---------- 131 | filename : str, optional 132 | Manifest file name 133 | """ 134 | manifest_fpath = os.path.join(self.output_dir, filename) 135 | 136 | print('Writing manifest to ' + manifest_fpath) 137 | with open(manifest_fpath, 'w') as f: 138 | f.write(json.dumps(self.manifest, indent=2, sort_keys=True)) 139 | -------------------------------------------------------------------------------- /web/js/wavegan_net.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (dl, wavegan) { 4 | // Config 5 | var cfg = wavegan.cfg; 6 | if (cfg.reqs.userCanceled) { 7 | return; 8 | } 9 | 10 | // Network state 11 | var net = { 12 | vars: null, 13 | ready: false 14 | }; 15 | 16 | // Hardware state 17 | var hw = { 18 | math: null, 19 | ready: false 20 | }; 21 | 22 | // Initialize hardware (uses WebGL if possible) 23 | var initHw = function (graph) { 24 | // TODO: update this 25 | try { 26 | new dl.NDArrayMathGPU(); 27 | cfg.debugMsg('WebGL supported'); 28 | } 29 | catch(err) { 30 | new dl.NDArrayMathCPU(); 31 | cfg.debugMsg('WebGL not supported'); 32 | } 33 | 34 | hw.math = dl.ENV.math; 35 | 36 | hw.ready = true; 37 | cfg.debugMsg('Hardware ready'); 38 | }; 39 | 40 | // Initialize network and hardware 41 | var initVars = function () { 42 | var varLoader = new dl.CheckpointLoader(cfg.net.ckptDir); 43 | varLoader.getAllVariables().then(function (vars) { 44 | net.vars = vars; 45 | net.ready = true; 46 | 47 | cfg.debugMsg('Variables loaded'); 48 | }); 49 | }; 50 | 51 | // Exports 52 | wavegan.net = {}; 53 | 54 | wavegan.net.isReady = function () { 55 | return net.ready && hw.ready; 56 | }; 57 | 58 | wavegan.net.getCherries = function () { 59 | if (!wavegan.net.isReady()) { 60 | throw 'Hardware not ready'; 61 | } 62 | if ('cherries' in net.vars) { 63 | var cherries = net.vars['cherries']; 64 | var _zs = []; 65 | for (var i = 0; i < cherries.shape[0]; ++i) { 66 | var _z = new Float32Array(cfg.net.zDim); 67 | for (var j = 0; j < cfg.net.zDim; ++j) { 68 | _z[j] = cherries.get(i, j); 69 | } 70 | _zs.push(_z); 71 | } 72 | return _zs; 73 | } 74 | else { 75 | return null; 76 | } 77 | }; 78 | 79 | wavegan.net.eval = function (_z) { 80 | if (!wavegan.net.isReady()) { 81 | throw 'Hardware not ready'; 82 | } 83 | for (var i = 0; i < _z.length; ++i) { 84 | if (_z[i].length !== cfg.net.zDim) { 85 | throw 'Input shape incorrect' 86 | } 87 | } 88 | 89 | var m = hw.math; 90 | 91 | // Reshape input to 2D array 92 | var b = _z.length; 93 | var _z_flat = new Float32Array(b * cfg.net.zDim); 94 | for (var i = 0; i < b; ++i) { 95 | for (var j = 0; j < cfg.net.zDim; ++j) { 96 | _z_flat[i * cfg.net.zDim + j] = _z[i][j]; 97 | } 98 | } 99 | var x = dl.Array2D.new([b, cfg.net.zDim], _z_flat); 100 | 101 | // Project to [b, 1, 16, 1024] 102 | x = m.matMul(x, net.vars['G/z_project/dense/kernel']); 103 | x = m.add(x, net.vars['G/z_project/dense/bias']); 104 | x = m.relu(x); 105 | x = x.reshape([b, 1, 16, 1024]); 106 | 107 | // Conv 0 to [b, 1, 64, 512] 108 | x = m.conv2dTranspose(x, 109 | net.vars['G/upconv_0/conv2d_transpose/kernel'], 110 | [b, 1, 64, 512], 111 | [1, 4], 112 | 'same'); 113 | x = m.add(x, net.vars['G/upconv_0/conv2d_transpose/bias']); 114 | x = m.relu(x); 115 | 116 | // Conv 1 to [b, 1, 256, 256] 117 | x = m.conv2dTranspose(x, 118 | net.vars['G/upconv_1/conv2d_transpose/kernel'], 119 | [b, 1, 256, 256], 120 | [1, 4], 121 | 'same'); 122 | x = m.add(x, net.vars['G/upconv_1/conv2d_transpose/bias']); 123 | x = m.relu(x); 124 | 125 | // Conv 2 to [b, 1, 1024, 128] 126 | x = m.conv2dTranspose(x, 127 | net.vars['G/upconv_2/conv2d_transpose/kernel'], 128 | [b, 1, 1024, 128], 129 | [1, 4], 130 | 'same'); 131 | x = m.add(x, net.vars['G/upconv_2/conv2d_transpose/bias']); 132 | x = m.relu(x); 133 | 134 | // Conv 3 to [b, 1, 4096, 64] 135 | x = m.conv2dTranspose(x, 136 | net.vars['G/upconv_3/conv2d_transpose/kernel'], 137 | [b, 1, 4096, 64], 138 | [1, 4], 139 | 'same'); 140 | x = m.add(x, net.vars['G/upconv_3/conv2d_transpose/bias']); 141 | x = m.relu(x); 142 | 143 | // Conv 4 to [b, 1, 16384, 1] 144 | x = m.conv2dTranspose(x, 145 | net.vars['G/upconv_4/conv2d_transpose/kernel'], 146 | [b, 1, 16384, 1], 147 | [1, 4], 148 | 'same'); 149 | x = m.add(x, net.vars['G/upconv_4/conv2d_transpose/bias']); 150 | x = m.tanh(x); 151 | 152 | // Post processing filter 153 | x = m.reshape(x, [b, 16384, 1]); 154 | if (cfg.net.ppFilt) { 155 | x = m.conv1d(x, 156 | net.vars['G/pp_filt/conv1d/kernel'], 157 | null, 158 | 1, 159 | 'same'); 160 | } 161 | 162 | // Create Float32Arrays with result 163 | wavs = [] 164 | for (var i = 0; i < b; ++i) { 165 | var wav = new Float32Array(16384); 166 | for (var j = 0; j < 16384; ++j) { 167 | wav[j] = x.get(i, j, 0); 168 | } 169 | wavs.push(wav); 170 | } 171 | 172 | return wavs 173 | }; 174 | 175 | // Run immediately 176 | initVars(); 177 | initHw(); 178 | 179 | })(window.dl, window.wavegan); 180 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from scipy.io.wavfile import read as wavread 2 | import numpy as np 3 | 4 | import tensorflow as tf 5 | 6 | import sys 7 | 8 | 9 | def decode_audio(fp, fs=None, num_channels=1, normalize=False, fast_wav=False): 10 | """Decodes audio file paths into 32-bit floating point vectors. 11 | 12 | Args: 13 | fp: Audio file path. 14 | fs: If specified, resamples decoded audio to this rate. 15 | mono: If true, averages channels to mono. 16 | fast_wav: Assume fp is a standard WAV file (PCM 16-bit or float 32-bit). 17 | 18 | Returns: 19 | A np.float32 array containing the audio samples at specified sample rate. 20 | """ 21 | if fast_wav: 22 | # Read with scipy wavread (fast). 23 | _fs, _wav = wavread(fp) 24 | if fs is not None and fs != _fs: 25 | raise NotImplementedError('Scipy cannot resample audio.') 26 | if _wav.dtype == np.int16: 27 | _wav = _wav.astype(np.float32) 28 | _wav /= 32768. 29 | elif _wav.dtype == np.float32: 30 | _wav = np.copy(_wav) 31 | else: 32 | raise NotImplementedError('Scipy cannot process atypical WAV files.') 33 | else: 34 | # Decode with librosa load (slow but supports file formats like mp3). 35 | import librosa 36 | _wav, _fs = librosa.core.load(fp, sr=fs, mono=False) 37 | if _wav.ndim == 2: 38 | _wav = np.swapaxes(_wav, 0, 1) 39 | 40 | assert _wav.dtype == np.float32 41 | 42 | # At this point, _wav is np.float32 either [nsamps,] or [nsamps, nch]. 43 | # We want [nsamps, 1, nch] to mimic 2D shape of spectral feats. 44 | if _wav.ndim == 1: 45 | nsamps = _wav.shape[0] 46 | nch = 1 47 | else: 48 | nsamps, nch = _wav.shape 49 | _wav = np.reshape(_wav, [nsamps, 1, nch]) 50 | 51 | # Average (mono) or expand (stereo) channels 52 | if nch != num_channels: 53 | if num_channels == 1: 54 | _wav = np.mean(_wav, 2, keepdims=True) 55 | elif nch == 1 and num_channels == 2: 56 | _wav = np.concatenate([_wav, _wav], axis=2) 57 | else: 58 | raise ValueError('Number of audio channels not equal to num specified') 59 | if normalize: 60 | factor = np.max(np.abs(_wav)) 61 | if factor > 0: 62 | _wav /= factor 63 | 64 | return _wav 65 | 66 | 67 | def decode_extract_and_batch( 68 | fps, 69 | batch_size, 70 | slice_len, 71 | decode_fs, 72 | decode_num_channels, 73 | decode_normalize=True, 74 | decode_fast_wav=False, 75 | decode_parallel_calls=1, 76 | slice_randomize_offset=False, 77 | slice_first_only=False, 78 | slice_overlap_ratio=0, 79 | slice_pad_end=False, 80 | repeat=False, 81 | shuffle=False, 82 | shuffle_buffer_size=None, 83 | prefetch_size=None, 84 | prefetch_gpu_num=None): 85 | # tf.debugging.set_log_device_placement(True) 86 | """Decodes audio file paths into mini-batches of samples. 87 | 88 | Args: 89 | fps: List of audio file paths. 90 | batch_size: Number of items in the batch. 91 | slice_len: Length of the sliceuences in samples or feature timesteps. 92 | decode_fs: (Re-)sample rate for decoded audio files. 93 | decode_num_channels: Number of channels for decoded audio files. 94 | decode_normalize: If false, do not normalize audio waveforms. 95 | decode_fast_wav: If true, uses scipy to decode standard wav files. 96 | decode_parallel_calls: Number of parallel decoding threads. 97 | slice_randomize_offset: If true, randomize starting position for slice. 98 | slice_first_only: If true, only use first slice from each audio file. 99 | slice_overlap_ratio: Ratio of overlap between adjacent slices. 100 | slice_pad_end: If true, allows zero-padded examples from the end of each audio file. 101 | repeat: If true (for training), continuously iterate through the dataset. 102 | shuffle: If true (for training), buffer and shuffle the sliceuences. 103 | shuffle_buffer_size: Number of examples to queue up before grabbing a batch. 104 | prefetch_size: Number of examples to prefetch from the queue. 105 | prefetch_gpu_num: If specified, prefetch examples to GPU. 106 | 107 | Returns: 108 | A tuple of np.float32 tensors representing audio waveforms. 109 | audio: [batch_size, slice_len, 1, nch] 110 | """ 111 | # Create dataset of filepaths 112 | dataset = tf.data.Dataset.from_tensor_slices(fps) 113 | 114 | # Shuffle all filepaths every epoch 115 | if shuffle: 116 | dataset = dataset.shuffle(buffer_size=len(fps)) 117 | 118 | # Repeat 119 | if repeat: 120 | dataset = dataset.repeat() 121 | 122 | def _decode_audio_shaped(fp): 123 | _decode_audio_closure = lambda _fp: decode_audio( 124 | _fp, 125 | fs=decode_fs, 126 | num_channels=decode_num_channels, 127 | normalize=decode_normalize, 128 | fast_wav=decode_fast_wav) 129 | 130 | audio = tf.py_func( 131 | _decode_audio_closure, 132 | [fp], 133 | tf.float32, 134 | stateful=False) 135 | audio.set_shape([None, 1, decode_num_channels]) 136 | 137 | return audio 138 | 139 | # Decode audio 140 | dataset = dataset.map( 141 | _decode_audio_shaped, 142 | num_parallel_calls=decode_parallel_calls) 143 | 144 | # Parallel 145 | def _slice(audio): 146 | # Calculate hop size 147 | if slice_overlap_ratio < 0: 148 | raise ValueError('Overlap ratio must be greater than 0') 149 | slice_hop = int(round(slice_len * (1. - slice_overlap_ratio)) + 1e-4) 150 | if slice_hop < 1: 151 | raise ValueError('Overlap ratio too high') 152 | 153 | # Randomize starting phase: 154 | if slice_randomize_offset: 155 | start = tf.random_uniform([], maxval=slice_len, dtype=tf.int32) 156 | audio = audio[start:] 157 | 158 | # Extract sliceuences 159 | audio_slices = tf.contrib.signal.frame( 160 | audio, 161 | slice_len, 162 | slice_hop, 163 | pad_end=slice_pad_end, 164 | pad_value=0, 165 | axis=0) 166 | 167 | # Only use first slice if requested 168 | if slice_first_only: 169 | audio_slices = audio_slices[:1] 170 | 171 | return audio_slices 172 | 173 | def _slice_dataset_wrapper(audio): 174 | audio_slices = _slice(audio) 175 | return tf.data.Dataset.from_tensor_slices(audio_slices) 176 | 177 | # Extract parallel sliceuences from both audio and features 178 | dataset = dataset.flat_map(_slice_dataset_wrapper) 179 | 180 | # Shuffle examples 181 | if shuffle: 182 | dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) 183 | 184 | # Make batches 185 | dataset = dataset.batch(batch_size, drop_remainder=True) 186 | 187 | # Prefetch a number of batches 188 | if prefetch_size is not None: 189 | dataset = dataset.prefetch(prefetch_size) 190 | if prefetch_gpu_num is not None and prefetch_gpu_num >= 0: 191 | print('prefetch_gpu_num : ',prefetch_gpu_num) 192 | dataset = dataset.apply( 193 | tf.data.experimental.prefetch_to_device( 194 | '/device:GPU:{}'.format(prefetch_gpu_num))) 195 | 196 | # Get tensors 197 | iterator = dataset.make_one_shot_iterator() 198 | 199 | return iterator.get_next() 200 | -------------------------------------------------------------------------------- /generator/loader.py: -------------------------------------------------------------------------------- 1 | from scipy.io.wavfile import read as wavread 2 | import numpy as np 3 | 4 | import tensorflow as tf 5 | 6 | import sys 7 | 8 | 9 | def decode_audio(fp, fs=None, num_channels=1, normalize=False, fast_wav=False): 10 | """Decodes audio file paths into 32-bit floating point vectors. 11 | 12 | Args: 13 | fp: Audio file path. 14 | fs: If specified, resamples decoded audio to this rate. 15 | mono: If true, averages channels to mono. 16 | fast_wav: Assume fp is a standard WAV file (PCM 16-bit or float 32-bit). 17 | 18 | Returns: 19 | A np.float32 array containing the audio samples at specified sample rate. 20 | """ 21 | if fast_wav: 22 | # Read with scipy wavread (fast). 23 | _fs, _wav = wavread(fp) 24 | if fs is not None and fs != _fs: 25 | raise NotImplementedError('Scipy cannot resample audio.') 26 | if _wav.dtype == np.int16: 27 | _wav = _wav.astype(np.float32) 28 | _wav /= 32768. 29 | elif _wav.dtype == np.float32: 30 | _wav = np.copy(_wav) 31 | else: 32 | raise NotImplementedError('Scipy cannot process atypical WAV files.') 33 | else: 34 | # Decode with librosa load (slow but supports file formats like mp3). 35 | import librosa 36 | _wav, _fs = librosa.core.load(fp, sr=fs, mono=False) 37 | if _wav.ndim == 2: 38 | _wav = np.swapaxes(_wav, 0, 1) 39 | 40 | assert _wav.dtype == np.float32 41 | 42 | # At this point, _wav is np.float32 either [nsamps,] or [nsamps, nch]. 43 | # We want [nsamps, 1, nch] to mimic 2D shape of spectral feats. 44 | if _wav.ndim == 1: 45 | nsamps = _wav.shape[0] 46 | nch = 1 47 | else: 48 | nsamps, nch = _wav.shape 49 | _wav = np.reshape(_wav, [nsamps, 1, nch]) 50 | 51 | # Average (mono) or expand (stereo) channels 52 | if nch != num_channels: 53 | if num_channels == 1: 54 | _wav = np.mean(_wav, 2, keepdims=True) 55 | elif nch == 1 and num_channels == 2: 56 | _wav = np.concatenate([_wav, _wav], axis=2) 57 | else: 58 | raise ValueError('Number of audio channels not equal to num specified') 59 | if normalize: 60 | factor = np.max(np.abs(_wav)) 61 | if factor > 0: 62 | _wav /= factor 63 | 64 | return _wav 65 | 66 | 67 | def decode_extract_and_batch( 68 | fps, 69 | batch_size, 70 | slice_len, 71 | decode_fs, 72 | decode_num_channels, 73 | decode_normalize=True, 74 | decode_fast_wav=False, 75 | decode_parallel_calls=1, 76 | slice_randomize_offset=False, 77 | slice_first_only=False, 78 | slice_overlap_ratio=0, 79 | slice_pad_end=False, 80 | repeat=False, 81 | shuffle=False, 82 | shuffle_buffer_size=None, 83 | prefetch_size=None, 84 | prefetch_gpu_num=None): 85 | # tf.debugging.set_log_device_placement(True) 86 | """Decodes audio file paths into mini-batches of samples. 87 | 88 | Args: 89 | fps: List of audio file paths. 90 | batch_size: Number of items in the batch. 91 | slice_len: Length of the sliceuences in samples or feature timesteps. 92 | decode_fs: (Re-)sample rate for decoded audio files. 93 | decode_num_channels: Number of channels for decoded audio files. 94 | decode_normalize: If false, do not normalize audio waveforms. 95 | decode_fast_wav: If true, uses scipy to decode standard wav files. 96 | decode_parallel_calls: Number of parallel decoding threads. 97 | slice_randomize_offset: If true, randomize starting position for slice. 98 | slice_first_only: If true, only use first slice from each audio file. 99 | slice_overlap_ratio: Ratio of overlap between adjacent slices. 100 | slice_pad_end: If true, allows zero-padded examples from the end of each audio file. 101 | repeat: If true (for training), continuously iterate through the dataset. 102 | shuffle: If true (for training), buffer and shuffle the sliceuences. 103 | shuffle_buffer_size: Number of examples to queue up before grabbing a batch. 104 | prefetch_size: Number of examples to prefetch from the queue. 105 | prefetch_gpu_num: If specified, prefetch examples to GPU. 106 | 107 | Returns: 108 | A tuple of np.float32 tensors representing audio waveforms. 109 | audio: [batch_size, slice_len, 1, nch] 110 | """ 111 | # Create dataset of filepaths 112 | dataset = tf.data.Dataset.from_tensor_slices(fps) 113 | 114 | # Shuffle all filepaths every epoch 115 | if shuffle: 116 | dataset = dataset.shuffle(buffer_size=len(fps)) 117 | 118 | # Repeat 119 | if repeat: 120 | dataset = dataset.repeat() 121 | 122 | def _decode_audio_shaped(fp): 123 | _decode_audio_closure = lambda _fp: decode_audio( 124 | _fp, 125 | fs=decode_fs, 126 | num_channels=decode_num_channels, 127 | normalize=decode_normalize, 128 | fast_wav=decode_fast_wav) 129 | 130 | audio = tf.py_func( 131 | _decode_audio_closure, 132 | [fp], 133 | tf.float32, 134 | stateful=False) 135 | audio.set_shape([None, 1, decode_num_channels]) 136 | 137 | return audio 138 | 139 | # Decode audio 140 | dataset = dataset.map( 141 | _decode_audio_shaped, 142 | num_parallel_calls=decode_parallel_calls) 143 | 144 | # Parallel 145 | def _slice(audio): 146 | # Calculate hop size 147 | if slice_overlap_ratio < 0: 148 | raise ValueError('Overlap ratio must be greater than 0') 149 | slice_hop = int(round(slice_len * (1. - slice_overlap_ratio)) + 1e-4) 150 | if slice_hop < 1: 151 | raise ValueError('Overlap ratio too high') 152 | 153 | # Randomize starting phase: 154 | if slice_randomize_offset: 155 | start = tf.random_uniform([], maxval=slice_len, dtype=tf.int32) 156 | audio = audio[start:] 157 | 158 | # Extract sliceuences 159 | audio_slices = tf.contrib.signal.frame( 160 | audio, 161 | slice_len, 162 | slice_hop, 163 | pad_end=slice_pad_end, 164 | pad_value=0, 165 | axis=0) 166 | 167 | # Only use first slice if requested 168 | if slice_first_only: 169 | audio_slices = audio_slices[:1] 170 | 171 | return audio_slices 172 | 173 | def _slice_dataset_wrapper(audio): 174 | audio_slices = _slice(audio) 175 | return tf.data.Dataset.from_tensor_slices(audio_slices) 176 | 177 | # Extract parallel sliceuences from both audio and features 178 | dataset = dataset.flat_map(_slice_dataset_wrapper) 179 | 180 | # Shuffle examples 181 | if shuffle: 182 | dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) 183 | 184 | # Make batches 185 | dataset = dataset.batch(batch_size, drop_remainder=True) 186 | 187 | # Prefetch a number of batches 188 | if prefetch_size is not None: 189 | dataset = dataset.prefetch(prefetch_size) 190 | if prefetch_gpu_num is not None and prefetch_gpu_num >= 0: 191 | print('prefetch_gpu_num : ',prefetch_gpu_num) 192 | dataset = dataset.apply( 193 | tf.data.experimental.prefetch_to_device( 194 | '/device:GPU:{}'.format(prefetch_gpu_num))) 195 | 196 | # Get tensors 197 | iterator = dataset.make_one_shot_iterator() 198 | 199 | return iterator.get_next() 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Related Works 2 | 3 | 1) [**IR-GAN: Room Impulse Response Generator for Far-field Speech Recognition (INTERSPEECH 2021)**](https://github.com/anton-jeran/IR-GAN) 4 | 2) [**FAST-RIR: FAST NEURAL DIFFUSE ROOM IMPULSE RESPONSE GENERATOR (ICASSP 2022)**](https://github.com/anton-jeran/FAST-RIR) 5 | 3) [**MESH2IR: Neural Acoustic Impulse Response Generator for Complex 3D Scenes (ACM Multimedia 2022)**](https://anton-jeran.github.io/M2IR/) 6 | 7 | **NEWS: We release MULTI-CHANNEL MULTI-SPEAKER MULTI-SPATIAL AUDIO CODEC. The official code of our network [**M3-AUDIODEC**](https://github.com/anton-jeran/MULTI-AUDIODEC) is available.** 8 | 9 | # TS-RIR (Accepted to IEEE ASRU 2021) 10 | 11 | This is the official implementation of **TS-RIRGAN**. We started our implementation from [**WaveGAN**](https://github.com/chrisdonahue/wavegan). TS-RIRGAN is a one-dimensional CycleGAN that takes synthetic RIRs as raw waveform audio and translates it into real RIRs. Our network architecture is shown below. 12 | 13 | 14 | 15 | ![Architecture-1.png](https://github.com/anton-jeran/TS-RIR/blob/main/images/Architecture-1.png) 16 | 17 | You can find more details about our implementation from [**TS-RIR: Translated synthetic room impulse responses for speech augmentation**](https://arxiv.org/pdf/2103.16804v2.pdf). 18 | 19 | 20 | ## Requirements 21 | 22 | ``` 23 | tensorflow-gpu==1.12.0 24 | scipy==1.0.0 25 | matplotlib==3.0.2 26 | librosa==0.6.2 27 | ffmpeg ==4.2.1 28 | cuda ==9.0.176 29 | cudnn ==7.6.5 30 | ``` 31 | 32 | ## Datasets 33 | 34 | In order to train **TS-RIRGAN** to translate Synthetic RIRs to Real RIRs, download the RIRs from [**IRs_for_GAN**](https://drive.google.com/file/d/1ivj_UZ5j5inAZwsDTCQ6jEvI5JDtwH_2/view?usp=sharing). Unzip **IRs_for_GAN** directory inside **TS-RIR** folder. 35 | 36 | This folder contains Synthetic RIRs generated using [**Geometric Acoustic Simulator**](https://github.com/RoyJames/pygsound) and Real RIRs from [**BUT ReverbDB**](https://speech.fit.vutbr.cz/software/but-speech-fit-reverb-database) dataset.et. 37 | 38 | ## Translate Synthetic RIRs to Real RIRs using the trained model 39 | 40 | Download all the [**MODEL FILES**](https://drive.google.com/file/d/1fdAaIkvFbky-Xf7iuYCFa87nWpSaI1Ow/view?usp=sharing) and move all the files to the **generator** folder. Create a similar structure as the dataset inside the **generator** folder. You can convert **Synthetic RIRs** to **Real RIRs** by running the following command inside the **generator** folder. 41 | 42 | 43 | ``` 44 | export CUDA_VISIBLE_DEVICES=1 45 | python3 generator.py --data1_dir ../IRs_for_GAN/Real_IRs/train --data1_first_slice --data1_pad_end --data1_fast_wav --data2_dir ../IRs_for_GAN/Synthetic_IRs/train --data2_first_slice --data2_pad_end --data2_fast_wav 46 | ``` 47 | 48 | ## Training TS-RIRGAN 49 | 50 | Run following command to train TS-RIRGAN. 51 | 52 | ``` 53 | export CUDA_VISIBLE_DEVICES=0 54 | python3 train_TSRIRgan.py train ./train --data1_dir ./IRs_for_GAN/Real_IRs/train --data1_first_slice --data1_pad_end --data1_fast_wav --data2_dir ./IRs_for_GAN/Synthetic_IRs/train --data2_first_slice --data2_pad_end --data2_fast_wav 55 | ``` 56 | 57 | To backup the mode for every 1 hour, run the follwing command 58 | 59 | 60 | ``` 61 | export CUDA_VISIBLE_DEVICES=1 62 | python3 backup.py ./train 60 63 | ``` 64 | 65 | To monitor the training using tensorboard, run the followind command 66 | 67 | ``` 68 | tensorboard --logdir=./train 69 | ``` 70 | 71 | ## Results 72 | The figure below shows Synthetic RIR generated using [**Geometric Acoustic Simulator**](https://github.com/RoyJames/pygsound), Synthetic RIR translated to Real RIR using our [**TS-RIRGAN**](https://arxiv.org/pdf/2103.16804v2.pdf) and a Real RIR from [**BUT ReverbDB**](https://speech.fit.vutbr.cz/software/but-speech-fit-reverb-database) dataset. Please note that there is no one-to-one relationship between Synthetic RIR and Real RIR from **BUT ReverbDB**. We show an example of Real RIR to compare the energy distribution of our translated RIR with the energy distribution of Real RIR. 73 | ![spectrogram.png](https://github.com/anton-jeran/TS-RIR/blob/main/images/spectrogram.png) 74 | 75 | ## Output 76 | 77 | You can download RIRs generated for our Kaldi Far-field Automatic Speech Recognition Exepriments. 78 | 79 | - RIR generated using Geomteric Acoustic Simulator ([**GAS**](https://github.com/RoyJames/pygsound)). -- [**Output**](https://drive.google.com/file/d/175g-lZSJpU1yrjm8LB5c3tOtdJxi-RAc/view?usp=sharing) 80 | - Perform **room equalization** on Synthetic RIRs from **GAS**. -- [**Output**](https://drive.google.com/file/d/1Xo8eX3vlZMvLyKBYAGovO9XgpeFjpsMk/view?usp=sharing) 81 | - First, perform **room equalization**, then **translate the equalized synthetic RIR to a real RIR**. -- [**Output**](https://drive.google.com/file/d/1zKGS2ENYF_YTnhaifXKQyh3Q-g9wFeLK/view?usp=sharing) 82 | - Only **translate synthetic RIR to real RIR**. -- [**Output**](https://drive.google.com/file/d/1ZISePenNQ37_0xazlr1TSMjnVxktkrjD/view?usp=sharing) 83 | - First, **translate a synthetic RIR to a real RIR**, then **perform room equalization to the translated RIR**. -- [**Output**](https://drive.google.com/file/d/1PrLvFOl10qqztNob8whZFdErHg6yYrSt/view?usp=sharing) 84 | 85 | 86 | ### Attribution 87 | 88 | If you use this code in your research, please consider citing 89 | 90 | ``` 91 | @article{DBLP:journals/corr/abs-2103-16804, 92 | author = {Anton Ratnarajah and 93 | Zhenyu Tang and 94 | Dinesh Manocha}, 95 | title = {{TS-RIR:} Translated synthetic room impulse responses for speech augmentation}, 96 | journal = {CoRR}, 97 | volume = {abs/2103.16804}, 98 | year = {2021} 99 | } 100 | ``` 101 | 102 | ``` 103 | @inproceedings{donahue2019wavegan, 104 | title={Adversarial Audio Synthesis}, 105 | author={Donahue, Chris and McAuley, Julian and Puckette, Miller}, 106 | booktitle={ICLR}, 107 | year={2019} 108 | } 109 | ``` 110 | 111 | If you use **Sub-band Room Equalization** please consider citing 112 | ``` 113 | @inproceedings{9054454, 114 | author={Z. {Tang} and H. {Meng} and D. {Manocha}}, 115 | booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 116 | title={Low-Frequency Compensated Synthetic Impulse Responses For Improved Far-Field Speech Recognition}, 117 | year={2020}, 118 | volume={}, 119 | number={}, 120 | pages={6974-6978}, 121 | } 122 | 123 | ``` 124 | If you use **Real RIRs** from our dataset folder([**IRs_for_GAN**](https://drive.google.com/file/d/1ivj_UZ5j5inAZwsDTCQ6jEvI5JDtwH_2/view?usp=sharing)), please consider citing 125 | 126 | ``` 127 | @article{DBLP:journals/jstsp/SzokeSMPC19, 128 | author = {Igor Sz{\"{o}}ke and 129 | Miroslav Sk{\'{a}}cel and 130 | Ladislav Mosner and 131 | Jakub Paliesek and 132 | Jan Honza Cernock{\'{y}}}, 133 | title = {Building and Evaluation of a Real Room Impulse Response Dataset}, 134 | journal = {{IEEE} J. Sel. Top. Signal Process.}, 135 | volume = {13}, 136 | number = {4}, 137 | pages = {863--876}, 138 | year = {2019} 139 | } 140 | ``` 141 | If you use **Synthetic RIRs** from our dataset folder([**IRs_for_GAN**](https://drive.google.com/file/d/1ivj_UZ5j5inAZwsDTCQ6jEvI5JDtwH_2/view?usp=sharing)), please consider citing 142 | 143 | ``` 144 | @inproceedings{9052932, 145 | author={Z. {Tang} and L. {Chen} and B. {Wu} and D. {Yu} and D. {Manocha}}, 146 | booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 147 | title={Improving Reverberant Speech Training Using Diffuse Acoustic Simulation}, 148 | year={2020}, 149 | volume={}, 150 | number={}, 151 | pages={6969-6973}, 152 | } 153 | ``` 154 | 155 | 156 | -------------------------------------------------------------------------------- /web/js/wavegan_sequencer.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (wavegan) { 4 | // Config 5 | var cfg = wavegan.cfg; 6 | 7 | var Sequencer = function (canvas, voices) { 8 | this.canvas = canvas; 9 | this.canvasCtx = this.canvas.getContext('2d'); 10 | this.canvasWidth = this.canvas.width; 11 | this.canvasHeight = this.canvas.height; 12 | 13 | this.voices = voices; 14 | 15 | this.linesBuffer = document.createElement('canvas'); 16 | this.linesBufferCtx = this.linesBuffer.getContext('2d'); 17 | this.linesBuffer.width = this.canvasWidth; 18 | this.linesBuffer.height = this.canvasHeight; 19 | 20 | this.cellsBuffer = document.createElement('canvas'); 21 | this.cellsBufferCtx = this.cellsBuffer.getContext('2d'); 22 | this.cellsBuffer.width = this.canvasWidth; 23 | this.cellsBuffer.height = this.canvasHeight; 24 | 25 | // Create grid 26 | this.numCols = cfg.sequencer.numCols; 27 | this.numRows = this.voices.length; 28 | this.grid = []; 29 | for (var j = 0; j < this.numRows; ++j) { 30 | var row = []; 31 | for (var i = 0; i < this.numCols; ++i) { 32 | row.push(cfg.sequencer.pattern[j][i]); 33 | } 34 | this.grid.push(row); 35 | } 36 | 37 | // Render 38 | this._redrawLines(); 39 | this._redrawCells(); 40 | 41 | // Bind click event 42 | var that = this; 43 | this.canvas.addEventListener('click', function (event) { 44 | var x = event.offsetX; 45 | var y = event.offsetY; 46 | var grid = that._absToGrid(x, y); 47 | 48 | var gridValid = true; 49 | gridValid &= grid.i >= 0; 50 | gridValid &= grid.i < that.numCols; 51 | gridValid &= grid.j >= 0; 52 | gridValid &= grid.j < that.numRows; 53 | 54 | if (gridValid) { 55 | var i = Math.floor(grid.i); 56 | var j = Math.floor(grid.j); 57 | that.grid[j][i] = 1 - that.grid[j][i]; 58 | that._redrawCells(); 59 | that.render(); 60 | } 61 | }); 62 | 63 | // Playback state 64 | this.delayMs = null; 65 | this.swing = 0.5; 66 | this.setTempoBpm(120); 67 | this.playing = false; 68 | this.tick = 0; 69 | }; 70 | 71 | Sequencer.prototype._tick = function () { 72 | if (!this.playing) { 73 | return; 74 | } 75 | 76 | // Audio playback 77 | for (var j = 0; j < this.voices.length; ++j) { 78 | if (this.grid[j][this.tick] > 0) { 79 | this.voices[j].bang(); 80 | } 81 | } 82 | 83 | // Render grid 84 | this.render(); 85 | 86 | // Calculate swing delay 87 | var totalDelay = this.delayMs * 2; 88 | if (this.tick % 2 == 0) { 89 | var delay = this.swing * totalDelay; 90 | } 91 | else { 92 | var delay = (1 - this.swing) * totalDelay; 93 | } 94 | 95 | var that = this; 96 | setTimeout(function () {that._tick();}, delay); 97 | this.tick += 1; 98 | this.tick = this.tick % this.numCols; 99 | }; 100 | Sequencer.prototype._absToGrid = function (x, y) { 101 | var labelWidth = cfg.sequencer.labelWidth; 102 | var gridWidth = this.canvasWidth - labelWidth; 103 | var gridHeight = this.canvasHeight; 104 | 105 | var cellWidth = gridWidth / this.numCols; 106 | var cellHeight = gridHeight / this.numRows; 107 | 108 | return { 109 | i: (x - labelWidth) / cellWidth, 110 | j: y / cellHeight 111 | }; 112 | }; 113 | Sequencer.prototype._gridToAbs = function (i, j) { 114 | var labelWidth = cfg.sequencer.labelWidth; 115 | var gridWidth = this.canvasWidth - labelWidth; 116 | var gridHeight = this.canvasHeight; 117 | 118 | var cellWidth = gridWidth / this.numCols; 119 | var cellHeight = gridHeight / this.numRows; 120 | 121 | return { 122 | x: (i * cellWidth) + labelWidth, 123 | y: j * cellHeight 124 | }; 125 | }; 126 | Sequencer.prototype._redrawCells = function () { 127 | var ctx = this.cellsBufferCtx; 128 | var w = this.canvasWidth; 129 | var h = this.canvasHeight; 130 | 131 | // Draw buffer 132 | ctx.clearRect(0, 0, w, h); 133 | var topLeft = this._gridToAbs(0, 0); 134 | var bottomRight = this._gridToAbs(this.numCols, this.numRows); 135 | 136 | // Draw grid 137 | for (var j = 0; j < this.numRows; ++j) { 138 | for (var i = 0; i < this.numCols; ++i) { 139 | if (this.grid[j][i] > 0) { 140 | var cellTopLeft = this._gridToAbs(i, j); 141 | var cellBottomRight = this._gridToAbs(i + 1, j + 1); 142 | var cellWidth = cellBottomRight.x - cellTopLeft.x; 143 | var cellHeight = cellBottomRight.y - cellTopLeft.y; 144 | 145 | var hue = (j / (this.numRows - 1)) * 255; 146 | var hsl = 'hsl(' + String(hue) + ', 80%, 60%)'; 147 | ctx.fillStyle = hsl; 148 | ctx.fillRect(cellTopLeft.x, cellTopLeft.y, cellWidth, cellHeight); 149 | } 150 | } 151 | } 152 | 153 | // Draw grid lines 154 | ctx.drawImage(this.linesBuffer, 0, 0); 155 | }; 156 | Sequencer.prototype._redrawLines = function () { 157 | var ctx = this.linesBufferCtx; 158 | var w = this.canvasWidth; 159 | var h = this.canvasHeight; 160 | 161 | // Clear background 162 | ctx.clearRect(0, 0, w, h); 163 | var topLeft = this._gridToAbs(0, 0); 164 | var bottomRight = this._gridToAbs(this.numCols, this.numRows); 165 | 166 | // Draw row lines 167 | ctx.strokeStyle = '#ffffff'; 168 | ctx.lineWidth = 1; 169 | ctx.font = '18px sans-serif'; 170 | ctx.fillStyle = '#ffffff'; 171 | var rowStart = topLeft.x; 172 | var rowEnd = bottomRight.x; 173 | for (var j = 0; j < this.numRows + 1; ++j) { 174 | var y = this._gridToAbs(0, j).y; 175 | ctx.beginPath(); 176 | ctx.moveTo(rowStart, y) 177 | ctx.lineTo(rowEnd, y); 178 | ctx.stroke(); 179 | ctx.fillText('Drum ' + String(j + 1), 0, y + 26); 180 | } 181 | 182 | // Draw columns 183 | var colStart = topLeft.y; 184 | var colEnd = bottomRight.y 185 | for (var i = 0; i < this.numCols + 1; ++i) { 186 | if (i % 4 == 0) { 187 | ctx.strokeStyle = '#ffffff'; 188 | ctx.lineWidth = 4; 189 | } 190 | else { 191 | ctx.strokeStyle = '#ffffff'; 192 | ctx.lineWidth = 1; 193 | } 194 | 195 | var x = this._gridToAbs(i, 0).x 196 | ctx.beginPath(); 197 | ctx.moveTo(x, colStart) 198 | ctx.lineTo(x, colEnd); 199 | ctx.stroke(); 200 | } 201 | }; 202 | 203 | Sequencer.prototype.render = function () { 204 | var ctx = this.canvasCtx; 205 | var w = this.canvasWidth; 206 | var h = this.canvasHeight; 207 | 208 | // Draw background 209 | ctx.clearRect(0, 0, w, h); 210 | var topLeft = this._gridToAbs(0, 0); 211 | var bottomRight = this._gridToAbs(this.numCols, this.numRows); 212 | ctx.fillStyle = '#000000'; 213 | ctx.fillRect(topLeft.x, topLeft.y, bottomRight.x - topLeft.x, bottomRight.y - topLeft.y); 214 | 215 | // Draw cells 216 | ctx.drawImage(this.cellsBuffer, 0, 0); 217 | 218 | // Draw lines 219 | ctx.drawImage(this.linesBuffer, 0, 0); 220 | 221 | if (this.playing) { 222 | var topLeft = this._gridToAbs(this.tick, 0); 223 | var bottomRight = this._gridToAbs(this.tick + 1, this.numRows); 224 | ctx.fillStyle = '#ff0000'; 225 | ctx.globalAlpha = 0.5; 226 | ctx.fillRect(topLeft.x, topLeft.y, bottomRight.x - topLeft.x, bottomRight.y - topLeft.y); 227 | ctx.globalAlpha = 1; 228 | } 229 | }; 230 | 231 | Sequencer.prototype.setTempoBpm = function (bpm) { 232 | var bps = bpm / 60; 233 | var cellsPerBeat = this.numCols / 4; 234 | var cellsPerSecond = bps * cellsPerBeat; 235 | var secondsPerCell = 1 / cellsPerSecond; 236 | this.delayMs = secondsPerCell * 1000; 237 | }; 238 | Sequencer.prototype.setSwing = function (swing) { 239 | this.swing = swing; 240 | }; 241 | Sequencer.prototype.play = function () { 242 | if (!this.playing) { 243 | this.playing = true; 244 | this.tick = 0; 245 | this._tick(); 246 | } 247 | }; 248 | Sequencer.prototype.stop = function () { 249 | this.playing = false; 250 | this.render(); 251 | }; 252 | Sequencer.prototype.toggle = function () { 253 | if (this.playing) { 254 | this.stop(); 255 | } 256 | else { 257 | this.play(); 258 | } 259 | }; 260 | Sequencer.prototype.clear = function () { 261 | for (var j = 0; j < this.numRows; ++j) { 262 | for (var i = 0; i < this.numCols; ++i) { 263 | this.grid[j][i] = 0; 264 | } 265 | } 266 | this._redrawCells(); 267 | this.render(); 268 | }; 269 | 270 | // Exports 271 | wavegan.sequencer = {}; 272 | wavegan.sequencer.Sequencer = Sequencer; 273 | 274 | })(window.wavegan); 275 | -------------------------------------------------------------------------------- /generator/generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import loader 3 | from IPython.display import display, Audio 4 | import math 5 | import os 6 | import numpy as np 7 | import librosa 8 | 9 | 10 | def generate_real(fps2,args): 11 | 12 | no_samples = len(fps2) 13 | no_set =int(no_samples/64) 14 | no_remain = no_samples%64 15 | 16 | for k in range (no_set+1): 17 | 18 | if(no_set == k): 19 | print("k is ",k) 20 | s_fps2 = fps2[no_samples-64:no_samples] 21 | else: 22 | s_fps2 = fps2[(64*k):(64*(k+1))] 23 | tf.reset_default_graph() 24 | saver = tf.train.import_meta_graph('infer.meta') 25 | graph = tf.get_default_graph() 26 | sess = tf.InteractiveSession() 27 | saver.restore(sess, 'model.ckpt') 28 | 29 | with tf.name_scope('samp_x_synthetic'): 30 | x_synthetic = loader.decode_extract_and_batch( 31 | s_fps2, 32 | batch_size=args.train_batch_size, 33 | slice_len=args.data2_slice_len, 34 | decode_fs=args.data2_sample_rate, 35 | decode_num_channels=args.data2_num_channels, 36 | decode_fast_wav=args.data2_fast_wav, 37 | decode_parallel_calls=4, 38 | slice_randomize_offset=False if args.data2_first_slice else True, 39 | slice_first_only=args.data2_first_slice, 40 | slice_overlap_ratio=0. if args.data2_first_slice else args.data2_overlap_ratio, 41 | slice_pad_end=True if args.data2_first_slice else args.data2_pad_end, 42 | repeat=True, 43 | shuffle=True, 44 | shuffle_buffer_size=4096, 45 | prefetch_size=args.train_batch_size * 4, 46 | prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0] 47 | 48 | 49 | _x_synthetic = x_synthetic.eval(session=tf.Session()) 50 | print("input ", len(_x_synthetic), len(_x_synthetic[1])) 51 | # _z = (np.random.rand(1000, 100) * 2.) - 1 52 | # Synthesize G(z) 53 | x_synthetic = graph.get_tensor_by_name('x_synthetic:0') 54 | x_real = graph.get_tensor_by_name('x_real:0') 55 | G_real = graph.get_tensor_by_name('G_real_x:0') 56 | _G_real = sess.run(G_real, {x_synthetic: _x_synthetic,x_real: _x_synthetic}) 57 | print("G_S" , len(_G_real), len(_G_real[1])) 58 | for i in range (64): 59 | print("i ",i) 60 | wav=_G_real[i][0:16000] 61 | name = 'IRs_for_GAN/' + s_fps2[i] 62 | print("name ",name) 63 | librosa.output.write_wav(path=name,y=wav,sr=16000) 64 | 65 | 66 | if __name__ == '__main__': 67 | import argparse 68 | import glob 69 | import sys 70 | 71 | parser = argparse.ArgumentParser() 72 | 73 | 74 | data1_args = parser.add_argument_group('Data1') 75 | data1_args.add_argument('--data1_dir', type=str, 76 | help='Data directory containing *only* audio files to load') 77 | data1_args.add_argument('--data1_sample_rate', type=int, 78 | help='Number of audio samples per second') 79 | data1_args.add_argument('--data1_slice_len', type=int, choices=[16384, 32768, 65536], 80 | help='Number of audio samples per slice (maximum generation length)') 81 | data1_args.add_argument('--data1_num_channels', type=int, 82 | help='Number of audio channels to generate (for >2, must match that of data)') 83 | data1_args.add_argument('--data1_overlap_ratio', type=float, 84 | help='Overlap ratio [0, 1) between slices') 85 | data1_args.add_argument('--data1_first_slice', action='store_true', dest='data1_first_slice', 86 | help='If set, only use the first slice each audio example') 87 | data1_args.add_argument('--data1_pad_end', action='store_true', dest='data1_pad_end', 88 | help='If set, use zero-padded partial slices from the end of each audio file') 89 | data1_args.add_argument('--data1_normalize', action='store_true', dest='data1_normalize', 90 | help='If set, normalize the training examples') 91 | data1_args.add_argument('--data1_fast_wav', action='store_true', dest='data1_fast_wav', 92 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 93 | data1_args.add_argument('--data1_prefetch_gpu_num', type=int, 94 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 95 | 96 | data2_args = parser.add_argument_group('Data2') 97 | data2_args.add_argument('--data2_dir', type=str, 98 | help='Data directory containing *only* audio files to load') 99 | data2_args.add_argument('--data2_sample_rate', type=int, 100 | help='Number of audio samples per second') 101 | data2_args.add_argument('--data2_slice_len', type=int, choices=[16384, 32768, 65536], 102 | help='Number of audio samples per slice (maximum generation length)') 103 | data2_args.add_argument('--data2_num_channels', type=int, 104 | help='Number of audio channels to generate (for >2, must match that of data)') 105 | data2_args.add_argument('--data2_overlap_ratio', type=float, 106 | help='Overlap ratio [0, 1) between slices') 107 | data2_args.add_argument('--data2_first_slice', action='store_true', dest='data2_first_slice', 108 | help='If set, only use the first slice each audio example') 109 | data2_args.add_argument('--data2_pad_end', action='store_true', dest='data2_pad_end', 110 | help='If set, use zero-padded partial slices from the end of each audio file') 111 | data2_args.add_argument('--data2_normalize', action='store_true', dest='data2_normalize', 112 | help='If set, normalize the training examples') 113 | data2_args.add_argument('--data2_fast_wav', action='store_true', dest='data2_fast_wav', 114 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 115 | data2_args.add_argument('--data2_prefetch_gpu_num', type=int, 116 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 117 | 118 | 119 | TSRIRgan_args = parser.add_argument_group('TSRIRGAN') 120 | TSRIRgan_args.add_argument('--TSRIRgan_latent_dim', type=int, 121 | help='Number of dimensions of the latent space') 122 | TSRIRgan_args.add_argument('--TSRIRgan_kernel_len', type=int, 123 | help='Length of 1D filter kernels') 124 | TSRIRgan_args.add_argument('--TSRIRgan_dim', type=int, 125 | help='Dimensionality multiplier for model of G and D') 126 | TSRIRgan_args.add_argument('--TSRIRgan_batchnorm', action='store_true', dest='TSRIRgan_batchnorm', 127 | help='Enable batchnorm') 128 | TSRIRgan_args.add_argument('--TSRIRgan_disc_nupdates', type=int, 129 | help='Number of discriminator updates per generator update') 130 | TSRIRgan_args.add_argument('--TSRIRgan_loss', type=str, choices=['cycle-gan'], 131 | help='Which GAN loss to use') 132 | TSRIRgan_args.add_argument('--TSRIRgan_genr_upsample', type=str, choices=['zeros', 'nn'], 133 | help='Generator upsample strategy') 134 | TSRIRgan_args.add_argument('--TSRIRgan_genr_pp', action='store_true', dest='TSRIRgan_genr_pp', 135 | help='If set, use post-processing filter') 136 | TSRIRgan_args.add_argument('--TSRIRgan_genr_pp_len', type=int, 137 | help='Length of post-processing filter for DCGAN') 138 | TSRIRgan_args.add_argument('--TSRIRgan_disc_phaseshuffle', type=int, 139 | help='Radius of phase shuffle operation') 140 | 141 | train_args = parser.add_argument_group('Train') 142 | train_args.add_argument('--train_batch_size', type=int, 143 | help='Batch size') 144 | train_args.add_argument('--train_save_secs', type=int, 145 | help='How often to save model') 146 | train_args.add_argument('--train_summary_secs', type=int, 147 | help='How often to report summaries') 148 | 149 | preview_args = parser.add_argument_group('Preview') 150 | preview_args.add_argument('--preview_n', type=int, 151 | help='Number of samples to preview') 152 | 153 | # incept_args = parser.add_argument_group('Incept') 154 | # incept_args.add_argument('--incept_metagraph_fp', type=str, 155 | # help='Inference model for inception score') 156 | # incept_args.add_argument('--incept_ckpt_fp', type=str, 157 | # help='Checkpoint for inference model') 158 | # incept_args.add_argument('--incept_n', type=int, 159 | # help='Number of generated examples to test') 160 | # incept_args.add_argument('--incept_k', type=int, 161 | # help='Number of groups to test') 162 | 163 | parser.set_defaults( 164 | data1_dir=None, 165 | data1_sample_rate=16000, 166 | data1_slice_len=16384, 167 | data1_num_channels=1, 168 | data1_overlap_ratio=0., 169 | data1_first_slice=False, 170 | data1_pad_end=False, 171 | data1_normalize=False, 172 | data1_fast_wav=False, 173 | data1_prefetch_gpu_num=0, 174 | data2_dir=None, 175 | data2_sample_rate=16000, 176 | data2_slice_len=16384, 177 | data2_num_channels=1, 178 | data2_overlap_ratio=0., 179 | data2_first_slice=False, 180 | data2_pad_end=False, 181 | data2_normalize=False, 182 | data2_fast_wav=False, 183 | data2_prefetch_gpu_num=0, 184 | TSRIRgan_latent_dim=100, 185 | TSRIRgan_kernel_len=25, 186 | TSRIRgan_dim=64, 187 | TSRIRgan_batchnorm=False, 188 | TSRIRgan_disc_nupdates=5, 189 | TSRIRgan_loss='cycle-gan', 190 | TSRIRgan_genr_upsample='zeros', 191 | TSRIRgan_genr_pp=False, 192 | TSRIRgan_genr_pp_len=512, 193 | TSRIRgan_disc_phaseshuffle=2, 194 | train_batch_size=64, 195 | train_save_secs=300, 196 | train_summary_secs=120, 197 | preview_n=32)#, 198 | # incept_metagraph_fp='./eval/inception/infer.meta', 199 | # incept_ckpt_fp='./eval/inception/best_acc-103005', 200 | # incept_n=5000, 201 | # incept_k=10) 202 | 203 | args = parser.parse_args() 204 | 205 | 206 | 207 | # Make model kwarg dicts 208 | setattr(args, 'TSRIRgan_g_kwargs', { 209 | 'slice_len': args.data1_slice_len, 210 | 'nch': args.data1_num_channels, 211 | 'kernel_len': args.TSRIRgan_kernel_len, 212 | 'dim': args.TSRIRgan_dim, 213 | 'use_batchnorm': args.TSRIRgan_batchnorm, 214 | 'upsample': args.TSRIRgan_genr_upsample 215 | }) 216 | setattr(args, 'TSRIRgan_d_kwargs', { 217 | 'kernel_len': args.TSRIRgan_kernel_len, 218 | 'dim': args.TSRIRgan_dim, 219 | 'use_batchnorm': args.TSRIRgan_batchnorm, 220 | 'phaseshuffle_rad': args.TSRIRgan_disc_phaseshuffle 221 | }) 222 | 223 | 224 | fps1 = glob.glob(os.path.join(args.data1_dir, '*')) 225 | if len(fps1) == 0: 226 | raise Exception('Did not find any audio files in specified directory(real_IR)') 227 | print('Found {} audio files in specified directory'.format(len(fps1))) 228 | fps2 = glob.glob(os.path.join(args.data2_dir, '*')) 229 | if len(fps2) == 0: 230 | raise Exception('Did not find any audio files in specified directory(synthetic_IR)') 231 | print('Found {} audio files in specified directory'.format(len(fps2))) 232 | 233 | generate_real(fps2, args) -------------------------------------------------------------------------------- /web/js/wavegan_ui.js: -------------------------------------------------------------------------------- 1 | window.wavegan = window.wavegan || {}; 2 | 3 | (function (deeplearn, wavegan) { 4 | // Config 5 | var cfg = wavegan.cfg; 6 | if (cfg.reqs.userCanceled) { 7 | document.getElementById('demo').setAttribute('hidden', ''); 8 | document.getElementById('canceled').removeAttribute('hidden'); 9 | return; 10 | } 11 | 12 | // Make a new random vector 13 | var random_vector = function () { 14 | var d = wavegan.cfg.net.zDim; 15 | var z = new Float32Array(d); 16 | for (var i = 0; i < d; ++i) { 17 | z[i] = (Math.random() * 2.) - 1.; 18 | } 19 | return z; 20 | }; 21 | 22 | // Linear interpolation between two vectors 23 | var z_lerp = function (z0, z1, a) { 24 | if (z0.length !== z1.length) { 25 | throw 'Vector length differs'; 26 | } 27 | 28 | var interp = new Float32Array(z0.length); 29 | for (var i = 0; i < z0.length; ++i) { 30 | interp[i] = (1. - a) * z0[i] + a * z1[i]; 31 | } 32 | 33 | return interp; 34 | }; 35 | 36 | // Class to handle UI interactions with player/visualizer 37 | var globalAudioCtxChromeWorkaround = null; 38 | var globalAudioCtxHasBeenResumed = false; 39 | var Zactor = function (fs, div, name, color) { 40 | this.canvas = div.children[0]; 41 | this.button = div.children[1]; 42 | this.player = new wavegan.player.ResamplingPlayer(fs); 43 | this.visualizer = new wavegan.visualizer.WaveformVisualizer(this.canvas, name, color); 44 | this.animFramesRemaining = 0; 45 | this.z = null; 46 | this.Gz = null; 47 | this.filename = null; 48 | 49 | var that = this; 50 | this.canvas.onclick = function (event) { 51 | that.bang(); 52 | }; 53 | 54 | // Change button 55 | div.children[1].onclick = function (event) { 56 | that.randomize(); 57 | }; 58 | 59 | // Save button 60 | div.children[2].onclick = function (event) { 61 | if (that.Gz !== null) { 62 | if (that.filename === null) { 63 | that.filename = wavegan.savewav.randomFilename(); 64 | } 65 | wavegan.savewav.saveWav(that.filename, that.Gz); 66 | } 67 | }; 68 | }; 69 | Zactor.prototype.setPrerendered = function (z, Gz) { 70 | this.z = z; 71 | this.Gz = Gz; 72 | this.filename = null; 73 | this.player.setSample(Gz, 16000); 74 | this.visualizer.setSample(Gz); 75 | }; 76 | Zactor.prototype.setZ = function (z) { 77 | var Gz = wavegan.net.eval([z])[0]; 78 | this.setPrerendered(z, Gz); 79 | }; 80 | Zactor.prototype.randomize = function () { 81 | var oldGain = gainNode.gain.value; 82 | gainNode.gain.value = 0; 83 | 84 | var z = random_vector(); 85 | this.setZ(z); 86 | 87 | gainNode.gain.value = oldGain; 88 | }; 89 | Zactor.prototype.readBlock = function (buffer) { 90 | this.player.readBlock(buffer); 91 | }; 92 | Zactor.prototype.bang = function () { 93 | if (!globalAudioCtxHasBeenResumed && globalAudioCtxChromeWorkaround !== null) { 94 | globalAudioCtxChromeWorkaround.resume(); 95 | globalAudioCtxHasBeenResumed = true; 96 | } 97 | 98 | this.player.bang(); 99 | 100 | var animFramesTot = Math.round(1024 / cfg.ui.rmsAnimDelayMs); 101 | this.animFramesRemaining = animFramesTot; 102 | var lastRemaining = this.animFramesRemaining; 103 | var that = this; 104 | var animFrame = function () { 105 | var rms = that.player.getRmsAmplitude(); 106 | var initPeriod = animFramesTot - that.animFramesRemaining; 107 | if (initPeriod < 8) { 108 | var fade = initPeriod / 8; 109 | rms = (1 - fade) * 0.25 + fade * rms; 110 | } 111 | that.visualizer.render(rms); 112 | 113 | if (that.animFramesRemaining > 0 && lastRemaining === that.animFramesRemaining) { 114 | --that.animFramesRemaining; 115 | --lastRemaining; 116 | setTimeout(animFrame, cfg.ui.rmsAnimDelayMs); 117 | } 118 | }; 119 | 120 | animFrame(); 121 | }; 122 | 123 | // Initializer for waveform players/visualizers 124 | var zactors = null; 125 | var initZactors = function (audioCtx, cherries) { 126 | var nzactors = cfg.ui.zactorNumRows * cfg.ui.zactorNumCols; 127 | 128 | // Create zactors 129 | zactors = []; 130 | for (var i = 0; i < nzactors; ++i) { 131 | var div = document.getElementById('zactor' + String(i)); 132 | var name = 'Drum ' + String(i + 1); 133 | var hue = (i / (nzactors - 1)) * 255; 134 | var hsl = 'hsl(' + String(hue) + ', 80%, 60%)'; 135 | zactors.push(new Zactor(audioCtx.sampleRate, div, name, hsl)); 136 | } 137 | 138 | // Render initial batch 139 | var zs = []; 140 | if (cherries === null || cfg.net.cherries.length != nzactors) { 141 | for (var i = 0; i < nzactors; ++i) { 142 | zs.push(random_vector()); 143 | } 144 | } 145 | else { 146 | for (var i = 0; i < nzactors; ++i) { 147 | zs.push(cherries[cfg.net.cherries[i]]); 148 | } 149 | } 150 | 151 | var Gzs = wavegan.net.eval(zs); 152 | for (var i = 0; i < nzactors; ++i) { 153 | zactors[i].setPrerendered(zs[i], Gzs[i]); 154 | } 155 | 156 | // Hook up audio 157 | var scriptProcessor = audioCtx.createScriptProcessor(512, 0, 1); 158 | scriptProcessor.onaudioprocess = function (event) { 159 | var buffer = event.outputBuffer.getChannelData(0); 160 | for (var i = 0; i < buffer.length; ++i) { 161 | buffer[i] = 0; 162 | } 163 | for (var i = 0; i < nzactors; ++i) { 164 | zactors[i].readBlock(buffer); 165 | } 166 | }; 167 | 168 | return scriptProcessor; 169 | }; 170 | 171 | // Sequencer state 172 | var sequencer = null; 173 | 174 | // Global resize callback 175 | var onResize = function (event) { 176 | var demo = document.getElementById('demo'); 177 | var demoHeight = demo.offsetTop + demo.offsetHeight; 178 | var viewportHeight = Math.max(document.documentElement.clientHeight, window.innerHeight || 0); 179 | return; 180 | }; 181 | 182 | // Global keyboard callback 183 | var onKeydown = function (event) { 184 | var key = event.keyCode; 185 | var digit = key - 48; 186 | var zactorid = digit - 1; 187 | var shifted = event.getModifierState('Shift'); 188 | if (zactorid >= 0 && zactorid < 8) { 189 | if (shifted) { 190 | zactors[zactorid].randomize(); 191 | } 192 | else { 193 | zactors[zactorid].bang(); 194 | } 195 | } 196 | 197 | // Space bar 198 | if (key == 32) { 199 | sequencer.toggle(); 200 | } 201 | }; 202 | 203 | var initSlider = function (sliderId, sliderMin, sliderMax, sliderDefault, callback) { 204 | var slider = document.getElementById(sliderId); 205 | slider.value = 10000 * ((sliderDefault - sliderMin) / (sliderMax - sliderMin)); 206 | callback(sliderDefault); 207 | slider.addEventListener('input', function (event) { 208 | var valUi = slider.value / 10000; 209 | var val = (valUi * (sliderMax - sliderMin)) + sliderMin; 210 | callback(val); 211 | }, true); 212 | }; 213 | 214 | var createReverb = function (audioCtx) { 215 | var sampleRate = audioCtx.sampleRate; 216 | var reverbLen = Math.floor(sampleRate * cfg.audio.reverbLen); 217 | var reverbDcy = cfg.audio.reverbDecay; 218 | var impulse = audioCtx.createBuffer(2, reverbLen, sampleRate); 219 | var impulseL = impulse.getChannelData(0); 220 | var impulseR = impulse.getChannelData(1); 221 | for (var i = 0; i < reverbLen; ++i) { 222 | impulseL[i] = (Math.random() * 2 - 1) * Math.pow(1 - i / reverbLen, reverbDcy); 223 | impulseR[i] = (Math.random() * 2 - 1) * Math.pow(1 - i / reverbLen, reverbDcy); 224 | } 225 | var reverbNode = audioCtx.createConvolver(); 226 | reverbNode.buffer = impulse; 227 | return reverbNode 228 | }; 229 | 230 | // Run once DOM loads 231 | var gainNode = null; 232 | var domReady = function () { 233 | cfg.debugMsg('DOM ready'); 234 | 235 | // Create grid 236 | var cellTemplate = document.getElementById('zactor-template').innerHTML; 237 | var i = 0; 238 | var gridHtml = ''; 239 | for (var j = 0; j < cfg.ui.zactorNumRows; ++j) { 240 | gridHtml += '
'; 241 | for (var k = 0; k < cfg.ui.zactorNumCols; ++k) { 242 | gridHtml += cellTemplate.replace('{ID}', 'zactor' + String(i)); 243 | ++i; 244 | } 245 | gridHtml += '
'; 246 | } 247 | document.getElementById('zactors').innerHTML = gridHtml; 248 | 249 | // Initialize audio 250 | var audioCtx = new window.AudioContext(); 251 | globalAudioCtxChromeWorkaround = audioCtx; 252 | 253 | var reverbNode = createReverb(audioCtx); 254 | var wet = audioCtx.createGain(); 255 | var dry = audioCtx.createGain(); 256 | gainNode = audioCtx.createGain(); 257 | reverbNode.connect(wet); 258 | wet.connect(gainNode); 259 | dry.connect(gainNode); 260 | gainNode.connect(audioCtx.destination); 261 | 262 | // (Gross) wait for net to be ready 263 | var wait = function() { 264 | if (wavegan.net.isReady()) { 265 | var scriptProcessor = initZactors(audioCtx, wavegan.net.getCherries()); 266 | scriptProcessor.connect(reverbNode); 267 | scriptProcessor.connect(dry); 268 | 269 | var seqCanvas = document.getElementById('sequencer-canvas'); 270 | sequencer = new wavegan.sequencer.Sequencer(seqCanvas, zactors); 271 | sequencer.render(); 272 | 273 | document.getElementById('overlay').setAttribute('hidden', ''); 274 | document.getElementById('content').removeAttribute('hidden'); 275 | } 276 | else { 277 | setTimeout(wait, 5); 278 | } 279 | }; 280 | setTimeout(wait, 5); 281 | 282 | // Sequencer button callbacks 283 | document.getElementById('sequencer-play').addEventListener('click', function () { 284 | sequencer.play(); 285 | }); 286 | document.getElementById('sequencer-stop').addEventListener('click', function () { 287 | sequencer.stop(); 288 | }); 289 | document.getElementById('sequencer-clear').addEventListener('click', function () { 290 | sequencer.clear(); 291 | }); 292 | 293 | // Slider callbacks 294 | initSlider('gain', 295 | 0, 1, 296 | cfg.audio.gainDefault, 297 | function (val) { 298 | gainNode.gain.value = val * val * val * val; 299 | }); 300 | initSlider('reverb', 301 | 0, 1, 302 | cfg.audio.reverbDefault, 303 | function (val) { 304 | dry.gain.value = (1 - val); 305 | wet.gain.value = val; 306 | }); 307 | initSlider('sequencer-tempo', 308 | cfg.sequencer.tempoMin, cfg.sequencer.tempoMax, 309 | cfg.sequencer.tempoDefault, 310 | function (val) { 311 | if (sequencer !== null) { 312 | sequencer.setTempoBpm(val); 313 | } 314 | }); 315 | initSlider('sequencer-swing', 316 | cfg.sequencer.swingMin, cfg.sequencer.swingMax, 317 | cfg.sequencer.swingDefault, 318 | function (val) { 319 | if (sequencer !== null) { 320 | sequencer.setSwing(val); 321 | } 322 | }); 323 | 324 | // Global resize callback 325 | window.addEventListener('resize', onResize, true); 326 | onResize(); 327 | 328 | // Global key listener callback 329 | window.addEventListener('keydown', onKeydown, true); 330 | }; 331 | 332 | // DOM load callbacks 333 | if (document.addEventListener) document.addEventListener("DOMContentLoaded", domReady, false); 334 | else if (document.attachEvent) document.attachEvent("onreadystatechange", domReady); 335 | else window.onload = domReady; 336 | 337 | // Exports 338 | wavegan.ui = {}; 339 | 340 | })(window.deeplearn, window.wavegan); 341 | -------------------------------------------------------------------------------- /TSRIRgan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv1d_transpose( 5 | inputs, 6 | filters, 7 | kernel_width, 8 | stride=4, 9 | padding='same', 10 | upsample='zeros'): 11 | if upsample == 'zeros': 12 | return tf.layers.conv2d_transpose( 13 | tf.expand_dims(inputs, axis=1), 14 | filters, 15 | (1, kernel_width), 16 | strides=(1, stride), 17 | padding='same' 18 | )[:, 0] 19 | elif upsample == 'nn': 20 | batch_size = tf.shape(inputs)[0] 21 | _, w, nch = inputs.get_shape().as_list() 22 | 23 | x = inputs 24 | 25 | x = tf.expand_dims(x, axis=1) 26 | x = tf.image.resize_nearest_neighbor(x, [1, w * stride]) 27 | x = x[:, 0] 28 | 29 | return tf.layers.conv1d( 30 | x, 31 | filters, 32 | kernel_width, 33 | 1, 34 | padding='same') 35 | else: 36 | raise NotImplementedError 37 | 38 | 39 | """ 40 | Input: [None, 100] 41 | Output: [None, slice_len, 1] 42 | """ 43 | def TSRIRGANGenerator_real( 44 | z, 45 | slice_len=16384, 46 | nch=1, 47 | kernel_len=25, 48 | dim=64, 49 | use_batchnorm=False, 50 | upsample='zeros', 51 | train=False): 52 | assert slice_len in [16384, 32768, 65536] 53 | batch_size = tf.shape(z)[0] 54 | 55 | if use_batchnorm: 56 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=train) 57 | else: 58 | batchnorm = lambda x: x 59 | 60 | # FC and reshape for convolution 61 | # [100] -> [16, 1024] 62 | dim_mul = 16 if slice_len == 16384 else 32 63 | # output = z 64 | # with tf.variable_scope('z_project'): 65 | # output = tf.layers.dense(output, 4 * 4 * dim * dim_mul) 66 | # output = tf.reshape(output, [batch_size, 16, dim * dim_mul]) 67 | # output = batchnorm(output) 68 | # output = tf.nn.relu(output) 69 | # dim_mul //= 2 70 | 71 | # Layer 0 72 | # [16384, 1] -> [4096, 64] 73 | output = z 74 | with tf.variable_scope('downconv_0'): 75 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 76 | output = tf.nn.relu(output) 77 | # output = phaseshuffle(output) 78 | 79 | # Layer 1 80 | # [4096, 64] -> [1024, 128] 81 | with tf.variable_scope('downconv_1'): 82 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 83 | output = batchnorm(output) 84 | output = tf.nn.relu(output) 85 | # output = phaseshuffle(output) 86 | 87 | # Layer 2 88 | # [1024, 128] -> [256, 256] 89 | with tf.variable_scope('downconv_2'): 90 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 91 | output = batchnorm(output) 92 | output = tf.nn.relu(output) 93 | # output = phaseshuffle(output) 94 | 95 | # Layer 3 96 | # [256, 256] -> [64, 512] 97 | with tf.variable_scope('downconv_3'): 98 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 99 | output = batchnorm(output) 100 | output = tf.nn.relu(output) 101 | # output = phaseshuffle(output) 102 | 103 | # Layer 4 104 | # [64, 512] -> [16, 1024] %[32,1024] 105 | with tf.variable_scope('downconv_4'): 106 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 107 | output = batchnorm(output) 108 | output = tf.nn.relu(output) 109 | 110 | with tf.variable_scope('downconv_5'): 111 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 112 | output = batchnorm(output) 113 | output = tf.nn.relu(output) 114 | 115 | 116 | with tf.variable_scope('downconv_6'): 117 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 118 | output = batchnorm(output) 119 | output = tf.nn.relu(output) 120 | 121 | with tf.variable_scope('downconv_7'): 122 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 123 | output = batchnorm(output) 124 | output = tf.nn.relu(output) 125 | with tf.variable_scope('downconv_8'): 126 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 127 | output = batchnorm(output) 128 | output = tf.nn.relu(output) 129 | 130 | # # Layer 0 131 | # # [16, 1024] -> [64, 512] %[32, 1024] 132 | # with tf.variable_scope('upconv_0'): 133 | # output = conv1d_transpose(output, dim * dim_mul, kernel_len, 2, upsample=upsample) 134 | # output = batchnorm(output) 135 | # output = tf.nn.relu(output) 136 | # dim_mul //= 2 137 | 138 | 139 | 140 | #Up Layer 1 141 | # [64, 512] -> [256, 256] 142 | with tf.variable_scope('upconv_1'): 143 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 144 | output = batchnorm(output) 145 | output = tf.nn.relu(output) 146 | dim_mul //= 2 147 | 148 | #Up Layer 2 149 | # [256, 256] -> [1024, 128] 150 | with tf.variable_scope('upconv_2'): 151 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 152 | output = batchnorm(output) 153 | output = tf.nn.relu(output) 154 | dim_mul //= 2 155 | 156 | #Up Layer 3 157 | # [1024, 128] -> [4096, 64] 158 | with tf.variable_scope('upconv_3'): 159 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 160 | output = batchnorm(output) 161 | output = tf.nn.relu(output) 162 | 163 | if slice_len == 16384: 164 | #Up Layer 4 165 | # [4096, 64] -> [16384, nch] 166 | with tf.variable_scope('upconv_4'): 167 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 168 | output = tf.nn.tanh(output) 169 | elif slice_len == 32768: 170 | #Up Layer 4 171 | # [4096, 128] -> [16384, 64] 172 | with tf.variable_scope('upconv_4'): 173 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 174 | output = batchnorm(output) 175 | output = tf.nn.relu(output) 176 | 177 | #Up Layer 5 178 | # [16384, 64] -> [32768, nch] 179 | with tf.variable_scope('upconv_5'): 180 | output = conv1d_transpose(output, nch, kernel_len, 2, upsample=upsample) 181 | output = tf.nn.tanh(output) 182 | elif slice_len == 65536: 183 | #Up Layer 4 184 | # [4096, 128] -> [16384, 64] 185 | with tf.variable_scope('upconv_4'): 186 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 187 | output = batchnorm(output) 188 | output = tf.nn.relu(output) 189 | 190 | #Up Layer 5 191 | # [16384, 64] -> [65536, nch] 192 | with tf.variable_scope('upconv_5'): 193 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 194 | output = tf.nn.tanh(output) 195 | 196 | # Automatically update batchnorm moving averages every time G is used during training 197 | if train and use_batchnorm: 198 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) 199 | if slice_len == 16384: 200 | assert len(update_ops) == 10 201 | else: 202 | assert len(update_ops) == 12 203 | with tf.control_dependencies(update_ops): 204 | output = tf.identity(output) 205 | 206 | return output 207 | 208 | def TSRIRGANGenerator_synthetic( 209 | z, 210 | slice_len=16384, 211 | nch=1, 212 | kernel_len=25, 213 | dim=64, 214 | use_batchnorm=False, 215 | upsample='zeros', 216 | train=False): 217 | assert slice_len in [16384, 32768, 65536] 218 | batch_size = tf.shape(z)[0] 219 | 220 | if use_batchnorm: 221 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=train) 222 | else: 223 | batchnorm = lambda x: x 224 | 225 | # FC and reshape for convolution 226 | # [100] -> [16, 1024] 227 | dim_mul = 16 if slice_len == 16384 else 32 228 | # output = z 229 | # with tf.variable_scope('z_project'): 230 | # output = tf.layers.dense(output, 4 * 4 * dim * dim_mul) 231 | # output = tf.reshape(output, [batch_size, 16, dim * dim_mul]) 232 | # output = batchnorm(output) 233 | # output = tf.nn.relu(output) 234 | # dim_mul //= 2 235 | 236 | # Layer 0 237 | # [16384, 1] -> [4096, 64] 238 | output = z 239 | with tf.variable_scope('downconv_0'): 240 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 241 | output = tf.nn.relu(output) 242 | # output = phaseshuffle(output) 243 | 244 | # Layer 1 245 | # [4096, 64] -> [1024, 128] 246 | with tf.variable_scope('downconv_1'): 247 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 248 | output = batchnorm(output) 249 | output = tf.nn.relu(output) 250 | # output = phaseshuffle(output) 251 | 252 | # Layer 2 253 | # [1024, 128] -> [256, 256] 254 | with tf.variable_scope('downconv_2'): 255 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 256 | output = batchnorm(output) 257 | output = tf.nn.relu(output) 258 | # output = phaseshuffle(output) 259 | 260 | # Layer 3 261 | # [256, 256] -> [64, 512] 262 | with tf.variable_scope('downconv_3'): 263 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 264 | output = batchnorm(output) 265 | output = tf.nn.relu(output) 266 | # output = phaseshuffle(output) 267 | 268 | # Layer 4 269 | # [64, 512] -> [16, 1024] %[32,1024] 270 | with tf.variable_scope('downconv_4'): 271 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 272 | output = batchnorm(output) 273 | output = tf.nn.relu(output) 274 | 275 | with tf.variable_scope('downconv_5'): 276 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 277 | output = batchnorm(output) 278 | output = tf.nn.relu(output) 279 | 280 | 281 | with tf.variable_scope('downconv_6'): 282 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 283 | output = batchnorm(output) 284 | output = tf.nn.relu(output) 285 | 286 | with tf.variable_scope('downconv_7'): 287 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 288 | output = batchnorm(output) 289 | output = tf.nn.relu(output) 290 | with tf.variable_scope('downconv_8'): 291 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 1, padding='SAME') 292 | output = batchnorm(output) 293 | output = tf.nn.relu(output) 294 | # # Layer 0 295 | # # [16, 1024] -> [64, 512] %[32, 1024] 296 | # with tf.variable_scope('upconv_0'): 297 | # output = conv1d_transpose(output, dim * dim_mul, kernel_len, 2, upsample=upsample) 298 | # output = batchnorm(output) 299 | # output = tf.nn.relu(output) 300 | # dim_mul //= 2 301 | 302 | 303 | 304 | 305 | #Up Layer 1 306 | # [64, 512] -> [256, 256] 307 | with tf.variable_scope('upconv_1'): 308 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 309 | output = batchnorm(output) 310 | output = tf.nn.relu(output) 311 | dim_mul //= 2 312 | 313 | #Up Layer 2 314 | # [256, 256] -> [1024, 128] 315 | with tf.variable_scope('upconv_2'): 316 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 317 | output = batchnorm(output) 318 | output = tf.nn.relu(output) 319 | dim_mul //= 2 320 | 321 | #Up Layer 3 322 | # [1024, 128] -> [4096, 64] 323 | with tf.variable_scope('upconv_3'): 324 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 325 | output = batchnorm(output) 326 | output = tf.nn.relu(output) 327 | 328 | if slice_len == 16384: 329 | #Up Layer 4 330 | # [4096, 64] -> [16384, nch] 331 | with tf.variable_scope('upconv_4'): 332 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 333 | output = tf.nn.tanh(output) 334 | elif slice_len == 32768: 335 | #Up Layer 4 336 | # [4096, 128] -> [16384, 64] 337 | with tf.variable_scope('upconv_4'): 338 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 339 | output = batchnorm(output) 340 | output = tf.nn.relu(output) 341 | 342 | #Up Layer 5 343 | # [16384, 64] -> [32768, nch] 344 | with tf.variable_scope('upconv_5'): 345 | output = conv1d_transpose(output, nch, kernel_len, 2, upsample=upsample) 346 | output = tf.nn.tanh(output) 347 | elif slice_len == 65536: 348 | #Up Layer 4 349 | # [4096, 128] -> [16384, 64] 350 | with tf.variable_scope('upconv_4'): 351 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 352 | output = batchnorm(output) 353 | output = tf.nn.relu(output) 354 | 355 | #Up Layer 5 356 | # [16384, 64] -> [65536, nch] 357 | with tf.variable_scope('upconv_5'): 358 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 359 | output = tf.nn.tanh(output) 360 | 361 | # Automatically update batchnorm moving averages every time G is used during training 362 | if train and use_batchnorm: 363 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) 364 | if slice_len == 16384: 365 | assert len(update_ops) == 10 366 | else: 367 | assert len(update_ops) == 12 368 | with tf.control_dependencies(update_ops): 369 | output = tf.identity(output) 370 | 371 | return output 372 | 373 | 374 | def lrelu(inputs, alpha=0.2): 375 | return tf.maximum(alpha * inputs, inputs) 376 | 377 | 378 | def apply_phaseshuffle(x, rad, pad_type='reflect'): 379 | b, x_len, nch = x.get_shape().as_list() 380 | 381 | phase = tf.random_uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32) 382 | pad_l = tf.maximum(phase, 0) 383 | pad_r = tf.maximum(-phase, 0) 384 | phase_start = pad_r 385 | x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type) 386 | 387 | x = x[:, phase_start:phase_start+x_len] 388 | x.set_shape([b, x_len, nch]) 389 | 390 | return x 391 | 392 | 393 | """ 394 | Input: [None, slice_len, nch] 395 | Output: [None] (linear output) 396 | """ 397 | def TSRIRGANDiscriminator_synthetic( 398 | x, 399 | kernel_len=25, 400 | dim=64, 401 | use_batchnorm=False, 402 | phaseshuffle_rad=0): 403 | batch_size = tf.shape(x)[0] 404 | slice_len = int(x.get_shape()[1]) 405 | 406 | if use_batchnorm: 407 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True) 408 | else: 409 | batchnorm = lambda x: x 410 | 411 | # if phaseshuffle_rad > 0: 412 | # phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad) 413 | # else: 414 | # phaseshuffle = lambda x: x 415 | phaseshuffle = lambda x: x 416 | 417 | 418 | # Layer 0 419 | # [16384, 1] -> [4096, 64] 420 | output = x 421 | with tf.variable_scope('downconv_0'): 422 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 423 | output = lrelu(output) 424 | output = phaseshuffle(output) 425 | 426 | # Layer 1 427 | # [4096, 64] -> [1024, 128] 428 | with tf.variable_scope('downconv_1'): 429 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 430 | output = batchnorm(output) 431 | output = lrelu(output) 432 | output = phaseshuffle(output) 433 | 434 | # Layer 2 435 | # [1024, 128] -> [256, 256] 436 | with tf.variable_scope('downconv_2'): 437 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 438 | output = batchnorm(output) 439 | output = lrelu(output) 440 | output = phaseshuffle(output) 441 | 442 | # Layer 3 443 | # [256, 256] -> [64, 512] 444 | with tf.variable_scope('downconv_3'): 445 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 446 | output = batchnorm(output) 447 | output = lrelu(output) 448 | output = phaseshuffle(output) 449 | 450 | # Layer 4 451 | # [64, 512] -> [16, 1024] 452 | with tf.variable_scope('downconv_4'): 453 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME') 454 | output = batchnorm(output) 455 | output = lrelu(output) 456 | 457 | if slice_len == 32768: 458 | # Layer 5 459 | # [32, 1024] -> [16, 2048] 460 | with tf.variable_scope('downconv_5'): 461 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME') 462 | output = batchnorm(output) 463 | output = lrelu(output) 464 | elif slice_len == 65536: 465 | # Layer 5 466 | # [64, 1024] -> [16, 2048] 467 | with tf.variable_scope('downconv_5'): 468 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME') 469 | output = batchnorm(output) 470 | output = lrelu(output) 471 | 472 | # Flatten 473 | output = tf.reshape(output, [batch_size, -1]) 474 | 475 | # Connect to single logit 476 | with tf.variable_scope('output'): 477 | output = tf.layers.dense(output, 1)[:, 0] 478 | 479 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training 480 | 481 | return output 482 | 483 | 484 | def TSRIRGANDiscriminator_real( 485 | x, 486 | kernel_len=25, 487 | dim=64, 488 | use_batchnorm=False, 489 | phaseshuffle_rad=0): 490 | batch_size = tf.shape(x)[0] 491 | slice_len = int(x.get_shape()[1]) 492 | 493 | if use_batchnorm: 494 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True) 495 | else: 496 | batchnorm = lambda x: x 497 | 498 | # if phaseshuffle_rad > 0: 499 | # phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad) 500 | # else: 501 | # phaseshuffle = lambda x: x 502 | phaseshuffle = lambda x: x 503 | 504 | # Layer 0 505 | # [16384, 1] -> [4096, 64] 506 | output = x 507 | with tf.variable_scope('downconv_0'): 508 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 509 | output = lrelu(output) 510 | output = phaseshuffle(output) 511 | 512 | # Layer 1 513 | # [4096, 64] -> [1024, 128] 514 | with tf.variable_scope('downconv_1'): 515 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 516 | output = batchnorm(output) 517 | output = lrelu(output) 518 | output = phaseshuffle(output) 519 | 520 | # Layer 2 521 | # [1024, 128] -> [256, 256] 522 | with tf.variable_scope('downconv_2'): 523 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 524 | output = batchnorm(output) 525 | output = lrelu(output) 526 | output = phaseshuffle(output) 527 | 528 | # Layer 3 529 | # [256, 256] -> [64, 512] 530 | with tf.variable_scope('downconv_3'): 531 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 532 | output = batchnorm(output) 533 | output = lrelu(output) 534 | output = phaseshuffle(output) 535 | 536 | # Layer 4 537 | # [64, 512] -> [16, 1024] 538 | with tf.variable_scope('downconv_4'): 539 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME') 540 | output = batchnorm(output) 541 | output = lrelu(output) 542 | 543 | if slice_len == 32768: 544 | # Layer 5 545 | # [32, 1024] -> [16, 2048] 546 | with tf.variable_scope('downconv_5'): 547 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME') 548 | output = batchnorm(output) 549 | output = lrelu(output) 550 | elif slice_len == 65536: 551 | # Layer 5 552 | # [64, 1024] -> [16, 2048] 553 | with tf.variable_scope('downconv_5'): 554 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME') 555 | output = batchnorm(output) 556 | output = lrelu(output) 557 | 558 | # Flatten 559 | output = tf.reshape(output, [batch_size, -1]) 560 | 561 | # Connect to single logit 562 | with tf.variable_scope('output'): 563 | output = tf.layers.dense(output, 1)[:, 0] 564 | 565 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training 566 | 567 | return output 568 | -------------------------------------------------------------------------------- /train_TSRIRgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import cPickle as pickle 5 | except: 6 | import pickle 7 | from functools import reduce 8 | import os 9 | import time 10 | import matplotlib.pyplot as plt 11 | from IPython.display import clear_output 12 | import numpy as np 13 | import tensorflow as tf 14 | from six.moves import xrange 15 | import RT60 16 | import loader 17 | from TSRIRgan import TSRIRGANGenerator_synthetic, TSRIRGANGenerator_real, TSRIRGANDiscriminator_synthetic,TSRIRGANDiscriminator_real 18 | 19 | 20 | """ 21 | Trains a TSRIRGAN 22 | """ 23 | # loss_obj = tf.keras.losses.binary_crossentropy(from_logits=True) 24 | 25 | def mae_criterion(pred, target): 26 | return tf.reduce_mean((pred - target) ** 2) 27 | 28 | LAMBDA = 10 29 | def discriminator_loss(real, generated): 30 | # real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real), real)#,from_logits=True) 31 | real_loss = mae_criterion(tf.ones_like(real), real) 32 | # generated_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(generated), generated)#,from_logits=True) 33 | generated_loss = mae_criterion(tf.zeros_like(generated), generated) 34 | total_disc_loss = real_loss + generated_loss 35 | return total_disc_loss * 0.5*LAMBDA 36 | 37 | def generator_loss(generated): 38 | # return tf.keras.losses.binary_crossentropy(tf.ones_like(generated), generated)#,from_logits=True) 39 | return mae_criterion(tf.ones_like(generated), generated) 40 | 41 | def calc_cycle_loss(real_image, cycled_image): 42 | loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) 43 | return LAMBDA * loss1 44 | 45 | def identity_loss(real_image, same_image): 46 | loss = tf.reduce_mean(tf.abs(real_image - same_image)) 47 | return LAMBDA * 0.5 * loss 48 | 49 | def RT_60_loss(real,generated,sess): 50 | # sess = tf.Session() 51 | # sess.run(tf.global_variables_initializer()) 52 | _real = real.eval(session = sess) 53 | _generated = generated.eval(session = sess) 54 | # _generated = generated 55 | no_samples = len(_real) 56 | sampling_rate = 16000 57 | t60_loss_list =list() 58 | 59 | for i in range (no_samples): 60 | real_wav = _real[i] 61 | generated_wav = _generated[i] 62 | real_t60_val = RT60.t60_impulse(real_wav,sampling_rate) 63 | generated_t60_val = RT60.t60_impulse(generated_wav,sampling_rate) 64 | # print("real t60 ", real_t60_val) 65 | # print("generated t60 ", generated_t60_val) 66 | t60_loss = abs(real_t60_val-generated_t60_val) 67 | t60_loss_list.append(t60_loss) 68 | 69 | mean_t60_loss = sum(t60_loss_list)/len(t60_loss_list) 70 | 71 | return mean_t60_loss 72 | 73 | 74 | 75 | def train(fps1,fps2, args): 76 | with tf.name_scope('loader'): 77 | x_real = loader.decode_extract_and_batch( 78 | fps1, 79 | batch_size=args.train_batch_size, 80 | slice_len=args.data1_slice_len, 81 | decode_fs=args.data1_sample_rate, 82 | decode_num_channels=args.data1_num_channels, 83 | decode_fast_wav=args.data1_fast_wav, 84 | decode_parallel_calls=4, 85 | slice_randomize_offset=False if args.data1_first_slice else True, 86 | slice_first_only=args.data1_first_slice, 87 | slice_overlap_ratio=0. if args.data1_first_slice else args.data1_overlap_ratio, 88 | slice_pad_end=True if args.data1_first_slice else args.data1_pad_end, 89 | repeat=True, 90 | shuffle=True, 91 | shuffle_buffer_size=4096, 92 | prefetch_size=args.train_batch_size * 4, 93 | prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0] 94 | 95 | x_synthetic = loader.decode_extract_and_batch( 96 | fps2, 97 | batch_size=args.train_batch_size, 98 | slice_len=args.data2_slice_len, 99 | decode_fs=args.data2_sample_rate, 100 | decode_num_channels=args.data2_num_channels, 101 | decode_fast_wav=args.data2_fast_wav, 102 | decode_parallel_calls=4, 103 | slice_randomize_offset=False if args.data2_first_slice else True, 104 | slice_first_only=args.data2_first_slice, 105 | slice_overlap_ratio=0. if args.data2_first_slice else args.data2_overlap_ratio, 106 | slice_pad_end=True if args.data2_first_slice else args.data2_pad_end, 107 | repeat=True, 108 | shuffle=True, 109 | shuffle_buffer_size=4096, 110 | prefetch_size=args.train_batch_size * 4, 111 | prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0] 112 | 113 | # print('length check', len(x_real)) 114 | # Make z vector 115 | # z = tf.random_uniform([args.train_batch_size, args.TSRIRgan_latent_dim], -1., 1., dtype=tf.float32) 116 | 117 | # Make generator_synthetic 118 | with tf.variable_scope('G_synthetic'): 119 | G_synthetic = TSRIRGANGenerator_synthetic(x_real, train=True, **args.TSRIRgan_g_kwargs) 120 | if args.TSRIRgan_genr_pp: 121 | with tf.variable_scope('s_pp_filt'): 122 | G_synthetic = tf.layers.conv1d(G_synthetic, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 123 | G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_synthetic') 124 | 125 | # Print G_synthetic summary 126 | print('-' * 80) 127 | print('Generator_synthetic vars') 128 | nparams = 0 129 | for v in G_synthetic_vars: 130 | v_shape = v.get_shape().as_list() 131 | v_n = reduce(lambda x, y: x * y, v_shape) 132 | nparams += v_n 133 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 134 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 135 | 136 | # Summarize 137 | tf.summary.audio('x_real', x_real, args.data1_sample_rate) 138 | tf.summary.audio('G_synthetic', G_synthetic, args.data1_sample_rate) 139 | G_synthetic_rms = tf.sqrt(tf.reduce_mean(tf.square(G_synthetic[:, :, 0]), axis=1)) 140 | x_real_rms = tf.sqrt(tf.reduce_mean(tf.square(x_real[:, :, 0]), axis=1)) 141 | tf.summary.histogram('x_real_rms_batch', x_real_rms) 142 | tf.summary.histogram('G_synthetic_rms_batch', G_synthetic_rms) 143 | tf.summary.scalar('x_real_rms', tf.reduce_mean(x_real_rms)) 144 | tf.summary.scalar('G_synthetic_rms', tf.reduce_mean(G_synthetic_rms)) 145 | 146 | # Make generator_real 147 | with tf.variable_scope('G_real'): 148 | G_real = TSRIRGANGenerator_real(x_synthetic, train=True, **args.TSRIRgan_g_kwargs) 149 | if args.TSRIRgan_genr_pp: 150 | with tf.variable_scope('r_pp_filt'): 151 | G_real = tf.layers.conv1d(G_real, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 152 | G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_real') 153 | 154 | # Print G_real summary 155 | print('-' * 80) 156 | print('Generator_real vars') 157 | nparams = 0 158 | for v in G_real_vars: 159 | v_shape = v.get_shape().as_list() 160 | v_n = reduce(lambda x, y: x * y, v_shape) 161 | nparams += v_n 162 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 163 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 164 | 165 | # Summarize 166 | tf.summary.audio('x_synthetic', x_synthetic, args.data1_sample_rate) 167 | tf.summary.audio('G_real', G_real, args.data1_sample_rate) 168 | G_real_rms = tf.sqrt(tf.reduce_mean(tf.square(G_real[:, :, 0]), axis=1)) 169 | x_synthetic_rms = tf.sqrt(tf.reduce_mean(tf.square(x_synthetic[:, :, 0]), axis=1)) 170 | tf.summary.histogram('x_synthetic_rms_batch', x_synthetic_rms) 171 | tf.summary.histogram('G_real_rms_batch', G_real_rms) 172 | tf.summary.scalar('x_synthetic_rms', tf.reduce_mean(x_synthetic_rms)) 173 | tf.summary.scalar('G_real_rms', tf.reduce_mean(G_real_rms)) 174 | 175 | 176 | #Generating Cycled Image 177 | with tf.variable_scope('G_synthetic',reuse=True): 178 | cycle_synthetic = TSRIRGANGenerator_synthetic(G_real, train=True, **args.TSRIRgan_g_kwargs) 179 | if args.TSRIRgan_genr_pp: 180 | with tf.variable_scope('s_pp_filt'): 181 | cycle_synthetic = tf.layers.conv1d(cycle_synthetic, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 182 | G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_synthetic') 183 | 184 | with tf.variable_scope('G_real', reuse=True): 185 | cycle_real = TSRIRGANGenerator_real(G_synthetic, train=True, **args.TSRIRgan_g_kwargs) 186 | if args.TSRIRgan_genr_pp: 187 | with tf.variable_scope('r_pp_filt'): 188 | cycle_real = tf.layers.conv1d(cycle_real, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 189 | G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_real') 190 | 191 | #Generating Same Image 192 | with tf.variable_scope('G_synthetic', reuse=True): 193 | same_synthetic = TSRIRGANGenerator_synthetic(x_synthetic, train=True, **args.TSRIRgan_g_kwargs) 194 | if args.TSRIRgan_genr_pp: 195 | with tf.variable_scope('s_pp_filt'): 196 | same_synthetic = tf.layers.conv1d(same_synthetic, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 197 | G_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_synthetic') 198 | 199 | with tf.variable_scope('G_real', reuse=True): 200 | same_real = TSRIRGANGenerator_real(x_real, train=True, **args.TSRIRgan_g_kwargs) 201 | if args.TSRIRgan_genr_pp: 202 | with tf.variable_scope('r_pp_filt'): 203 | same_real = tf.layers.conv1d(same_real, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 204 | G_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G_real') 205 | 206 | #Synthetic 207 | # Make real discriminator 208 | with tf.name_scope('D_synthetic_x'), tf.variable_scope('D_synthetic'): 209 | D_synthetic_x = TSRIRGANDiscriminator_synthetic(x_synthetic, **args.TSRIRgan_d_kwargs) 210 | D_synthetic_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_synthetic') 211 | 212 | # Print D summary 213 | print('-' * 80) 214 | print('Discriminator_synthetic vars') 215 | nparams = 0 216 | for v in D_synthetic_vars: 217 | v_shape = v.get_shape().as_list() 218 | v_n = reduce(lambda x, y: x * y, v_shape) 219 | nparams += v_n 220 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 221 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 222 | print('-' * 80) 223 | 224 | # Make fake discriminator 225 | with tf.name_scope('D_G_synthetic'), tf.variable_scope('D_synthetic', reuse=True): 226 | D_G_synthetic = TSRIRGANDiscriminator_synthetic(G_synthetic, **args.TSRIRgan_d_kwargs) 227 | 228 | 229 | #Real 230 | # Make real discriminator 231 | with tf.name_scope('D_real_x'), tf.variable_scope('D_real'): 232 | D_real_x = TSRIRGANDiscriminator_real(x_real, **args.TSRIRgan_d_kwargs) 233 | D_real_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_real') 234 | 235 | # Print D summary 236 | print('-' * 80) 237 | print('Discriminator_real vars') 238 | nparams = 0 239 | for v in D_real_vars: 240 | v_shape = v.get_shape().as_list() 241 | v_n = reduce(lambda x, y: x * y, v_shape) 242 | nparams += v_n 243 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 244 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 245 | print('-' * 80) 246 | 247 | # Make fake discriminator 248 | with tf.name_scope('D_G_real'), tf.variable_scope('D_real', reuse=True): 249 | D_G_real = TSRIRGANDiscriminator_real(G_real, **args.TSRIRgan_d_kwargs) 250 | ############stop here########### 251 | # Create loss 252 | D_clip_weights = None 253 | sess = tf.Session() 254 | sess.run(tf.global_variables_initializer()) 255 | if args.TSRIRgan_loss == 'cycle-gan': 256 | #Real IR 257 | gen_real_loss = generator_loss(D_G_real) 258 | gen_synthetic_loss = generator_loss(D_G_synthetic) 259 | 260 | cycle_loss_real = calc_cycle_loss(x_real,cycle_real) 261 | cycle_loss_synthetic = calc_cycle_loss(x_synthetic,cycle_synthetic) 262 | 263 | total_cycle_loss = cycle_loss_real + cycle_loss_synthetic 264 | 265 | same_real_loss = identity_loss(x_real,same_real) 266 | same_synthetic_loss = identity_loss(x_synthetic,same_synthetic) 267 | 268 | # RT60_loss_real = RT_60_loss(x_real,G_real,sess) 269 | # RT60_loss_synthetic = RT_60_loss(x_synthetic,G_synthetic,sess) 270 | 271 | total_gen_real_loss = gen_real_loss + 25*total_cycle_loss + 35*same_real_loss #+RT60_loss_real 272 | total_gen_synthetic_loss = gen_synthetic_loss + 25*total_cycle_loss + 35*same_synthetic_loss # +RT60_loss_synthetic 273 | 274 | disc_synthetic_loss = discriminator_loss(D_synthetic_x,D_G_synthetic) 275 | disc_real_loss = discriminator_loss(D_real_x,D_G_real) 276 | 277 | else: 278 | raise NotImplementedError() 279 | 280 | # tf.summary.scalar('RT60_loss_real', RT60_loss_real) 281 | # tf.summary.scalar('RT60_loss_synthetic',RT60_loss_synthetic) 282 | tf.summary.scalar('G_real_loss', total_gen_real_loss) 283 | tf.summary.scalar('G_synthetic_loss', total_gen_synthetic_loss) 284 | tf.summary.scalar('D_real_loss', disc_real_loss) 285 | tf.summary.scalar('D_synthetic_loss', disc_synthetic_loss) 286 | 287 | tf.summary.scalar('Generator_real_loss', gen_real_loss) 288 | tf.summary.scalar('Generator_synthetic_loss', gen_synthetic_loss) 289 | tf.summary.scalar('Cycle_loss_real',15*cycle_loss_real) 290 | tf.summary.scalar('Cycle_loss_synthetic', 15*cycle_loss_synthetic) 291 | tf.summary.scalar('Same_loss_real',20*same_real_loss) 292 | tf.summary.scalar('Same_loss_synthetic', 20*same_synthetic_loss) 293 | 294 | # Create (recommended) optimizer 295 | if args.TSRIRgan_loss == 'cycle-gan': 296 | # G_real_opt = tf.train.AdamOptimizer( 297 | # learning_rate=2e-4, 298 | # beta1=0.5) 299 | # G_synthetic_opt = tf.train.AdamOptimizer( 300 | # learning_rate=2e-4, 301 | # beta1=0.5) 302 | # D_real_opt = tf.train.AdamOptimizer( 303 | # learning_rate=2e-4, 304 | # beta1=0.5) 305 | # D_synthetic_opt = tf.train.AdamOptimizer( 306 | # learning_rate=2e-4, 307 | # beta1=0.5) 308 | G_real_opt = tf.train.RMSPropOptimizer( 309 | learning_rate=3e-5) 310 | G_synthetic_opt = tf.train.RMSPropOptimizer( 311 | learning_rate=3e-5) 312 | D_real_opt = tf.train.RMSPropOptimizer( 313 | learning_rate=3e-5) 314 | D_synthetic_opt = tf.train.RMSPropOptimizer( 315 | learning_rate=3e-5) 316 | else: 317 | raise NotImplementedError() 318 | 319 | # Create training ops 320 | G_real_train_op = G_real_opt.minimize(total_gen_real_loss, var_list=G_real_vars, 321 | global_step=tf.train.get_or_create_global_step()) 322 | G_synthetic_train_op = G_synthetic_opt.minimize(total_gen_synthetic_loss, var_list=G_synthetic_vars, 323 | global_step=tf.train.get_or_create_global_step()) 324 | D_real_train_op = D_real_opt.minimize(disc_real_loss, var_list=D_real_vars) 325 | D_synthetic_train_op = D_synthetic_opt.minimize(disc_synthetic_loss, var_list=D_synthetic_vars) 326 | 327 | # Run training 328 | with tf.train.MonitoredTrainingSession( 329 | checkpoint_dir=args.train_dir, 330 | save_checkpoint_secs=args.train_save_secs, 331 | save_summaries_secs=args.train_summary_secs) as sess: 332 | print('-' * 80) 333 | print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir)) 334 | # RT60_loss_real = RT_60_loss(x_real,G_real,sess) 335 | # RT60_loss_synthetic = RT_60_loss(x_synthetic,G_synthetic,sess) 336 | while True: 337 | # Train discriminator 338 | for i in xrange(args.TSRIRgan_disc_nupdates): 339 | sess.run(D_real_train_op) 340 | sess.run(D_synthetic_train_op) 341 | 342 | # Enforce Lipschitz constraint for WGAN 343 | # if D_clip_weights is not None: 344 | # sess.run(D_clip_weights) 345 | 346 | # Train generator 347 | sess.run(G_real_train_op) 348 | sess.run(G_synthetic_train_op) 349 | # RT60_loss_real = RT_60_loss(x_real,G_real,sess) 350 | # RT60_loss_synthetic = RT_60_loss(x_synthetic,G_synthetic,sess) 351 | 352 | def infer(args): 353 | infer_dir = os.path.join(args.train_dir, 'infer') 354 | if not os.path.isdir(infer_dir): 355 | os.makedirs(infer_dir) 356 | 357 | samp_x_synthetic_n = tf.placeholder(tf.int32, [], name='samp_x_synthetic_n') 358 | samp_x_real_n = tf.placeholder(tf.int32, [], name='samp_x_real_n') 359 | 360 | # samp_z = tf.random_uniform([samp_z_n, args.TSRIRgan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z') 361 | 362 | # Input zo 363 | x_real = tf.placeholder(tf.float32, [64, 16384, 1], name='x_real') 364 | x_synthetic = tf.placeholder(tf.float32, [64, 16384, 1], name='x_synthetic') 365 | 366 | 367 | synthetic_flat_pad = tf.placeholder(tf.int32, [], name='synthetic_flat_pad') 368 | x_synthetic_flat_pad = tf.placeholder(tf.int32, [], name='x_synthetic_flat_pad') 369 | real_flat_pad = tf.placeholder(tf.int32, [], name='real_flat_pad') 370 | x_real_flat_pad = tf.placeholder(tf.int32, [], name='x_real_flat_pad') 371 | print("shape ", x_real.shape) 372 | # Execute generator 373 | with tf.variable_scope('G_synthetic'): 374 | G_synthetic_x = TSRIRGANGenerator_synthetic(x_real, train=False, **args.TSRIRgan_g_kwargs) 375 | if args.TSRIRgan_genr_pp: 376 | with tf.variable_scope('s_pp_filt'): 377 | G_synthetic_x = tf.layers.conv1d(G_synthetic_x, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 378 | G_synthetic_x = tf.identity(G_synthetic_x, name='G_synthetic_x') 379 | 380 | with tf.variable_scope('G_real'): 381 | G_real_x = TSRIRGANGenerator_real(x_synthetic, train=False, **args.TSRIRgan_g_kwargs) 382 | if args.TSRIRgan_genr_pp: 383 | with tf.variable_scope('r_pp_filt'): 384 | G_real_x = tf.layers.conv1d(G_real_x, 1, args.TSRIRgan_genr_pp_len, use_bias=False, padding='same') 385 | G_real_x = tf.identity(G_real_x, name='G_real_x') 386 | 387 | # Flatten batch 388 | synthetic_nch = int(G_synthetic_x.get_shape()[-1]) 389 | G_synthetic_x_padded = tf.pad(G_synthetic_x, [[0, 0], [0, synthetic_flat_pad], [0, 0]]) 390 | G_synthetic_x_flat = tf.reshape(G_synthetic_x_padded, [-1, synthetic_nch], name='G_synthetic_x_flat') 391 | 392 | xs_nch = int(x_synthetic.get_shape()[-1]) 393 | x_synthetic_padded = tf.pad(x_synthetic, [[0, 0], [0, x_synthetic_flat_pad], [0, 0]]) 394 | x_synthetic_flat = tf.reshape(x_synthetic_padded, [-1, xs_nch], name='x_synthetic_flat') 395 | 396 | 397 | real_nch = int(G_real_x.get_shape()[-1]) 398 | G_real_x_padded = tf.pad(G_real_x, [[0, 0], [0, real_flat_pad], [0, 0]]) 399 | G_real_x_flat = tf.reshape(G_real_x_padded, [-1, real_nch], name='G_real_x_flat') 400 | 401 | xr_nch = int(x_real.get_shape()[-1]) 402 | x_real_padded = tf.pad(x_real, [[0, 0], [0, x_real_flat_pad], [0, 0]]) 403 | x_real_flat = tf.reshape(x_real_padded, [-1, xr_nch], name='x_real_flat') 404 | 405 | # Encode to int16 406 | def float_to_int16(x, name=None): 407 | x_int16 = x * 32767. 408 | x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) 409 | x_int16 = tf.cast(x_int16, tf.int16, name=name) 410 | return x_int16 411 | G_synthetic_x_int16 = float_to_int16(G_synthetic_x, name='G_synthetic_x_int16') 412 | G_synthetic_x_flat_int16 = float_to_int16(G_synthetic_x_flat, name='G_synthetic_x_flat_int16') 413 | G_real_x_int16 = float_to_int16(G_real_x, name='G_real_x_int16') 414 | G_real_x_flat_int16 = float_to_int16(G_real_x_flat, name='G_real_x_flat_int16') 415 | 416 | x_synthetic_int16 = float_to_int16(x_synthetic, name='x_synthetic_int16') 417 | x_synthetic_flat_int16 = float_to_int16(x_synthetic_flat, name='x_synthetic_flat_int16') 418 | x_real_int16 = float_to_int16(x_real, name='x_real_int16') 419 | x_real_flat_int16 = float_to_int16(x_real_flat, name='x_real_flat_int16') 420 | 421 | # Create saver 422 | G_synthetic_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_synthetic') 423 | global_step = tf.train.get_or_create_global_step() 424 | 425 | G_real_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_real')### 426 | 427 | saver = tf.train.Saver(G_synthetic_vars + G_real_vars + [global_step]) 428 | 429 | # Export graph 430 | tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') 431 | 432 | # Export MetaGraph 433 | infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') 434 | tf.train.export_meta_graph( 435 | filename=infer_metagraph_fp, 436 | clear_devices=True, 437 | saver_def=saver.as_saver_def()) 438 | 439 | # # Create saver 440 | # G_real_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_real') 441 | # global_step = tf.train.get_or_create_global_step() 442 | # saver = tf.train.Saver(G_real_vars + [global_step]) 443 | 444 | # # Export graph 445 | # tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') 446 | 447 | # # Export MetaGraph 448 | # infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') 449 | # tf.train.export_meta_graph( 450 | # filename=infer_metagraph_fp, 451 | # clear_devices=True, 452 | # saver_def=saver.as_saver_def()) 453 | 454 | # Reset graph (in case training afterwards) 455 | tf.reset_default_graph() 456 | 457 | 458 | """ 459 | Generates a preview audio file every time a checkpoint is saved 460 | """ 461 | def preview(fps1,fps2,args): 462 | import matplotlib 463 | matplotlib.use('Agg') 464 | import matplotlib.pyplot as plt 465 | from scipy.io.wavfile import write as wavwrite 466 | from scipy.signal import freqz 467 | 468 | preview_dir = os.path.join(args.train_dir, 'preview') 469 | if not os.path.isdir(preview_dir): 470 | os.makedirs(preview_dir) 471 | 472 | #################################################### 473 | s_fps1 = fps1[0:args.preview_n] 474 | s_fps2 = fps2[0:args.preview_n] 475 | with tf.name_scope('samp_x_real'): 476 | x_real = loader.decode_extract_and_batch( 477 | s_fps1, 478 | batch_size=args.train_batch_size, 479 | slice_len=args.data1_slice_len, 480 | decode_fs=args.data1_sample_rate, 481 | decode_num_channels=args.data1_num_channels, 482 | decode_fast_wav=args.data1_fast_wav, 483 | decode_parallel_calls=4, 484 | slice_randomize_offset=False if args.data1_first_slice else True, 485 | slice_first_only=args.data1_first_slice, 486 | slice_overlap_ratio=0. if args.data1_first_slice else args.data1_overlap_ratio, 487 | slice_pad_end=True if args.data1_first_slice else args.data1_pad_end, 488 | repeat=True, 489 | shuffle=True, 490 | shuffle_buffer_size=4096, 491 | prefetch_size=args.train_batch_size * 4, 492 | prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0] 493 | 494 | 495 | with tf.name_scope('samp_x_synthetic'): 496 | x_synthetic = loader.decode_extract_and_batch( 497 | s_fps2, 498 | batch_size=args.train_batch_size, 499 | slice_len=args.data2_slice_len, 500 | decode_fs=args.data2_sample_rate, 501 | decode_num_channels=args.data2_num_channels, 502 | decode_fast_wav=args.data2_fast_wav, 503 | decode_parallel_calls=4, 504 | slice_randomize_offset=False if args.data2_first_slice else True, 505 | slice_first_only=args.data2_first_slice, 506 | slice_overlap_ratio=0. if args.data2_first_slice else args.data2_overlap_ratio, 507 | slice_pad_end=True if args.data2_first_slice else args.data2_pad_end, 508 | repeat=True, 509 | shuffle=True, 510 | shuffle_buffer_size=4096, 511 | prefetch_size=args.train_batch_size * 4, 512 | prefetch_gpu_num=args.data1_prefetch_gpu_num)[:, :, 0] 513 | 514 | #################################################### 515 | x_synthetic = x_synthetic.eval(session=tf.Session()) 516 | x_real = x_real.eval(session=tf.Session()) 517 | 518 | # Load graph 519 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta') 520 | graph = tf.get_default_graph() 521 | saver = tf.train.import_meta_graph(infer_metagraph_fp) 522 | 523 | 524 | # Set up graph for generating preview images 525 | feeds = {} 526 | feeds[graph.get_tensor_by_name('x_synthetic:0')] = x_synthetic 527 | feeds[graph.get_tensor_by_name('synthetic_flat_pad:0')] = int(args.data1_sample_rate / 2) 528 | feeds[graph.get_tensor_by_name('x_synthetic_flat_pad:0')] = int(args.data1_sample_rate / 2) 529 | feeds[graph.get_tensor_by_name('x_real:0')] = x_real 530 | feeds[graph.get_tensor_by_name('real_flat_pad:0')] = int(args.data1_sample_rate / 2) 531 | feeds[graph.get_tensor_by_name('x_real_flat_pad:0')] = int(args.data1_sample_rate / 2) 532 | fetches = {} 533 | fetches['step'] = tf.train.get_or_create_global_step() 534 | fetches['G_synthetic_x'] = graph.get_tensor_by_name('G_synthetic_x:0') 535 | fetches['G_synthetic_x_flat_int16'] = graph.get_tensor_by_name('G_synthetic_x_flat_int16:0') 536 | fetches['x_synthetic_flat_int16'] = graph.get_tensor_by_name('x_synthetic_flat_int16:0') 537 | fetches['G_real_x'] = graph.get_tensor_by_name('G_real_x:0') 538 | fetches['G_real_x_flat_int16'] = graph.get_tensor_by_name('G_real_x_flat_int16:0') 539 | fetches['x_real_flat_int16'] = graph.get_tensor_by_name('x_real_flat_int16:0') 540 | if args.TSRIRgan_genr_pp: 541 | s_fetches['s_pp_filter'] = graph.get_tensor_by_name('G_synthetic_x/s_pp_filt/conv1d/kernel:0')[:, 0, 0] 542 | s_fetches['r_pp_filter'] = graph.get_tensor_by_name('G_real_x/r_pp_filt/conv1d/kernel:0')[:, 0, 0] 543 | 544 | # Summarize 545 | G_synthetic_x = graph.get_tensor_by_name('G_synthetic_x_flat:0') 546 | s_summaries = [ 547 | tf.summary.audio('preview', tf.expand_dims(G_synthetic_x, axis=0), args.data1_sample_rate, max_outputs=1) 548 | ] 549 | fetches['s_summaries'] = tf.summary.merge(s_summaries) 550 | s_summary_writer = tf.summary.FileWriter(preview_dir) 551 | 552 | G_real_x = graph.get_tensor_by_name('G_real_x_flat:0') 553 | r_summaries = [ 554 | tf.summary.audio('preview', tf.expand_dims(G_real_x, axis=0), args.data1_sample_rate, max_outputs=1) 555 | ] 556 | fetches['r_summaries'] = tf.summary.merge(r_summaries) 557 | r_summary_writer = tf.summary.FileWriter(preview_dir) 558 | 559 | 560 | 561 | # PP Summarize 562 | if args.TSRIRgan_genr_pp: 563 | s_pp_fp = tf.placeholder(tf.string, []) 564 | s_pp_bin = tf.read_file(s_pp_fp) 565 | s_pp_png = tf.image.decode_png(s_pp_bin) 566 | s_pp_summary = tf.summary.image('s_pp_filt', tf.expand_dims(s_pp_png, axis=0)) 567 | 568 | if args.TSRIRgan_genr_pp: 569 | r_pp_fp = tf.placeholder(tf.string, []) 570 | r_pp_bin = tf.read_file(r_pp_fp) 571 | r_pp_png = tf.image.decode_png(r_pp_bin) 572 | r_pp_summary = tf.summary.image('r_pp_filt', tf.expand_dims(r_pp_png, axis=0)) 573 | 574 | # Loop, waiting for checkpoints 575 | ckpt_fp = None 576 | while True: 577 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) 578 | if latest_ckpt_fp != ckpt_fp: 579 | print('Preview: {}'.format(latest_ckpt_fp)) 580 | 581 | with tf.Session() as sess: 582 | saver.restore(sess, latest_ckpt_fp) 583 | 584 | _fetches = sess.run(fetches, feeds) 585 | 586 | _step = _fetches['step'] 587 | 588 | # with tf.Session() as sess: 589 | # saver.restore(sess, latest_ckpt_fp) 590 | 591 | # _r_fetches = sess.run(r_fetches, r_feeds) 592 | 593 | # _r_step = _r_fetches['step'] 594 | 595 | s_preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8)+'synthetic')) 596 | wavwrite(s_preview_fp, args.data1_sample_rate, _fetches['G_synthetic_x_flat_int16']) 597 | s_original_fp = os.path.join(preview_dir, '{}.wav'.format('synthetic_original')) 598 | wavwrite(s_original_fp, args.data1_sample_rate, _fetches['x_synthetic_flat_int16']) 599 | 600 | s_summary_writer.add_summary(_fetches['s_summaries'], _step) 601 | 602 | r_preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8)+'real')) 603 | wavwrite(r_preview_fp, args.data1_sample_rate, _fetches['G_real_x_flat_int16']) 604 | r_original_fp = os.path.join(preview_dir, '{}.wav'.format('real_original')) 605 | wavwrite(r_original_fp, args.data1_sample_rate, _fetches['x_real_flat_int16']) 606 | 607 | r_summary_writer.add_summary(_fetches['r_summaries'], _step) 608 | 609 | #I have to edit this 610 | # if args.TSRIRgan_genr_pp: 611 | # s_w, s_h = freqz(_s_fetches['s_pp_filter']) 612 | 613 | # fig = plt.figure() 614 | # plt.title('Digital filter frequncy response') 615 | # ax1 = fig.add_subplot(111) 616 | 617 | # plt.plot(w, 20 * np.log10(abs(h)), 'b') 618 | # plt.ylabel('Amplitude [dB]', color='b') 619 | # plt.xlabel('Frequency [rad/sample]') 620 | 621 | # ax2 = ax1.twinx() 622 | # angles = np.unwrap(np.angle(h)) 623 | # plt.plot(w, angles, 'g') 624 | # plt.ylabel('Angle (radians)', color='g') 625 | # plt.grid() 626 | # plt.axis('tight') 627 | 628 | # _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8))) 629 | # plt.savefig(_pp_fp) 630 | 631 | # with tf.Session() as sess: 632 | # _summary = sess.run(pp_summary, {pp_fp: _pp_fp}) 633 | # summary_writer.add_summary(_summary, _step) 634 | 635 | print('Done') 636 | 637 | ckpt_fp = latest_ckpt_fp 638 | 639 | time.sleep(1) 640 | 641 | 642 | 643 | if __name__ == '__main__': 644 | import argparse 645 | import glob 646 | import sys 647 | 648 | parser = argparse.ArgumentParser() 649 | 650 | parser.add_argument('mode', type=str, choices=['train', 'preview', 'infer']) 651 | parser.add_argument('train_dir', type=str, 652 | help='Training directory') 653 | 654 | data1_args = parser.add_argument_group('Data1') 655 | data1_args.add_argument('--data1_dir', type=str, 656 | help='Data directory containing *only* audio files to load') 657 | data1_args.add_argument('--data1_sample_rate', type=int, 658 | help='Number of audio samples per second') 659 | data1_args.add_argument('--data1_slice_len', type=int, choices=[16384, 32768, 65536], 660 | help='Number of audio samples per slice (maximum generation length)') 661 | data1_args.add_argument('--data1_num_channels', type=int, 662 | help='Number of audio channels to generate (for >2, must match that of data)') 663 | data1_args.add_argument('--data1_overlap_ratio', type=float, 664 | help='Overlap ratio [0, 1) between slices') 665 | data1_args.add_argument('--data1_first_slice', action='store_true', dest='data1_first_slice', 666 | help='If set, only use the first slice each audio example') 667 | data1_args.add_argument('--data1_pad_end', action='store_true', dest='data1_pad_end', 668 | help='If set, use zero-padded partial slices from the end of each audio file') 669 | data1_args.add_argument('--data1_normalize', action='store_true', dest='data1_normalize', 670 | help='If set, normalize the training examples') 671 | data1_args.add_argument('--data1_fast_wav', action='store_true', dest='data1_fast_wav', 672 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 673 | data1_args.add_argument('--data1_prefetch_gpu_num', type=int, 674 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 675 | 676 | data2_args = parser.add_argument_group('Data2') 677 | data2_args.add_argument('--data2_dir', type=str, 678 | help='Data directory containing *only* audio files to load') 679 | data2_args.add_argument('--data2_sample_rate', type=int, 680 | help='Number of audio samples per second') 681 | data2_args.add_argument('--data2_slice_len', type=int, choices=[16384, 32768, 65536], 682 | help='Number of audio samples per slice (maximum generation length)') 683 | data2_args.add_argument('--data2_num_channels', type=int, 684 | help='Number of audio channels to generate (for >2, must match that of data)') 685 | data2_args.add_argument('--data2_overlap_ratio', type=float, 686 | help='Overlap ratio [0, 1) between slices') 687 | data2_args.add_argument('--data2_first_slice', action='store_true', dest='data2_first_slice', 688 | help='If set, only use the first slice each audio example') 689 | data2_args.add_argument('--data2_pad_end', action='store_true', dest='data2_pad_end', 690 | help='If set, use zero-padded partial slices from the end of each audio file') 691 | data2_args.add_argument('--data2_normalize', action='store_true', dest='data2_normalize', 692 | help='If set, normalize the training examples') 693 | data2_args.add_argument('--data2_fast_wav', action='store_true', dest='data2_fast_wav', 694 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 695 | data2_args.add_argument('--data2_prefetch_gpu_num', type=int, 696 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 697 | 698 | 699 | TSRIRgan_args = parser.add_argument_group('TSRIRGAN') 700 | TSRIRgan_args.add_argument('--TSRIRgan_latent_dim', type=int, 701 | help='Number of dimensions of the latent space') 702 | TSRIRgan_args.add_argument('--TSRIRgan_kernel_len', type=int, 703 | help='Length of 1D filter kernels') 704 | TSRIRgan_args.add_argument('--TSRIRgan_dim', type=int, 705 | help='Dimensionality multiplier for model of G and D') 706 | TSRIRgan_args.add_argument('--TSRIRgan_batchnorm', action='store_true', dest='TSRIRgan_batchnorm', 707 | help='Enable batchnorm') 708 | TSRIRgan_args.add_argument('--TSRIRgan_disc_nupdates', type=int, 709 | help='Number of discriminator updates per generator update') 710 | TSRIRgan_args.add_argument('--TSRIRgan_loss', type=str, choices=['cycle-gan'], 711 | help='Which GAN loss to use') 712 | TSRIRgan_args.add_argument('--TSRIRgan_genr_upsample', type=str, choices=['zeros', 'nn'], 713 | help='Generator upsample strategy') 714 | TSRIRgan_args.add_argument('--TSRIRgan_genr_pp', action='store_true', dest='TSRIRgan_genr_pp', 715 | help='If set, use post-processing filter') 716 | TSRIRgan_args.add_argument('--TSRIRgan_genr_pp_len', type=int, 717 | help='Length of post-processing filter for DCGAN') 718 | TSRIRgan_args.add_argument('--TSRIRgan_disc_phaseshuffle', type=int, 719 | help='Radius of phase shuffle operation') 720 | 721 | train_args = parser.add_argument_group('Train') 722 | train_args.add_argument('--train_batch_size', type=int, 723 | help='Batch size') 724 | train_args.add_argument('--train_save_secs', type=int, 725 | help='How often to save model') 726 | train_args.add_argument('--train_summary_secs', type=int, 727 | help='How often to report summaries') 728 | 729 | preview_args = parser.add_argument_group('Preview') 730 | preview_args.add_argument('--preview_n', type=int, 731 | help='Number of samples to preview') 732 | 733 | # incept_args = parser.add_argument_group('Incept') 734 | # incept_args.add_argument('--incept_metagraph_fp', type=str, 735 | # help='Inference model for inception score') 736 | # incept_args.add_argument('--incept_ckpt_fp', type=str, 737 | # help='Checkpoint for inference model') 738 | # incept_args.add_argument('--incept_n', type=int, 739 | # help='Number of generated examples to test') 740 | # incept_args.add_argument('--incept_k', type=int, 741 | # help='Number of groups to test') 742 | 743 | parser.set_defaults( 744 | data1_dir=None, 745 | data1_sample_rate=16000, 746 | data1_slice_len=16384, 747 | data1_num_channels=1, 748 | data1_overlap_ratio=0., 749 | data1_first_slice=False, 750 | data1_pad_end=False, 751 | data1_normalize=False, 752 | data1_fast_wav=False, 753 | data1_prefetch_gpu_num=0, 754 | data2_dir=None, 755 | data2_sample_rate=16000, 756 | data2_slice_len=16384, 757 | data2_num_channels=1, 758 | data2_overlap_ratio=0., 759 | data2_first_slice=False, 760 | data2_pad_end=False, 761 | data2_normalize=False, 762 | data2_fast_wav=False, 763 | data2_prefetch_gpu_num=0, 764 | TSRIRgan_latent_dim=100, 765 | TSRIRgan_kernel_len=25, 766 | TSRIRgan_dim=64, 767 | TSRIRgan_batchnorm=False, 768 | TSRIRgan_disc_nupdates=2, 769 | TSRIRgan_loss='cycle-gan', 770 | TSRIRgan_genr_upsample='zeros', 771 | TSRIRgan_genr_pp=False, 772 | TSRIRgan_genr_pp_len=512, 773 | TSRIRgan_disc_phaseshuffle=2, 774 | train_batch_size=64, 775 | train_save_secs=300, 776 | train_summary_secs=120, 777 | preview_n=32)#, 778 | # incept_metagraph_fp='./eval/inception/infer.meta', 779 | # incept_ckpt_fp='./eval/inception/best_acc-103005', 780 | # incept_n=5000, 781 | # incept_k=10) 782 | 783 | args = parser.parse_args() 784 | 785 | # Make train dir 786 | if not os.path.isdir(args.train_dir): 787 | os.makedirs(args.train_dir) 788 | 789 | # Save args 790 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f: 791 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 792 | 793 | # Make model kwarg dicts 794 | setattr(args, 'TSRIRgan_g_kwargs', { 795 | 'slice_len': args.data1_slice_len, 796 | 'nch': args.data1_num_channels, 797 | 'kernel_len': args.TSRIRgan_kernel_len, 798 | 'dim': args.TSRIRgan_dim, 799 | 'use_batchnorm': args.TSRIRgan_batchnorm, 800 | 'upsample': args.TSRIRgan_genr_upsample 801 | }) 802 | setattr(args, 'TSRIRgan_d_kwargs', { 803 | 'kernel_len': args.TSRIRgan_kernel_len, 804 | 'dim': args.TSRIRgan_dim, 805 | 'use_batchnorm': args.TSRIRgan_batchnorm, 806 | 'phaseshuffle_rad': args.TSRIRgan_disc_phaseshuffle 807 | }) 808 | 809 | 810 | fps1 = glob.glob(os.path.join(args.data1_dir, '*')) 811 | if len(fps1) == 0: 812 | raise Exception('Did not find any audio files in specified directory(real_IR)') 813 | print('Found {} audio files in specified directory'.format(len(fps1))) 814 | fps2 = glob.glob(os.path.join(args.data2_dir, '*')) 815 | if len(fps2) == 0: 816 | raise Exception('Did not find any audio files in specified directory(synthetic_IR)') 817 | print('Found {} audio files in specified directory'.format(len(fps2))) 818 | if args.mode == 'train': 819 | infer(args) 820 | train(fps1,fps2, args) 821 | elif args.mode == 'preview': 822 | preview(fps1,fps2,args) 823 | elif args.mode == 'infer': 824 | infer(args) 825 | else: 826 | raise NotImplementedError() 827 | --------------------------------------------------------------------------------