├── .gitignore ├── LICENSE ├── README.md ├── apt-packages.txt ├── data ├── preprocess_data.py └── trumpet │ └── input.wav ├── requirements.txt ├── res ├── 2024-string.gif └── precorrect.png ├── run.py └── src ├── callbacks.py ├── configs ├── callbacks │ └── base.yaml ├── config.yaml ├── experiment │ ├── all-fixed.yaml │ ├── base.yaml │ ├── evaluate.yaml │ ├── linear-string.yaml │ ├── nonlinear-string.yaml │ ├── nsynth-like.yaml │ ├── process_training_data.yaml │ └── synth-dmsp.yaml ├── framework │ └── supervised.yaml ├── model │ ├── base.yaml │ ├── bow.yaml │ ├── dmsp.yaml │ ├── fdtd.yaml │ ├── hammer.yaml │ ├── pluck.yaml │ └── trainer.yaml ├── optimizer │ ├── adam.yaml │ ├── adamw.yaml │ ├── lamb.yaml │ ├── radam.yaml │ └── sgd.yaml ├── scheduler │ ├── constant.yaml │ ├── constant_warmup.yaml │ ├── cosine.yaml │ ├── cosine_warmup.yaml │ ├── linear_warmup.yaml │ ├── multistep.yaml │ ├── noam.yaml │ ├── plateau.yaml │ ├── step.yaml │ └── timm_cosine.yaml └── task │ ├── evaluate.yaml │ ├── process_training_data.yaml │ ├── simulate.yaml │ └── synthesize.yaml ├── dataset └── synthesize.py ├── model ├── analytic.py ├── cpp │ ├── bow.cpp │ ├── bow.h │ ├── hammer.cpp │ ├── hammer.h │ ├── misc.cpp │ ├── misc.h │ ├── simulator.cpp │ ├── string.cpp │ ├── string.h │ ├── vnv.cpp │ └── vnv.h ├── nn │ ├── blocks.py │ ├── ddsp.py │ ├── dmsp.py │ └── synthesizer.py └── simulator.py ├── task ├── evaluate.py ├── process_training_data.py ├── simulate.py └── synthesize.py ├── trainer.py └── utils ├── analysis └── frequency.py ├── audio.py ├── config.py ├── control.py ├── data.py ├── ddsp.py ├── fdm.py ├── loss.py ├── misc.py ├── objective.py ├── optimizer.py ├── plot.py └── vnv.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | */__pycache__/ 3 | */*/__pycache__/ 4 | */*/*/__pycache__/ 5 | */*/*/*/__pycache__/ 6 | results/ 7 | check/ 8 | log/ 9 | vnv/ 10 | 11 | data/libri* 12 | data/shortviolin 13 | data/violin 14 | data/trumpet2 15 | data/*/*.npy 16 | data/*/*.pdf 17 | data/*/sine-* 18 | data/*/sample-* 19 | 20 | src/configs/experiment/demo-* 21 | src/configs/experiment/discretize-convergence.yaml 22 | src/configs/experiment/friction-coefficient.yaml 23 | src/configs/experiment/variation-* 24 | src/configs/experiment/analyze-convnext*.yaml 25 | src/configs/experiment/synth-ddsp-*.yaml 26 | src/configs/experiment/test-*.yaml 27 | src/configs/model/mlp.yaml 28 | src/configs/model/wavenet.yaml 29 | src/configs/model/convnext.yaml 30 | src/configs/model/ddsp.yaml 31 | src/configs/model/gru.yaml 32 | src/configs/model/transformer.yaml 33 | src/configs/model/unet.yaml 34 | src/configs/framework/pinn.yaml 35 | src/configs/framework/diffusion.yaml 36 | src/configs/task/analyze.yaml 37 | 38 | src/model/cpp/build 39 | src/model/nn/analyzer.py 40 | src/model/nn/convnext*.py 41 | src/model/nn/gru.py 42 | src/model/nn/mlp.py 43 | src/model/nn/mpd.py 44 | src/model/nn/transformer.py 45 | src/model/nn/unet.py 46 | src/model/nn/wavenet*.py 47 | 48 | src/task/summarize.py 49 | src/task/adv.py 50 | src/task/analyze.py 51 | src/utils/diffusion.py 52 | src/utils/pde.py 53 | src/dataset/analyze.py 54 | 55 | cmd 56 | 57 | *.swp 58 | 59 | -------------------------------------------------------------------------------- /apt-packages.txt: -------------------------------------------------------------------------------- 1 | ninja-build 2 | ffmpeg 3 | 4 | -------------------------------------------------------------------------------- /data/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import librosa 4 | import librosa.display 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torchaudio.transforms as TAT 9 | import torchaudio.transforms as TAF 10 | import crepe 11 | 12 | import numpy as np 13 | import soundfile as sf 14 | import matplotlib.pyplot as plt 15 | 16 | def plot_spectrogram( 17 | save_path, out, sr, n_fft=2**13, hop_length=None, 18 | f0_input=None, f0_estimate=None, modes=None, colorbar=True, 19 | ): 20 | L = 32 21 | if out.shape[-1] > 2*n_fft: 22 | hop_length = n_fft // L if hop_length is None else hop_length 23 | else: 24 | n_fft = out.shape[-1] // 2 25 | hop_length = n_fft // L 26 | t_max = out.shape[-1] / sr 27 | 28 | D = librosa.stft(out, n_fft=n_fft, hop_length=hop_length, pad_mode='reflect') 29 | mag, phase = librosa.magphase(D) 30 | 31 | freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) 32 | times = librosa.times_like(D, sr=sr, hop_length=hop_length) 33 | 34 | logmag = librosa.amplitude_to_db(mag, ref=np.max) 35 | 36 | #width = 2.5; height = 1.9 37 | width = 30; height = 5 38 | plt.figure(figsize=(width,height)) 39 | spec = librosa.display.specshow( 40 | logmag, 41 | n_fft=n_fft, hop_length=hop_length, sr=sr, 42 | y_axis='log', x_axis='time', 43 | ) 44 | if colorbar: 45 | cbar = plt.colorbar(spec, ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi]) 46 | cbar.ax.set(yticklabels=['$-\pi$', '$-\pi/2$', "$0$", '$\pi/2$', '$\pi$']); 47 | 48 | def add_plot(freqs, label=None, ls=None, lw=2., dashes=(None,None)): 49 | x = np.linspace(1/sr, t_max, freqs.shape[-1]) 50 | freqs = np.interp(times, x, freqs) 51 | line, = plt.plot(times - times[0], freqs, label=label, color='white', lw=lw, ls=ls, dashes=dashes) 52 | return line 53 | 54 | freq_ticks = [0, 128, 512, 2048, 8192, sr // 2] 55 | time_ticks = [0, 1, 2] 56 | if f0_input is not None: 57 | add_plot(f0_input, "f0_input", dashes=(10,5)) 58 | freq_ticks += [f0_input[0]] 59 | 60 | if f0_estimate is not None: 61 | add_plot(f0_estimate, "f0_estimate", dashes=(2,5)) 62 | freq_ticks += [] if f0_input is not None else [f0_estimate[0]] 63 | 64 | if modes is not None: 65 | for im, m in enumerate(modes): 66 | l = add_plot(m, f"mode {im}") 67 | l.set_dashes([5,10,1,10]) 68 | 69 | plt.xticks([]) 70 | plt.yticks([]) 71 | plt.xlabel('') 72 | plt.ylabel('') 73 | #plt.xaxis.set_visible(False) 74 | #plt.yaxis.set_visible(False) 75 | 76 | plt.tight_layout() 77 | plt.savefig(save_path, bbox_inches='tight', pad_inches=-1e-6) 78 | plt.clf() 79 | plt.close("all") 80 | 81 | def spectrogram(x, n_fft=1024, hop_length=None, logscale=False): 82 | L = 4 83 | if x.shape[-1] > 2*n_fft: 84 | hop_length = n_fft // L if hop_length is None else hop_length 85 | else: 86 | n_fft = x.shape[-1] // 2 87 | hop_length = n_fft // L 88 | 89 | X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, pad_mode='reflect') 90 | mag, phase = librosa.magphase(X) 91 | mag = librosa.amplitude_to_db(mag, ref=np.max) if logscale else mag 92 | return mag, phase 93 | 94 | def load_wav(root_dir, filename, target_sr): 95 | wav_path = f"{root_dir}/{filename}/input.wav" 96 | if os.path.exists(wav_path): 97 | x, sr = sf.read(wav_path) 98 | else: 99 | x, sr = librosa.load(librosa.example(filename)) 100 | 101 | if sr != target_sr: 102 | resampler = TAT.Resample(sr, target_sr, resampling_method='kaiser_window') 103 | x = resampler(torch.from_numpy(x)).numpy() 104 | sr = target_sr 105 | if not os.path.exists(wav_path): 106 | os.makedirs(f'{root_dir}/{filename}', exist_ok=True) 107 | sf.write(f"{root_dir}/{filename}/input.wav", x, samplerate=sr) 108 | return x, sr 109 | 110 | def inverse_spectrogram(mag, phase, n_fft=1024, hop_length=None): 111 | hop_length = n_fft // 4 if hop_length is None else hop_length 112 | X = mag * phase 113 | x = librosa.istft(X, n_fft=n_fft, hop_length=hop_length) 114 | return x 115 | 116 | def get_amplitude(x): 117 | X_mag, X_phs = spectrogram(x) # (F, T) 118 | X_rms = np.sqrt(np.mean(X_mag**2, axis=0)+1e-5) 119 | return X_rms # (T,) 120 | 121 | def sine_like(freqs, length, sr): 122 | time_axis_1 = np.arange(length) / sr 123 | time_axis_2 = np.linspace(1/sr, length / sr, freqs.shape[-1]) 124 | freqs = np.interp(time_axis_1, time_axis_2, freqs) 125 | phase = np.cumsum(freqs) 126 | return np.sin(2 * np.pi * phase / sr) 127 | 128 | def AM(x, amp, sr): 129 | X_mag, X_phs = spectrogram(x) # (F, T) 130 | X_rms = np.sqrt(np.mean(X_mag**2, axis=0, keepdims=True)+1e-5) 131 | X_mag = X_mag / X_rms 132 | X_mag = X_mag * amp[None,:] 133 | x = inverse_spectrogram(X_mag, X_phs) 134 | return x 135 | 136 | def running_avg(x, N=1024, threshold=0.3): 137 | w = np.pad(np.ones(N)/N, (N,0)) 138 | x = np.where(x > threshold, x, np.zeros(x.shape)) 139 | x = np.convolve(x, w, mode='same') 140 | return x 141 | 142 | def process_f0(root_dir, filename, target_sr): 143 | x, sr = load_wav(root_dir, filename, target_sr) 144 | 145 | f0_path = f'{root_dir}/{filename}/string-f0.npy' 146 | if os.path.exists(f0_path): 147 | f0 = np.load(f0_path) 148 | else: 149 | os.makedirs(f'{root_dir}/{filename}', exist_ok=True) 150 | time, f0, confidence, activation = crepe.predict(x, sr, viterbi=True) 151 | np.save(f0_path, f0) 152 | return x, f0 153 | 154 | def process_amp(root_dir, filename, target_sr): 155 | x, sr = load_wav(root_dir, filename, target_sr) 156 | 157 | ''' get f0 ''' 158 | f0_path = f'{root_dir}/{filename}/string-f0.npy' 159 | f0 = np.load(f0_path) 160 | if len(f0) != len(x): 161 | time_axis_1 = np.arange(len(x)) / sr 162 | time_axis_2 = np.linspace(1/sr, len(x) / sr, len(f0)) 163 | f0 = np.interp(time_axis_1, time_axis_2, f0) 164 | np.save(f0_path, f0) 165 | 166 | ''' get amplitude ''' 167 | amp_path = f'{root_dir}/{filename}/amp.npy' 168 | amp = get_amplitude(x) 169 | 170 | y1 = sine_like(f0, x.shape[-1], sr) 171 | y2 = AM(y1, amp, sr) 172 | 173 | if len(amp) != len(x): 174 | time_axis_1 = np.arange(len(x)) / sr 175 | time_axis_2 = np.linspace(1/sr, len(x) / sr, len(amp)) 176 | amp = np.interp(time_axis_1, time_axis_2, amp) 177 | 178 | force = running_avg(amp) 179 | force = 100 * (force/2+ 1e-5)**.1 180 | force = np.where(force > 40, force, np.zeros(force.shape)) 181 | force_path = f'{root_dir}/{filename}/bow-F_b.npy' 182 | 183 | 184 | o_env = librosa.onset.onset_strength(y=x, sr=sr) 185 | time_axis_f = librosa.times_like(o_env, sr=sr) 186 | onset_frames = librosa.onset.onset_detect(onset_envelope=o_env, sr=sr) 187 | scale = x.shape[-1] / time_axis_f.shape[-1] 188 | hammer = np.zeros(x.shape[-1]) 189 | onset_sample = np.array(onset_frames * scale).astype(int) 190 | hammer[onset_sample] = 1 191 | hammer_path = f'{root_dir}/{filename}/hammer-v_H.npy' 192 | 193 | y3 = x * running_avg(hammer) 194 | 195 | np.save(force_path, force) 196 | np.save(hammer_path, hammer) 197 | 198 | sf.write(f"{root_dir}/{filename}/sine-f0.wav", y1, sr) 199 | sf.write(f"{root_dir}/{filename}/sine-f0-amp.wav", y2, sr) 200 | sf.write(f"{root_dir}/{filename}/sine-f0-ham.wav", y3, sr) 201 | 202 | return y1, y2, y3 203 | 204 | if __name__=='__main__': 205 | root_dir = 'data' 206 | filename = 'trumpet' 207 | sr = 48000 208 | 209 | x, f0 = process_f0(root_dir, filename, sr) 210 | y1, y2, y3 = process_amp(root_dir, filename, sr) 211 | 212 | plot_spectrogram(f'{root_dir}/{filename}/spec.pdf', x, sr, f0_input=f0, colorbar=False) 213 | plot_spectrogram(f'{root_dir}/{filename}/spec-f0.pdf', y1, sr, colorbar=False) 214 | plot_spectrogram(f'{root_dir}/{filename}/spec-f0-amp.pdf', y2, sr, colorbar=False) 215 | plot_spectrogram(f'{root_dir}/{filename}/spec-f0-ham.pdf', y3, sr, colorbar=False) 216 | 217 | 218 | sample_list = glob.glob(f'{root_dir}/{filename}/sample-*.wav') 219 | for sp in sample_list: 220 | x, sr = sf.read(sp) 221 | sample_name = sp.split('/')[-1].split('.')[0] 222 | plot_spectrogram(f'{root_dir}/{filename}/{sample_name}.pdf', x, sr, colorbar=False) 223 | 224 | 225 | -------------------------------------------------------------------------------- /data/trumpet/input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jin-woo-lee/torch-fdtd-string/77374b2e506ceeb9be14a0095e57158de5280672/data/trumpet/input.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.12.1+cu113 3 | torchaudio==0.12.1 4 | tensorflow==2.15.0 5 | numpy==1.25.2 6 | lightning==2.1 7 | hydra-core 8 | torchinfo 9 | wandb 10 | timm==0.4.12 11 | auraloss 12 | matplotlib 13 | librosa 14 | einops 15 | rich 16 | crepe 17 | -------------------------------------------------------------------------------- /res/2024-string.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jin-woo-lee/torch-fdtd-string/77374b2e506ceeb9be14a0095e57158de5280672/res/2024-string.gif -------------------------------------------------------------------------------- /res/precorrect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jin-woo-lee/torch-fdtd-string/77374b2e506ceeb9be14a0095e57158de5280672/res/precorrect.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import sys 5 | import glob 6 | import hydra 7 | import traceback 8 | from shutil import copyfile 9 | from datetime import datetime 10 | from omegaconf import DictConfig, OmegaConf 11 | from hydra.core.hydra_config import HydraConfig 12 | 13 | from src.utils import config as cf 14 | 15 | class ConfigArgument: 16 | def __getitem__(self,key): 17 | return getattr(self, key) 18 | def __setitem__(self,key,value): 19 | return setattr(self, key, value) 20 | 21 | def get_object(config, m): 22 | for key in config.keys(): 23 | if isinstance(config[key], DictConfig): 24 | m[key] = ConfigArgument() 25 | get_object(config[key], m[key]) 26 | else: 27 | m[key] = config[key] 28 | return m 29 | 30 | def backup_code(args): 31 | # Copy directory sturcture and files 32 | exclude_dir = ['data', '__pycache__', 'log', '.git', 'res', 'check'] 33 | exclude_file = ['cfg', 'cmd', '.gitignore'] 34 | exclude_ext = ['.png', '.jpg', '.pt', '.npz'] 35 | filepath = [] 36 | cwd_name = os.path.dirname(args.cwd) 37 | for dirpath, dirnames, filenames in os.walk(args.cwd, topdown=True): 38 | subdirs = dirpath.split(cwd_name)[-1] 39 | if not any(dir in subdirs for dir in exclude_dir): 40 | filtered_files=[name for name in filenames if (os.path.splitext(name)[-1] not in exclude_ext) and (name not in exclude_file)] 41 | filepath.append({'dir': dirpath, 'files': filtered_files}) 42 | 43 | num_strip = len(args.cwd) 44 | for path in filepath: 45 | dirname = path['dir'][num_strip+1:] 46 | for filename in path['files']: 47 | if '.swp' in filename or '.onnx' in filename: 48 | continue 49 | file2copy = os.path.join(path['dir'], filename) 50 | os.makedirs(f"codes/{dirname}", exist_ok=True) 51 | filepath2save = os.path.join(f"codes/{dirname}", filename) 52 | copyfile(file2copy, filepath2save) 53 | 54 | @hydra.main(config_path="src/configs", config_name="config.yaml") 55 | def main(config: OmegaConf): 56 | config = cf.process_config(config) 57 | cf.print_config(config, resolve=True) 58 | args = get_object(config, ConfigArgument()) 59 | 60 | # os.environ['MASTER_ADDR'] = "127.0.0.1" 61 | # os.environ['MASTER_PORT'] = f"{args.proc.port}" 62 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 63 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" if args.proc.cpu \ 64 | else ','.join([str(gpu_num) for gpu_num in args.proc.gpus]) 65 | ''' The CUDA_VISIBLE_DEVICES environment variable is read by the cuda driver. 66 | So it needs to be set before the cuda driver is initialized. 67 | It is best if you make sure it is set **before** importing torch 68 | (or at least before you do anything cuda related in torch). 69 | source: https://discuss.pytorch.org/t/os-environ-cuda-visible-devices-not-functioning/105545/4 70 | ''' 71 | import torch 72 | if not args.proc.train: 73 | # This is redundant in the case of being `args.proc.train == True`, 74 | # since Lightning will seed everything (see `src/trainer.py`.) 75 | torch.manual_seed(args.proc.seed) 76 | 77 | args.cwd = HydraConfig.get().runtime.cwd 78 | 79 | if args.task.save_name is not None: 80 | save_dir_name = args.task.save_name 81 | elif args.proc.debug or args.task.result_dir=='debug': 82 | args.proc.debug = True 83 | save_dir_name = 'debug' 84 | else: 85 | save_dir_name = args.task.result_dir 86 | 87 | if not os.path.isabs(args.task.root_dir): 88 | # If the root_dir is relative, make it absolute. 89 | args.task.root_dir = os.path.join(args.cwd, args.task.root_dir) 90 | if not os.path.isabs(args.task.load_dir): 91 | # If the root_dir is relative, make it absolute. 92 | args.task.load_dir = os.path.join(args.cwd, args.task.load_dir) 93 | 94 | save_dir = f'{args.task.root_dir}/{save_dir_name}' 95 | 96 | if args.task.measure_time: 97 | args.task.plot = False 98 | args.task.save = False 99 | args.task.plot_state = False 100 | 101 | if args.task.result_dir == "debug": 102 | args.proc.debug = True 103 | 104 | if args.proc.simulate or args.proc.train: 105 | backup_code(args) 106 | 107 | if args.proc.simulate: 108 | model_name = 'random' if args.model.excitation is None else args.model.excitation 109 | n_samples = args.task.num_samples // args.task.batch_size 110 | from src.task import simulate 111 | # run simulation 112 | simulate.run(args, save_dir, model_name, n_samples=n_samples) 113 | 114 | if args.proc.evaluate: 115 | from src.task import evaluate 116 | # evaluate simulation results 117 | load_dir = save_dir if args.task.load_dir is None else args.task.load_dir 118 | evaluate.evaluate(load_dir) 119 | 120 | if args.proc.summarize: 121 | from src.task import summarize 122 | # summarize evaluation results 123 | load_dir = save_dir if args.task.load_dir is None else args.task.load_dir 124 | summarize.summarize(load_dir) 125 | 126 | if args.proc.process_training_data: 127 | from src.task import process_training_data 128 | # preprocess simulation results as training data 129 | process_training_data.process(args) 130 | #from src.task import process_alpha_data 131 | #process_alpha_data.process(args) 132 | 133 | if args.proc.train: 134 | from src import trainer 135 | # train neural network 136 | trainer.train(args) 137 | 138 | if args.proc.test: 139 | from src import trainer 140 | # test neural network 141 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 142 | msg = 'Do not pass args.task.ckpt_dir=...\n' 143 | msg += 'Indicate the test directory as: hydra.run.dir=...' 144 | assert args.task.ckpt_dir is None, msg 145 | output_dir = hydra_cfg['runtime']['output_dir'] 146 | if args.task._name_ in output_dir: 147 | ckpt_dir = args.task._name_ + output_dir.split(args.task._name_)[-1] 148 | else: 149 | ckpt_dir = output_dir 150 | args.task.ckpt_dir = ckpt_dir 151 | trainer.eval(args) 152 | 153 | if __name__=='__main__': 154 | #os.environ["CUDA_VISIBLE_DEVICES"] = '-1' 155 | main() 156 | 157 | -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import wandb 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.loggers import WandbLogger 7 | from pytorch_lightning.callbacks import Callback 8 | import numpy as np 9 | import soundfile as sf 10 | 11 | from src.utils import plot as plot 12 | from src.utils import audio as audio 13 | 14 | class PlotResults(Callback): 15 | 16 | def __init__(self, args): 17 | self.debug = args.proc.debug 18 | self.plot_dirs = f'plot' 19 | self.wave_dirs = f'wave' 20 | os.makedirs(f"valid/{self.plot_dirs}", exist_ok=True) 21 | os.makedirs(f"valid/{self.wave_dirs}", exist_ok=True) 22 | os.makedirs(f"test/{self.plot_dirs}", exist_ok=True) 23 | os.makedirs(f"test/{self.wave_dirs}", exist_ok=True) 24 | 25 | self.sr = args.task.sr 26 | self.n_fft = args.callbacks.plot.n_fft 27 | self.n_mel = args.callbacks.plot.n_mel 28 | self.hop_length = args.callbacks.plot.hop_length 29 | self.window = torch.hann_window(self.n_fft) 30 | mel_fbank = audio.mel_basis(args.task.sr, self.n_fft, self.n_mel) 31 | self.mel_basis = torch.from_numpy(mel_fbank) 32 | 33 | def stft(self, x): 34 | x = torch.stft(x, n_fft = self.n_fft, hop_length = self.hop_length, win_length = self.n_fft, window = self.window) 35 | return torch.view_as_complex(x).transpose(-1,-2) 36 | 37 | def logmag(self, spec): 38 | eps = torch.finfo(spec.abs().dtype).eps 39 | return 20 * (spec.abs() + eps).log10() 40 | 41 | def logmel(self, spec): 42 | eps = torch.finfo(spec.abs().dtype).eps 43 | mag = spec.abs() + eps 44 | mel = torch.matmul(self.mel_basis, mag.transpose(-1,-2)).transpose(-1,-2) 45 | return 20 * (mel + eps).log10() 46 | 47 | def summary(self, outputs, prefix, epoch, it): 48 | plot_path = f'{prefix}/{self.plot_dirs}/{epoch}-{it}.png' 49 | wave_path = f'{prefix}/{self.wave_dirs}/{epoch}-{it}.wav' 50 | inp_wave, tar_wave, est_wave = outputs 51 | 52 | batch_size = est_wave.shape[0] 53 | N = min(batch_size, 2) 54 | n = np.random.randint(batch_size-N) if batch_size > N else 0 55 | est_wave = est_wave[n:n+N].squeeze(1) 56 | inp_wave = inp_wave[n:n+N].squeeze(1) 57 | tar_wave = tar_wave[n:n+N].squeeze(1) 58 | 59 | est_spec = self.stft(est_wave) 60 | tar_spec = self.stft(tar_wave) 61 | inp_spec = self.stft(inp_wave) 62 | 63 | est_logmag = self.logmag(est_spec) 64 | est_logmel = self.logmel(est_spec) 65 | tar_logmag = self.logmag(tar_spec) 66 | tar_logmel = self.logmel(tar_spec) 67 | inp_logmag = self.logmag(inp_spec) 68 | inp_logmel = self.logmel(inp_spec) 69 | 70 | inp = { 71 | "state" : inp_wave, 72 | "wav" : inp_wave, 73 | "logmag" : inp_logmag, 74 | "logmel" : inp_logmel, 75 | } 76 | est = { 77 | "state" : est_wave, 78 | "wav" : est_wave, 79 | "logmag" : est_logmag, 80 | "logmel" : est_logmel, 81 | } 82 | tar = { 83 | "state" : est_wave, 84 | "wav" : tar_wave, 85 | "logmag": tar_logmag, 86 | "logmel": tar_logmel, 87 | } 88 | return plot.est_tar_specs(est, tar, inp, plot_path, wave_path, self.sr) 89 | 90 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 91 | if batch_idx == 0: 92 | #if not trainer.sanity_checking: 93 | prefix = "valid" if dataloader_idx == 0 else "test" 94 | summary = self.summary(outputs, prefix, trainer.current_epoch, batch_idx) 95 | if not self.debug: 96 | key = "valid" if dataloader_idx==0 else "test" 97 | trainer.logger.log_table(key=key, **summary) 98 | 99 | class SaveTestResults(Callback): 100 | 101 | def __init__(self, args): 102 | self.debug = args.proc.debug 103 | self.load_name = args.task.load_name 104 | if not os.path.isabs(args.task.ckpt_dir): 105 | ckpt_dir = f"{args.task.root_dir}/{args.task.ckpt_dir}" 106 | else: 107 | ckpt_dir = args.task.ckpt_dir 108 | self.video_dirs = f"{ckpt_dir}/test/{self.load_name}/video" 109 | self.state_dirs = f"{ckpt_dir}/test/{self.load_name}/state" 110 | self.score_dirs = f"{ckpt_dir}/test/{self.load_name}/score" 111 | os.makedirs(self.video_dirs, exist_ok=True) 112 | os.makedirs(self.state_dirs, exist_ok=True) 113 | if os.path.exists(self.score_dirs): 114 | print(f"* Score file already exists! {self.score_dirs}") 115 | print(f" Replacing with a new score...") 116 | shutil.rmtree(self.score_dirs) 117 | os.makedirs(self.score_dirs, exist_ok=True) 118 | 119 | def write_eval_scores(self, scores, epoch, it): 120 | out_score_path = f'{self.score_dirs}/output.txt' 121 | byp_score_path = f'{self.score_dirs}/modals.txt' 122 | for i, score_path in enumerate([out_score_path, byp_score_path]): 123 | keys = list(scores[i].keys()) 124 | bs = scores[i][keys[0]].shape[0] 125 | if not os.path.exists(score_path): 126 | with open(score_path, 'a+') as f: 127 | f.write('\t'.join(['id'] + keys) + '\n') 128 | with open(score_path, 'a+') as f: 129 | for b in range(bs): 130 | line = [f"{epoch}-{it}-{b}"] + [f"{scores[i][key][b]:.8f}" for key in keys] 131 | f.write('\t'.join(line) + '\n') 132 | 133 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): 134 | data, scores = outputs 135 | self.write_eval_scores(scores, trainer.current_epoch, batch_idx) 136 | 137 | class PlotStateVideo(Callback): 138 | 139 | def __init__(self, args): 140 | self.debug = args.proc.debug 141 | self.load_name = args.task.load_name 142 | if not os.path.isabs(args.task.ckpt_dir): 143 | ckpt_dir = f"{args.task.root_dir}/{args.task.ckpt_dir}" 144 | else: 145 | ckpt_dir = args.task.ckpt_dir 146 | self.video_dirs = f"{ckpt_dir}/test/{self.load_name}/video" 147 | self.state_dirs = f"{ckpt_dir}/test/{self.load_name}/state" 148 | os.makedirs(self.video_dirs, exist_ok=True) 149 | os.makedirs(self.state_dirs, exist_ok=True) 150 | 151 | self.sr = args.task.sr 152 | 153 | def summary(self, outputs, epoch, it): 154 | k = f'{epoch}-{it}' 155 | inp_wave, tar_wave, est_wave = outputs 156 | 157 | batch_size = est_wave.shape[0] 158 | est = est_wave.squeeze(1).numpy().T # (Nt, Bs=Nx) 159 | inp = inp_wave.squeeze(1).numpy().T # (Nt, Bs=Nx) 160 | tar = tar_wave.squeeze(1).numpy().T # (Nt, Bs=Nx) 161 | 162 | sf.write(f'{self.video_dirs}/estimate.wav', est.mean(-1), samplerate=self.sr) 163 | sf.write(f'{self.video_dirs}/analytic.wav', inp.mean(-1), samplerate=self.sr) 164 | sf.write(f'{self.video_dirs}/fdtd.wav', tar.mean(-1), samplerate=self.sr) 165 | 166 | np.savez_compressed(f"{self.state_dirs}/{k}.npz", analytic=inp, estimate=est, simulate=tar) 167 | plot.state_specs(f"{self.state_dirs}/{k}.pdf", inp, est, tar) 168 | 169 | plot.rainbowgram(f'{self.video_dirs}/{k}-estimate.pdf', est.mean(-1), self.sr, colorbar=False) 170 | plot.rainbowgram(f'{self.video_dirs}/{k}-analytic.pdf', inp.mean(-1), self.sr, colorbar=False) 171 | plot.rainbowgram(f'{self.video_dirs}/{k}-fdtd.pdf', tar.mean(-1), self.sr, colorbar=False) 172 | 173 | plot.state_video(self.video_dirs, est, self.sr, prefix=k, trim_front=True, fname='estimate') 174 | plot.state_video(self.video_dirs, inp, self.sr, prefix=k, trim_front=True, fname='analytic') 175 | plot.state_video(self.video_dirs, tar, self.sr, prefix=k, trim_front=True, fname='fdtd') 176 | 177 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): 178 | data, scores = outputs 179 | self.summary(data, trainer.current_epoch, batch_idx) 180 | 181 | 182 | class PlotRDE(Callback): 183 | 184 | def __init__(self, args): 185 | self.debug = args.proc.debug 186 | self.plot_dirs = f"{args.task.root_dir}/{args.task.ckpt_dir}/test/plot" 187 | self.wave_dirs = f"{args.task.root_dir}/{args.task.ckpt_dir}/test/wave" 188 | os.makedirs(self.plot_dirs, exist_ok=True) 189 | os.makedirs(self.wave_dirs, exist_ok=True) 190 | 191 | self.sr = args.task.sr 192 | self.n_fft = args.callbacks.plot.n_fft 193 | self.n_mel = args.callbacks.plot.n_mel 194 | self.hop_length = args.callbacks.plot.hop_length 195 | self.window = torch.hann_window(self.n_fft) 196 | mel_fbank = audio.mel_basis(args.task.sr, self.n_fft, self.n_mel) 197 | self.mel_basis = torch.from_numpy(mel_fbank) 198 | 199 | def stft(self, x): 200 | x = torch.stft(x, n_fft = self.n_fft, hop_length = self.hop_length, win_length = self.n_fft, window = self.window) 201 | return torch.view_as_complex(x).transpose(-1,-2) 202 | 203 | def logmag(self, spec): 204 | eps = torch.finfo(spec.abs().dtype).eps 205 | return 20 * (spec.abs() + eps).log10() 206 | 207 | def logmel(self, spec): 208 | eps = torch.finfo(spec.abs().dtype).eps 209 | mag = spec.abs() + eps 210 | mel = torch.matmul(self.mel_basis, mag.transpose(-1,-2)).transpose(-1,-2) 211 | return 20 * (mel + eps).log10() 212 | 213 | def summary(self, outputs, epoch, it): 214 | plot_path = f'{self.plot_dirs}/rde.png' 215 | wave_path = f'{self.wave_dirs}/rde.wav' 216 | sim_list, est_list, factors = outputs 217 | sim_wave, est_wave = [], [] 218 | for i, (sim, est) in enumerate(zip(sim_list, est_list)): 219 | sim = audio.state_to_wav(sim.sum(-1)) 220 | est = audio.state_to_wav(est.sum(-1)) 221 | sim_wave.append(sim) 222 | est_wave.append(est) 223 | 224 | est_list = [x.detach().squeeze().cpu() for x in est_list] 225 | sim_list = [x.detach().squeeze().cpu() for x in sim_list] 226 | est_wave = [x.detach().squeeze().cpu() for x in est_wave] 227 | sim_wave = [x.detach().squeeze().cpu() for x in sim_wave] 228 | est_spec = [self.stft(x) for x in est_wave] 229 | sim_spec = [self.stft(x) for x in sim_wave] 230 | est_logmag = [self.logmag(x) for x in est_spec] 231 | est_logmel = [self.logmel(x) for x in est_spec] 232 | sim_logmag = [self.logmag(x) for x in sim_spec] 233 | sim_logmel = [self.logmel(x) for x in sim_spec] 234 | 235 | est = { 236 | "state" : est_list, 237 | "wav" : est_wave, 238 | "spec" : est_spec, 239 | "logmag" : est_logmag, 240 | "logmel" : est_logmel, 241 | } 242 | sim = { 243 | "state" : sim_list, 244 | "wav" : sim_wave, 245 | "spec" : sim_spec, 246 | "logmag": sim_logmag, 247 | "logmel": sim_logmel, 248 | } 249 | return plot.rde_specs(factors, est, sim, plot_path, wave_path, self.sr) 250 | 251 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 252 | if batch_idx == 0: 253 | if not trainer.sanity_checking: 254 | summary = self.summary(outputs, trainer.current_epoch, batch_idx) 255 | trainer.logger.log_table(key="rde", **summary) 256 | 257 | class SaveResults(Callback): 258 | 259 | def __init__(self, args): 260 | self.debug = args.proc.debug 261 | self.sr = args.task.sr 262 | if isinstance(args.task.testset, str) or len(args.task.testset) == 1: 263 | self.save_dirs = f'eval/{args.task.testset[0]}' 264 | print(f"... Saving results under {self.save_dirs}") 265 | os.makedirs(f"{self.save_dirs}", exist_ok=True) 266 | else: 267 | self.save_dirs = f'eval/{args.task.testset[0]}' 268 | print(f"*** Mulitiple arguments provided for --testset: {args.task.testset}") 269 | print(f" Saving results under {self.save_dirs}") 270 | os.makedirs(f"{self.save_dirs}", exist_ok=True) 271 | 272 | def summary(self, outputs, epoch, it): 273 | outputs = outputs.detach().cpu().numpy() 274 | save_dir = f"{self.save_dirs}/wave" 275 | os.makedirs(save_dir, exist_ok=True) 276 | audio.save_waves(outputs, save_dir, self.sr) 277 | 278 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 279 | self.summary(outputs, trainer.current_epoch, batch_idx) 280 | 281 | -------------------------------------------------------------------------------- /src/configs/callbacks/base.yaml: -------------------------------------------------------------------------------- 1 | #learning_rate_monitor: 2 | # # _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | # logging_interval: ${train.interval} 4 | 5 | timer: 6 | # _target_: callbacks.timer.Timer 7 | step: True 8 | inter_step: False 9 | epoch: True 10 | val: True 11 | 12 | params: 13 | # _target_: callbacks.params.ParamsLog 14 | total: True 15 | trainable: True 16 | fixed: True 17 | 18 | plot: 19 | n_mel: 128 20 | n_fft: 1024 21 | hop_length: 256 22 | -------------------------------------------------------------------------------- /src/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - experiment: base 5 | - callbacks: base # Extra pytorch-lightning features 6 | 7 | task: 8 | root_dir: './results' 9 | result_dir: ${task._name_}-${framework._name_}-${model._name_}-${task.run}-${now:%Y%m%d-%H%M%S-%f} 10 | save_name: null 11 | measure_time: false 12 | 13 | proc: 14 | cpu: false 15 | gpus: [0, ] 16 | seed: 1234 17 | port: ${now:%M%S} 18 | num_workers: 3 19 | 20 | simulate: true # run simulation 21 | debug: false # run simulation in debug mode 22 | evaluate: false # evaluate the simulated output 23 | summarize: false # summarize the evaluation 24 | process_training_data: false # preprocess simulation results as training data 25 | train: false # train the neural network 26 | test : false # test the neural network 27 | 28 | hydra: 29 | run: 30 | dir: ${task.root_dir}/${task.result_dir} 31 | 32 | -------------------------------------------------------------------------------- /src/configs/experiment/all-fixed.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: base 4 | - /task: simulate 5 | 6 | model: 7 | excitation: pluck 8 | 9 | proc: 10 | cpu : true 11 | 12 | task: 13 | num_samples: 1 14 | batch_size: 1 15 | relative_order: 8 16 | precision: double 17 | length: .5 18 | chunk_length: 0.01 19 | skip_silence: false 20 | 21 | sampling_f0 : fix 22 | sampling_kappa : fix 23 | sampling_alpha : fix 24 | sampling_pickup : fix 25 | sampling_T60 : fix 26 | 27 | string_condition: 28 | #- f0_fixed : 110 29 | - f0_fixed : 55.0 30 | - kappa_fixed : 0.08 31 | - alpha_fixed : 20. 32 | - lossless : false 33 | #f0_inf: 110 34 | f0_inf: 55.0 35 | alpha_inf: 20 36 | 37 | hammer_condition: 38 | - x_H_min : 0.1 39 | - x_H_max : 0.1 40 | - v_H_min : 4.00 41 | - v_H_max : 4.00 42 | - M_r_min : 1.50 43 | - M_r_max : 1.50 44 | - w_H_min : 2000 45 | - w_H_max : 2000 46 | 47 | pluck_condition: 48 | - sampling_p_a : fix 49 | - p_a_fixed : 0.02 50 | - sampling_p_x : fix 51 | - p_x_fixed : 0.2 52 | 53 | bow_condition: 54 | - x_b_min : 0.2 55 | - x_b_max : 0.2 56 | - v_b_min : 0.35 57 | - v_b_max : 0.35 58 | - F_b_min : 90 59 | - F_b_max : 90. 60 | - phi_0_max : 9. 61 | - phi_0_min : 9. 62 | - phi_1_max : 0.01 63 | - phi_1_min : 0.01 64 | - wid_min : 4 65 | - wid_max : 4 66 | 67 | load_config: null 68 | 69 | plot: true 70 | #plot_state: true 71 | 72 | -------------------------------------------------------------------------------- /src/configs/experiment/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: base 4 | 5 | model: 6 | excitation: null 7 | 8 | 9 | -------------------------------------------------------------------------------- /src/configs/experiment/evaluate.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /task: evaluate 4 | 5 | task: 6 | result_dir: null 7 | 8 | proc: 9 | simulate: false 10 | evaluate: true 11 | num_workers: 16 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/configs/experiment/linear-string.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: base 4 | - /task: simulate 5 | 6 | model: 7 | excitation: pluck 8 | 9 | proc: 10 | num_workers: 2 11 | 12 | task: 13 | num_samples: 1 14 | dont_save_silence: false 15 | batch_size: 1 16 | relative_order: 8 17 | precision: double 18 | length: .2 19 | chunk_length: 0.001 20 | write_during_process: true 21 | normalize_output: true 22 | 23 | #sampling_f0 : random 24 | sampling_f0 : fix 25 | sampling_kappa : fix 26 | sampling_alpha : fix 27 | sampling_pickup : random 28 | sampling_T60 : fix 29 | 30 | manufactured: true 31 | precorrect: false 32 | #f0_inf: 60 33 | #===== 34 | f0_inf: 55 # 67 33 16 (kappa 0.03) 35 | #f0_inf: 110 # 46 23 11 (kappa 0.03) 36 | #f0_inf: 220 # 32 16 8 (kappa 0.03) 37 | #f0_inf: 440 # 21 10 5 (kappa 0.03) 38 | #===== 39 | #f0_inf: 59 # 64 32 16 (kappa 0.03) 40 | #f0_inf: 220 # 32 16 8 (kappa 0.03) 41 | #f0_inf: 650 # 16 8 4 (kappa 0.03) 42 | #===== 43 | #f0_inf: 145 # 64 32 16 (kappa 0.01) 44 | #f0_inf: 400 # 32 16 8 (kappa 0.01) 45 | #f0_inf: 880 # 16 8 4 (kappa 0.01) 46 | #===== 47 | string_condition: 48 | #- f0_fixed: 60 49 | #===== 50 | - f0_fixed: 55 # 64 32 16 (kappa 0.03) 51 | #- f0_fixed: 110 # 64 32 16 (kappa 0.03) 52 | #- f0_fixed: 220 # 64 32 16 (kappa 0.03) 53 | #- f0_fixed: 440 # 64 32 16 (kappa 0.03) 54 | #===== 55 | #- f0_fixed: 59 # 64 32 16 (kappa 0.03) 56 | #- f0_fixed: 220 # 32 16 8 (kappa 0.03) 57 | #- f0_fixed: 650 # 16 8 4 (kappa 0.03) 58 | #===== 59 | #- f0_fixed: 145 # 64 32 16 (kappa 0.01) 60 | #- f0_fixed: 400 # 32 16 8 (kappa 0.01) 61 | #- f0_fixed: 880 # 16 8 4 (kappa 0.01) 62 | #===== 63 | #===== 64 | - f0_mod_max : 0 65 | #------------------------------ 66 | - lossless : false 67 | #===== 68 | - t60_fixed : 20. 69 | #===== 70 | #- t60_fixed : 10. 71 | #===== 72 | #- t60_fixed : 5 73 | #------------------------------ 74 | - kappa_min : 0.03 75 | - kappa_max : 0.03 76 | - kappa_fixed : 0.03 77 | #===== 78 | #- kappa_min : 0.02 79 | #- kappa_max : 0.02 80 | #- kappa_fixed : 0.02 81 | #===== 82 | #- kappa_min : 0.01 83 | #- kappa_max : 0.01 84 | #- kappa_fixed : 0.01 85 | #===== 86 | #------------------------------ 87 | - alpha_fixed : 1. 88 | - alpha_min : 1. 89 | - alpha_max : 1. 90 | 91 | alpha_inf: 1 92 | 93 | pluck_condition: 94 | - sampling_p_a : fix 95 | - p_a_fixed: 0.01 96 | - sampling_p_x : fix 97 | - p_x_fixed : 0.3 98 | - pluck_profile : smooth 99 | #- pluck_profile : raised_cosine 100 | #- pluck_profile : triangular 101 | 102 | hammer_condition: 103 | - x_H_min : 0.5 104 | - x_H_max : 0.5 105 | - v_H_min : 2.50 106 | - v_H_max : 2.50 107 | - M_r_min : 10. 108 | - M_r_max : 10. 109 | - w_H_min : 3000 110 | - w_H_max : 3000 111 | - alpha_fixed: 3 112 | 113 | load_config: null 114 | 115 | plot: true 116 | plot_state: true 117 | 118 | -------------------------------------------------------------------------------- /src/configs/experiment/nonlinear-string.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: base 4 | - /task: simulate 5 | 6 | model: 7 | excitation: pluck 8 | 9 | proc: 10 | num_workers: 2 11 | 12 | task: 13 | num_samples: 1 14 | dont_save_silence: false 15 | batch_size: 1 16 | relative_order: 8 17 | precision: double 18 | length: .2 19 | chunk_length: 0.001 20 | write_during_process: true 21 | normalize_output: true 22 | 23 | #sampling_f0 : random 24 | sampling_f0 : fix 25 | sampling_kappa : fix 26 | sampling_alpha : fix 27 | sampling_pickup : random 28 | sampling_T60 : fix 29 | 30 | precorrect: false 31 | string_condition: 32 | - f0_fixed: 60.0 33 | - f0_mod_max : 0 34 | - lossless : false 35 | - t60_fixed : 20. 36 | - kappa_min : 0.03 37 | - kappa_max : 0.03 38 | - kappa_fixed : 0.03 39 | - alpha_fixed : 1. 40 | #===== 41 | f0_inf: 60.0 42 | alpha_inf: 1 43 | 44 | pluck_condition: 45 | - sampling_p_a : fix 46 | - p_a_fixed: 0.01 47 | - sampling_p_x : fix 48 | - p_x_fixed : 0.25 49 | - pluck_profile : smooth 50 | #- pluck_profile : raised_cosine 51 | #- pluck_profile : triangular 52 | 53 | bow_condition: 54 | - x_b_min : 0.3 55 | - x_b_max : 0.3 56 | - v_b_min : 0.35 57 | - v_b_max : 0.35 58 | - F_b_min : 94.013 59 | - F_b_max : 94.013 60 | - phi_0_min : 4.695 61 | - phi_0_max : 4.695 62 | - phi_1_min : 0.166 63 | - phi_1_max : 0.166 64 | - wid_min : 5.3 65 | - wid_max : 5.3 66 | 67 | hammer_condition: 68 | - x_H_min : 0.3 69 | - x_H_max : 0.3 70 | - v_H_min : 2.50 71 | - v_H_max : 2.50 72 | - M_r_min : 10. 73 | - M_r_max : 10. 74 | - w_H_min : 3000 75 | - w_H_max : 3000 76 | - alpha_fixed: 3 77 | 78 | load_config: null 79 | 80 | plot: true 81 | plot_state: true 82 | 83 | -------------------------------------------------------------------------------- /src/configs/experiment/nsynth-like.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: fdtd 4 | - /task: simulate 5 | 6 | model: 7 | excitation: pluck 8 | 9 | proc: 10 | num_workers: 4 11 | 12 | task: 13 | randomize_name: true 14 | num_samples: 32000 15 | batch_size: 24 16 | relative_order: 4 17 | precision: single 18 | length: 1.0 19 | chunk_length: 1.0 20 | 21 | write_during_process: false 22 | normalize_output: true 23 | 24 | randomize_each : batch 25 | sampling_f0 : random 26 | sampling_kappa : random 27 | sampling_alpha : random 28 | sampling_pickup : random 29 | sampling_T60 : random 30 | 31 | string_condition: 32 | - f0_min: 98.00 # G2 33 | - f0_max: 440.0 # A4 34 | - f0_diff_max : 30 35 | - f0_mod_max : 0.08 36 | - kappa_min : 0.01 37 | - kappa_max : 0.03 38 | - alpha_min : 1. 39 | - alpha_max : 25. 40 | - t60_min_1 : 10. 41 | - t60_max_1 : 25. 42 | - t60_min_2 : 10. 43 | - t60_max_2 : 30. 44 | f0_inf: 98.00 # G2 45 | alpha_inf: 1 46 | 47 | pluck_condition: 48 | - sampling_p_a : random 49 | - p_a_max: 0.02 50 | - sampling_p_x : random 51 | - p_x_max: 0.5 52 | 53 | hammer_condition: 54 | - M_r_min : 1.0 55 | - M_r_max : 10. 56 | - alpha_fixed: 3 57 | 58 | load_config: null 59 | 60 | plot: false 61 | #plot_state: true 62 | 63 | -------------------------------------------------------------------------------- /src/configs/experiment/process_training_data.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /task: process_training_data 4 | 5 | proc: 6 | cpu : false 7 | gpus: [0, ] 8 | simulate: false 9 | process_training_data: true 10 | 11 | task: 12 | data_split : 0 # splits the list of whole data into `data_split` parts (set 0 to disable) 13 | split_n : 0 # only processes the `split_n`-th part of the splitted data sublist 14 | result_dir: 'my_fdtd_results' 15 | save_dir : 'my_dmsp_data' 16 | 17 | sr: 48000 18 | Nx: 256 # upsampled spatial grid size 19 | strict: false # whether to assert the optimality in the analytic solution 20 | 21 | -------------------------------------------------------------------------------- /src/configs/experiment/synth-dmsp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /task: synthesize 4 | - /model: dmsp 5 | - /framework: supervised 6 | - /optimizer: radam 7 | - /scheduler: noam 8 | 9 | model: 10 | n_bands : 65 11 | hidden_dim : 512 12 | n_modes : 40 13 | embed_dim : 128 14 | use_precomputed_mode: false 15 | 16 | scheduler: 17 | warmup_steps: 1000 18 | 19 | optimizer: 20 | lr: 0.001 21 | 22 | proc: 23 | cpu : false 24 | gpus: [0, ] 25 | simulate: false 26 | train: true 27 | 28 | task: 29 | load_dir : './results' 30 | load_name: 'my_dmsp_data' 31 | 32 | valid_batch_size : 256 33 | test_batch_size : 256 34 | batch_size: 128 35 | 36 | n_fft: 2048 37 | train_lens: 1 38 | total_lens: 1 39 | valid_epoch: 1 40 | grad_clip: [null] 41 | loss_criteria: ['l1', 'magspec', 'melspec', 'f0', 'modefreq', 'modeamps'] 42 | eval_criteria: ['sisdr', 'modefreq', 'modeamps'] 43 | 44 | load_config: null 45 | 46 | plot: true 47 | #plot_state: true 48 | 49 | plot_test_video: false 50 | save_test_score: false 51 | 52 | -------------------------------------------------------------------------------- /src/configs/framework/supervised.yaml: -------------------------------------------------------------------------------- 1 | _name_: supervised 2 | 3 | -------------------------------------------------------------------------------- /src/configs/model/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - excitation: null # null for random excitation 3 | 4 | -------------------------------------------------------------------------------- /src/configs/model/bow.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fdtd 3 | 4 | excitation: 'bow' 5 | -------------------------------------------------------------------------------- /src/configs/model/dmsp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - trainer 3 | 4 | _name_: dmsp 5 | 6 | harmonic: 'inharmonic' 7 | ddsp_frequency_modulation: null 8 | use_precomputed_mode: false 9 | -------------------------------------------------------------------------------- /src/configs/model/fdtd.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - excitation: null # null for random excitation 3 | 4 | _name_: fdtd 5 | -------------------------------------------------------------------------------- /src/configs/model/hammer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fdtd 3 | 4 | excitation: 'hammer' 5 | -------------------------------------------------------------------------------- /src/configs/model/pluck.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fdtd 3 | 4 | excitation: 'pluck' 5 | -------------------------------------------------------------------------------- /src/configs/model/trainer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - excitation: null # null for random excitation 3 | 4 | _name_: trainer 5 | 6 | # for NN (ddsp, dmsp) 7 | n_bands : ??? 8 | hidden_dim: ??? 9 | embed_dim : ??? 10 | n_modes : ??? 11 | block_size: 256 12 | x_scale: [0., 1.] 13 | t_scale: [0., .3] 14 | gamma_scale: [196, 880] 15 | kappa_scale: [.01, .03] 16 | alpha_scale: [1., 30.] 17 | sig_0_scale: [0., .7] 18 | sig_1_scale: [0., 1e-5] 19 | ddsp_frequency_modulation: ??? 20 | use_precomputed_mode: ??? 21 | -------------------------------------------------------------------------------- /src/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.Adam 2 | _name_: adam 3 | lr: 0.0002 # Initial learning rate 4 | # weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /src/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.AdamW 2 | _name_: adamw 3 | lr: 0.0004 # Initial learning rate 4 | weight_decay: 0.001 # Weight decay 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /src/configs/optimizer/lamb.yaml: -------------------------------------------------------------------------------- 1 | # _target_: utils.lamb.JITLamb 2 | _name_: lamb 3 | lr: 0.0004 # Initial learning rate 4 | weight_decay: 0.0 # Weight decay for adam|lamb 5 | -------------------------------------------------------------------------------- /src/configs/optimizer/radam.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.Adam 2 | _name_: radam 3 | lr: 0.0001 # Initial learning rate 4 | -------------------------------------------------------------------------------- /src/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.SGD 2 | _name_: sgd 3 | lr: 0.0004 # Initial learning rate 4 | momentum: 0.9 5 | weight_decay: 0.0 # Weight decay for adam|lamb 6 | -------------------------------------------------------------------------------- /src/configs/scheduler/constant.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule 6 | _name_: constant 7 | -------------------------------------------------------------------------------- /src/configs/scheduler/constant_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule_with_warmup 6 | _name_: constant_warmup 7 | num_warmup_steps: 1000 # Number of iterations for LR warmup 8 | -------------------------------------------------------------------------------- /src/configs/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | #interval: epoch 5 | scheduler: 6 | # _target_: torch.optim.lr_scheduler.CosineAnnealingLR 7 | _name_: cosine 8 | T_max: 6718 # Max number of epochs steps for LR scheduler 9 | eta_min: 1e-6 # Min learning rate for cosine scheduler 10 | -------------------------------------------------------------------------------- /src/configs/scheduler/cosine_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup 7 | num_warmup_steps: 500 8 | num_training_steps: 6718 9 | -------------------------------------------------------------------------------- /src/configs/scheduler/linear_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_linear_schedule_with_warmup 6 | _name_: linear_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /src/configs/scheduler/multistep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | # _target_: torch.optim.lr_scheduler.MultiStepLR 5 | scheduler: 6 | _name_: multistep 7 | milestones: [80,140,180] 8 | gamma: 0.2 9 | -------------------------------------------------------------------------------- /src/configs/scheduler/noam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_constant_schedule_with_warmup 6 | _name_: noam 7 | warmup_steps: 1000 # Number of iterations for LR warmup 8 | -------------------------------------------------------------------------------- /src/configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | monitor: ??? # must be specified 5 | scheduler: 6 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | _name_: plateau 8 | mode: ${train.mode} # Which metric to monitor 9 | factor: 0.2 # Decay factor when ReduceLROnPlateau is used 10 | patience: 20 11 | min_lr: 0.0 # Minimum learning rate during annealing 12 | -------------------------------------------------------------------------------- /src/configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | scheduler: 5 | # _target_: torch.optim.lr_scheduler.StepLR 6 | _name_: step 7 | step_size: 20 8 | gamma: 0.99 9 | -------------------------------------------------------------------------------- /src/configs/scheduler/timm_cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: epoch 4 | monitor: ??? # must be specified 5 | scheduler: 6 | _name_: timm_cosine 7 | t_initial: 300 8 | lr_min: 1e-5 9 | cycle_decay: 0.1 # changed from decay_rate in timm 0.5.4 10 | warmup_lr_init: 1e-6 11 | warmup_t: 10 12 | cycle_limit: 1 -------------------------------------------------------------------------------- /src/configs/task/evaluate.yaml: -------------------------------------------------------------------------------- 1 | _name_: evaluate 2 | 3 | load_dir: null 4 | save_name: null 5 | 6 | # number of data 7 | num_samples: 1600 # number of samples for simulation 8 | batch_size: 16 # simulation batch size 9 | 10 | # number of samples in time 11 | length: 1. # length (in sec) to simulate 12 | chunk_length: -1. # length (in sec) to simulate in a chunk 13 | sr: 48000 # temporal sampling rate 14 | 15 | # number of samples in space 16 | f0_inf: 20. # minimum fundamental frequency (in Hz); can reduce redundancy in u. 17 | alpha_inf: 1.0 # minimum nonlinearity; can reduce redundancy in z (or zeta). 18 | 19 | fix_f0: false # fix input f0 20 | fix_kappa: false # fix string stiffness 21 | fix_alpha: false # fix nonlinearity 22 | fix_pickup: false # fix pickup position 23 | fix_T60: false # fix string lossiness 24 | 25 | save: true # save results in npz file 26 | plot: false # plot results 27 | plot_state: false # plot string video (can raise slowdown while plotting) 28 | measure_time: false # measure the process time 29 | write_during_process: true # write output wav file everytime it finishes processing every chunk 30 | 31 | -------------------------------------------------------------------------------- /src/configs/task/process_training_data.yaml: -------------------------------------------------------------------------------- 1 | _name_: process_training_data 2 | 3 | load_dir: null 4 | sr: 48000 5 | order: null 6 | strict: null 7 | 8 | -------------------------------------------------------------------------------- /src/configs/task/simulate.yaml: -------------------------------------------------------------------------------- 1 | _name_: simulate 2 | 3 | # number of data 4 | num_samples: 1600 # number of samples for simulation 5 | batch_size: 16 # simulation batch size 6 | skip_nan: true # skip diverged simulations while saving. otherwise raises error 7 | skip_silence: true # skip silent simulations while saving. otherwise save the silent results 8 | silence_threshold: -23. # dB 9 | 10 | randomize_name: false # randomize savenames 11 | 12 | # error configuration 13 | relative_order: 4 # order of the discretization error relative to the spatial grid size 14 | precision: single # single/double precision 15 | manufactured: false # simulate for the manufactured solution (used for verification purposes) 16 | 17 | # number of samples in time 18 | length: 1. # length (in sec) to simulate 19 | chunk_length: -1. # length (in sec) to simulate in a chunk 20 | sr: 48000 # temporal sampling rate 21 | 22 | # number of samples in space 23 | f0_inf: 20. # minimum fundamental frequency (in Hz); can reduce redundancy in u. 24 | alpha_inf: 1.0 # minimum nonlinearity; can reduce redundancy in z (or zeta). 25 | 26 | randomize_each : 'batch' # randomize over [batch/iter] 27 | sampling_f0 : random # (fix, span, random) input f0 28 | sampling_kappa : random # (fix, span, random) string stiffness 29 | sampling_alpha : random # (fix, span, random) nonlinearity 30 | sampling_pickup : random # (fix, span, random) pickup position 31 | sampling_T60 : random # (fix, span, random) string lossiness 32 | 33 | precorrect: true # pre-correct fundamental frequency with string stiffness 34 | 35 | lambda_c: 1 # grid resolution factor (should be >= 1; best to be 1) 36 | theta_t: null # implicit scheme free parameter (should be > 1/2 for a stable simulation.) set this value by `null` if you don't understand this (then it will automatically be set by an appropriate one). 37 | 38 | string_condition: 39 | - f0_min : null 40 | - f0_max : null 41 | - f0_diff_max : null 42 | - f0_mod_max : null 43 | - f0_fixed : null 44 | - kappa_min : null 45 | - kappa_max : null 46 | - kappa_fixed : null 47 | - alpha_min : null 48 | - alpha_max : null 49 | - alpha_fixed : null 50 | - pos_min : null 51 | - pos_max : null 52 | - pos_fixed : null 53 | - lossless : null 54 | - t60_min_1 : null 55 | - t60_max_1 : null 56 | - t60_min_2 : null 57 | - t60_max_2 : null 58 | - t60_fixed : null 59 | - t60_diff_max : null 60 | 61 | pluck_condition: 62 | - sampling_p_a : random 63 | - p_a_min : null 64 | - p_a_max : null 65 | - p_a_fixed : null 66 | - sampling_p_x : random 67 | - p_x_min : null 68 | - p_x_max : null 69 | - p_x_fixed : null 70 | - pluck_profile: null 71 | 72 | hammer_condition: 73 | - x_H_min : null 74 | - x_H_max : null 75 | - v_H_min : null 76 | - v_H_max : null 77 | - M_r_min : null 78 | - M_r_max : null 79 | - w_H_min : null 80 | - w_H_max : null 81 | - alpha_fixed : null 82 | 83 | bow_condition: 84 | - x_b_min : null 85 | - x_b_max : null 86 | - x_b_maxdiff: null 87 | - v_b_min : null 88 | - v_b_max : null 89 | - F_b_min : null 90 | - F_b_max : null 91 | - do_pulloff : null 92 | - F_b_maxdiff: null 93 | - phi_0_max : null 94 | - phi_0_min : null 95 | - phi_1_max : null 96 | - phi_1_min : null 97 | - wid_min : null 98 | - wid_max : null 99 | 100 | save: true # save results in npz file 101 | plot: true # plot results 102 | plot_state: false # plot string video (can raise slowdown while plotting) 103 | measure_time: false # measure the process time 104 | write_during_process: true # write output wav file everytime it finishes processing every chunk 105 | normalize_output: false # whether normalize the pickup output while saving the .wav file 106 | surface_integral: true # pickup using surface integral of velocities (if flase: state interpolation) 107 | 108 | load_dir: null 109 | save_name: null 110 | load_config: ??? 111 | 112 | -------------------------------------------------------------------------------- /src/configs/task/synthesize.yaml: -------------------------------------------------------------------------------- 1 | _name_: synthesize 2 | 3 | sr: 48000 4 | batch_size: 16 # training batch size 5 | valid_batch_size: 4 # validation batch size 6 | test_batch_size: 4 # test batch size 7 | train_lens: 1. # train sample time length (set `null` for whole length) 8 | 9 | ckpt_dir: null 10 | project: 'string' 11 | run: null 12 | 13 | load_dir: './results' 14 | load_name: 'nsynth-*' 15 | 16 | total_epoch: 50 17 | valid_epoch: 1 18 | overfit: false 19 | 20 | num_valid_samples: -1 # number of samples for test (set as float to use as a portion of whole data samples) 21 | num_test_samples: -1 # number of samples for test (set as float to use as a portion of whole data s 22 | 23 | grad_clip: [null,] # value to clip gradient norm that corresponds to each optimizer 24 | loss_criteria: ['mse', 'mrstft', ] 25 | eval_criteria: ['sisdr', 'mrstft', ] 26 | 27 | 28 | save_name: null 29 | load_config: ??? 30 | 31 | -------------------------------------------------------------------------------- /src/dataset/synthesize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | import os 5 | import soundfile as sf 6 | import pickle 7 | import glob 8 | import scipy 9 | import random 10 | import librosa 11 | from tqdm import tqdm 12 | import json 13 | import time 14 | import torch.nn.functional as F 15 | import torchaudio.functional as TAF 16 | import src.utils.audio as audio 17 | import src.utils.data as data 18 | import src.utils.misc as ms 19 | 20 | import os 21 | import sys 22 | 23 | class GenericDataset(torch.utils.data.Dataset): 24 | 25 | def __init__( 26 | self, 27 | data_dir, 28 | load_name, 29 | split='train', 30 | trim=None, 31 | alpha=None, 32 | Nx=256, 33 | ): 34 | np.random.seed(0) 35 | self.alpha = '*' if alpha is None else alpha 36 | self.trim = trim if trim is not None else None 37 | 38 | self.keys = ['x', 't'] 39 | self.keys += ['kappa', 'alpha', 'f0', 'T60',] 40 | self.keys += ['u0'] 41 | self.keys += ['mode_freq', 'mode_amps'] 42 | self.keys += ['gain'] 43 | self.keys += ['ua_f0',] 44 | self.keys += ['ut_f0',] 45 | data_expr = lambda split: f"{data_dir}/{load_name}/{split}/*/ut-0.wav" 46 | # set `load_name` to be the directory containing 47 | # data preprocessed by `src/task/process_training_data.py` 48 | def get_string_id(path): return path.split('/')[-2] 49 | def get_space_idx(path): return int(os.path.splitext(os.path.basename(path))[0].split('-')[-1]) 50 | def get_data_list(split): 51 | wp = f"{data_dir}/{load_name}/{split}/*/ut-0.wav" 52 | total_data = [p for p in glob.glob(data_expr(split))] 53 | return sorted(total_data, key=lambda i: \ 54 | (get_string_id(i), get_space_idx(i))) 55 | dl = get_data_list(split.lower()) 56 | assert len(dl) > 0, f"[Loader] No data found in the directory {data_expr(split.lower())}." 57 | self.Nx = Nx 58 | self.tgt_list = dl 59 | 60 | self.n_data = len(self.tgt_list) * Nx 61 | 62 | def load_data(self, tgt_path): 63 | ''' simulation.npz 64 | uout, zout, state_u, state_z 65 | v_r_out, F_H_out, u_H_out, 66 | bow_mask, hammer_mask, pluck_mask 67 | 68 | string.npz 69 | kappa, alpha, u0, v0, f0 70 | pos, T60, target_f0 71 | 72 | bow.npz 73 | x_B, v_B, F_B, phi_0, phi_1, wid_B 74 | 75 | hammer.npz 76 | x_H, v_H, u_H, w_H, M_r, alpha 77 | ''' 78 | # {data_dir}/{load_name}/{split}/{string_id}-*/ut-{nx}.wav 79 | tgt_path_list = tgt_path.split('/') 80 | string_dir = '/'.join(tgt_path_list[:-1]) 81 | filename = tgt_path_list[-1] 82 | x_idx = int(filename.split('.')[0].split('-')[-1]) 83 | 84 | npz_path = os.path.join(string_dir, 'parameters.npz') 85 | lin_path = tgt_path.replace('ut-', 'ua-') 86 | linear_wave = sf.read(lin_path)[0] 87 | 88 | Nt = len(linear_wave) 89 | if self.trim is not None: 90 | st = np.random.randint(Nt-self.trim) 91 | et = st + self.trim 92 | linear_wave = linear_wave[st:et] 93 | _tgt = data.load_wav(tgt_path, npz_path, [st, et], keys=self.keys) 94 | else: 95 | _tgt = data.load_wav(tgt_path, npz_path, keys=self.keys) 96 | xval = _tgt['x'][0,x_idx] 97 | coef = _tgt['mode_amps'][:,x_idx][None,None,:] 98 | _tgt.update(dict(x=xval)) 99 | _tgt.update(dict(mode_coef=coef)) 100 | _tgt.update(dict(analytic=linear_wave)) 101 | return _tgt 102 | 103 | def __len__(self): 104 | return self.n_data 105 | 106 | def __getitem__(self, index): 107 | anchor_index = index // self.Nx 108 | spaces_index = index % self.Nx 109 | anchor_path = self.tgt_list[anchor_index] 110 | target_path = anchor_path.replace('ut-0.wav', f'ut-{spaces_index}.wav') 111 | return self.load_data(target_path) 112 | 113 | class Trainset(GenericDataset): 114 | 115 | def __init__( 116 | self, 117 | data_dir, 118 | load_name, 119 | trim=None, 120 | ): 121 | super().__init__( 122 | data_dir, 123 | load_name, 124 | split='Train', 125 | trim=trim, 126 | ) 127 | print(f"[Loader] Train samples:") 128 | print(f"\t(total) {len(self)}") 129 | 130 | class Testset(GenericDataset): 131 | 132 | def __init__( 133 | self, 134 | data_dir, 135 | load_name, 136 | split='Test', 137 | trim=None, 138 | ): 139 | super().__init__( 140 | data_dir, 141 | load_name, 142 | split=split, 143 | trim=trim, 144 | ) 145 | print(f"[Loader] {split} samples:") 146 | print(f"\t(total) {len(self)}") 147 | 148 | 149 | if __name__=='__main__': 150 | dset = Trainset('/data2/private/szin/dfdm', 'cvg-*') 151 | data = dset[0] 152 | for key, value in data.items(): 153 | print(key, value.shape) 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /src/model/cpp/bow.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | # define M_PI 3.14159265358979323846 /* pi */ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "misc.h" 9 | 10 | torch::Tensor hard_bow(torch::Tensor v_rel, torch::Tensor a, torch::Tensor eps) { 11 | return torch::sign(v_rel) * (eps + (1-eps) * torch::exp(-a * v_rel.abs())); 12 | } 13 | torch::Tensor soft_bow(torch::Tensor v_rel, torch::Tensor a, torch::Tensor eps) { 14 | return (2*a).pow(.5) * v_rel * torch::exp(-a * v_rel.pow(2) + 0.5); 15 | } 16 | 17 | vector bow_term_rhs( 18 | torch::Tensor N, 19 | torch::Tensor h, 20 | float k, 21 | torch::Tensor u, 22 | torch::Tensor u1, 23 | torch::Tensor u2, 24 | torch::Tensor x_B, 25 | torch::Tensor v_B, 26 | torch::Tensor F_B, 27 | torch::Tensor wid, 28 | torch::Tensor phi_0, 29 | torch::Tensor phi_1, 30 | int iter) { 31 | 32 | auto rc = raised_cosine(N-1, x_B, wid, u1.size(1)); // (batch_size, max(N), 1) 33 | auto I = rc; 34 | auto J = rc / h.view({-1,1,1}); 35 | 36 | torch::Tensor v_rel; 37 | if (iter == 0) { v_rel = torch::matmul(I.transpose(1,2), ((u1 - u2) / k - v_B.view({-1,1,1}))); } 38 | else { v_rel = torch::matmul(I.transpose(1,2), ((u - u1) / k - v_B.view({-1,1,1}))); } 39 | auto Gamma = J * F_B.view({-1,1,1}) * hard_bow(v_rel, phi_0.view({-1,1,1}), phi_1.view({-1,1,1})); 40 | return { - pow(k, 2.) * Gamma, v_rel }; // {(batch_size, 1, 1), (batch_size, 1, 1)} 41 | } 42 | 43 | -------------------------------------------------------------------------------- /src/model/cpp/bow.h: -------------------------------------------------------------------------------- 1 | #ifndef BOW_H 2 | #define BOW_H 3 | 4 | torch::Tensor hard_bow(torch::Tensor, torch::Tensor, torch::Tensor); 5 | torch::Tensor soft_bow(torch::Tensor, torch::Tensor, torch::Tensor); 6 | vector bow_term_rhs( 7 | torch::Tensor, torch::Tensor, 8 | float, 9 | torch::Tensor, torch::Tensor, torch::Tensor, 10 | torch::Tensor, torch::Tensor, torch::Tensor, 11 | torch::Tensor, torch::Tensor, torch::Tensor, 12 | int 13 | ); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /src/model/cpp/hammer.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | # define M_PI 3.14159265358979323846 /* pi */ 3 | const float M_HD = -0.01; /* max hammer displacement */ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "misc.h" 10 | 11 | vector hammer_loop( 12 | torch::Tensor u_H1, // hammer displacement, curr. time (batch_size,) 13 | torch::Tensor u_H2, // hammer displacement, prev. time (batch_size,) 14 | torch::Tensor eta_1, // relative hammer displacement, curr. time (batch_size,) 15 | torch::Tensor eta_2, // relative hammer displacement, prev. time (batch_size,) 16 | torch::Tensor alpha, // nonlinear exponent (batch_size,) 17 | torch::Tensor w_H, // hammer stiffness parameter (batch_size,) 18 | torch::Tensor M_r, // hammer / string mass ratio (batch_size,) 19 | torch::Tensor eps_u, // hammering point string displacement, future time (batch_size,) 20 | float k, 21 | torch::Tensor threshold, 22 | torch::Tensor mask) { 23 | int iter = 0; 24 | bool not_converged = true; 25 | torch::Tensor F_H; 26 | torch::Tensor u_H; 27 | torch::Tensor eta; 28 | torch::Tensor eta_estimate; 29 | 30 | int batch_size = u_H1.size(0); 31 | eta = eta_1 * mask; 32 | eta_estimate = eta_1 * mask; 33 | while (not_converged) { 34 | eta = eta_estimate; 35 | 36 | // hammering force 37 | auto f_H = w_H.pow(1+alpha) 38 | * torch::relu(eta_1).pow(alpha-1) 39 | * (eta + eta_2) / 2; 40 | F_H = torch::where(eta_1.gt(0), f_H, torch::zeros_like(f_H)); 41 | 42 | // u_tt = - F_H 43 | // (u_H - 2u_H1 + u_H2) / k^2 = - F_H 44 | u_H = 2*u_H1 - u_H2 - pow(k, 2.) * F_H; // hammering point string displacement, future time 45 | u_H = torch::relu(u_H - M_HD) + M_HD; 46 | 47 | eta_estimate = (u_H - eps_u) * mask; 48 | 49 | torch::Tensor residual = (eta - eta_estimate).abs(); // (batch,) 50 | 51 | not_converged = residual.gt(threshold).any().item(); 52 | } 53 | return { F_H, u_H }; 54 | } 55 | 56 | vector hammer_term_rhs( 57 | torch::Tensor N, 58 | torch::Tensor h, 59 | float k, 60 | torch::Tensor u, 61 | torch::Tensor u1, 62 | torch::Tensor u2, 63 | torch::Tensor x_H, // hammer position 64 | torch::Tensor u_H1, // hammer displacement, curr. time (batch_size,) 65 | torch::Tensor u_H2, // hammer displacement, prev. time (batch_size,) 66 | torch::Tensor w_H, // hammer stiffness parameter (batch_size,) 67 | torch::Tensor M_r, // hammer / string mass ratio (batch_size,) 68 | torch::Tensor alpha, // nonlinear exponent (batch_size,) 69 | torch::Tensor threshold, // threshold (batch_size,) 70 | torch::Tensor mask) { // zero-mask updates on batches that are not hammer-excited (batch_size,) 71 | auto eps = floor_dirac_delta(N-1, x_H, u1.size(1)).transpose(1,2); 72 | auto eps_u = torch::matmul(eps, u).view(-1); // (batch_size,) 73 | auto eta_1 = u_H1 - torch::matmul(eps, u1).view(-1); // (batch_size,) 74 | auto eta_2 = u_H2 - torch::matmul(eps, u2).view(-1); // (batch_size,) 75 | 76 | auto loop_out = hammer_loop( 77 | u_H1, u_H2, eta_1, eta_2, 78 | alpha, w_H, M_r, eps_u, k, threshold, mask); 79 | auto F_H = loop_out[0].view({-1,1,1}); 80 | auto u_H = loop_out[1].view({-1,1,1}); 81 | auto Gamma = eps.transpose(1,2) * M_r.view({-1,1,1}) * F_H; 82 | auto d_H = eps.transpose(1,2) * torch::relu(u_H - eps_u.view({-1,1,1})); 83 | 84 | return { - pow(k, 2.) * Gamma, F_H, u_H, d_H }; 85 | } 86 | 87 | 88 | -------------------------------------------------------------------------------- /src/model/cpp/hammer.h: -------------------------------------------------------------------------------- 1 | #ifndef HAMMER_H 2 | #define HAMMER_H 3 | 4 | vector hammer_loop( 5 | torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, 6 | torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, 7 | float, torch::Tensor, torch::Tensor 8 | ); 9 | vector hammer_term_rhs( 10 | torch::Tensor, torch::Tensor, float, // N, h, k 11 | torch::Tensor, torch::Tensor, torch::Tensor, // u, u1, u2 12 | torch::Tensor, torch::Tensor, torch::Tensor, // x_H, u_H1, u_H2, w_H 13 | torch::Tensor, torch::Tensor, torch::Tensor, // w_H, M_r, alpha 14 | torch::Tensor, torch::Tensor // threshold, hammer_mask 15 | ); 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /src/model/cpp/misc.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | # define M_PI 3.14159265358979323846 /* pi */ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | using namespace at; 10 | 11 | namespace F = torch::nn::functional; 12 | 13 | torch::Device device() { 14 | return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; 15 | } 16 | torch::TensorOptions ntopt() { // new tensor options 17 | return torch::TensorOptions().device(device()); 18 | } 19 | 20 | torch::Tensor raised_cosine( 21 | torch::Tensor n, // number of samples in space 22 | torch::Tensor ctr, // center point of raised cosine curve in (0, 1] 23 | torch::Tensor wid, // width of raised cosine curve in (0, 1] 24 | int N) { 25 | 26 | float h = 1. / N; 27 | auto xax = torch::linspace(h, 1, N, ntopt()).view({1,-1,1}).to(n.dtype()); // (1, N, 1) 28 | ctr = (ctr * n / N).view({-1,1,1}); // abs portion -> rel portion (batch_size, 1, 1) 29 | wid = (wid * n / N).view({-1,1,1}); // abs portion -> rel portion (batch_size, 1, 1) 30 | auto ind = torch::sign(torch::relu(-(xax - ctr - wid / 2) * (xax - ctr + wid / 2))); 31 | auto out = 0.5 * ind * (1 + torch::cos(2 * M_PI * (xax - ctr) / wid)); 32 | out = out / out.abs().sum(1, /*keepdim*/true); 33 | return out; // (batch_size, N, 1) 34 | } 35 | 36 | torch::Tensor floor_dirac_delta( 37 | torch::Tensor n, // number of samples in space 38 | torch::Tensor ctr, // center point of raised cosine curve 39 | int N) { 40 | auto xax = torch::ones_like(ctr).view({-1,1,1}).repeat({1,N,1}).cumsum(1) - 1; 41 | auto idx = torch::floor(ctr * n).view({-1,1,1}); 42 | return torch::floor(xax).eq(idx).to(n.dtype()); // (batch_size, N, 1) 43 | } 44 | 45 | torch::Tensor domain_x(int N, torch::Tensor n) { 46 | /* N (int): number of maximal samples in space 47 | * n (B, 1, 1): number of actual samples in space 48 | */ 49 | auto v = 2 / n; 50 | v = (v * torch::ones_like(v).repeat({1,1,N})).cumsum(2) - v; 51 | return (v.clamp(0,2).transpose(2,1) - 1)/2; 52 | } 53 | 54 | torch::Tensor triangular( 55 | int N, 56 | torch::Tensor n, 57 | torch::Tensor p_x, 58 | torch::Tensor p_a) { 59 | /* N (int): number of maximal samples in space 60 | * n (B, 1, 1): number of actual samples in space 61 | * p_x (B, Nt, 1): peak position 62 | * p_a (B, Nt, 1): peak amplitude 63 | */ 64 | auto vel_l = p_a / p_x / n; 65 | auto vel_r = p_a / (1-p_x) / n; 66 | vel_l = (vel_l * torch::ones_like(vel_l).repeat({1,1,N})).cumsum(2) - vel_l; 67 | vel_r = ((vel_r * torch::ones_like(vel_r).repeat({1,1,N})).cumsum(2) - vel_r * (N-n+1)).clamp(/*min*/0).flip(2); 68 | return torch::minimum(vel_l, vel_r); 69 | } 70 | 71 | torch::Tensor expand(torch::Tensor X, int N_w, int N_h) { 72 | int n_h = X.size(-2); int n_w = X.size(-1); 73 | auto kwargs = F::PadFuncOptions({0, N_w-n_w, 0, N_h-n_h}).mode(torch::kConstant); 74 | return F::pad(X.unsqueeze(1), kwargs).squeeze(1); 75 | } 76 | 77 | /* Interpolator */ 78 | torch::Tensor Interpolator(int dim_i, int dim_o) { 79 | /* dim_i (int) : input dimension 80 | * dim_o (int) : output dimension 81 | * Returns a tensor with shape (dim_o, dim_i) 82 | * Be sure to match the right dtype, e.g., Interpolator(...).to(foo.dtype()) 83 | */ 84 | auto diagonal = torch::diag_embed(torch::ones(dim_i, ntopt())).view({1,dim_i,dim_i}); 85 | auto kwargs = F::InterpolateFuncOptions().size(vector({dim_o})) 86 | .mode(torch::kLinear).align_corners(true); 87 | return F::interpolate(diagonal, kwargs).transpose(1,2); 88 | } 89 | 90 | /* Interpolator operator */ 91 | torch::Tensor batched_interpolator(torch::Tensor N_i, torch::Tensor N_o) { 92 | /* N_i (batch_size,) : input dimension 93 | * N_o (batch_size,) : output dimension 94 | * returns a Interpolator tensor with shape (batch_size, max_o, max_i) 95 | */ 96 | int batch_size = N_i.size(0); 97 | int max_i = torch::max(N_i).item(); 98 | int max_o = torch::max(N_o).item(); 99 | auto out = torch::zeros({batch_size, max_o, max_i}, ntopt()).to(N_i.dtype()); 100 | for (int b=0; b < batch_size; b++) { 101 | int dim_i = N_i[b].item(); int dim_o = N_o[b].item(); 102 | out[b] = expand(Interpolator(dim_i, dim_o).to(N_i.dtype()), max_i, max_o).squeeze(0); 103 | } 104 | return out; 105 | } 106 | 107 | /* Diagonalizing operator */ 108 | torch::Tensor batched_diag(torch::Tensor lam) { 109 | // lam : (batch_size, N, 1) diagonal entries 110 | // return (batch_size, N, N) tensor with each diagonal element specified by lam 111 | auto maps = lam * torch::ones_like(lam).repeat({1,1,lam.size(1)}); 112 | auto cr = torch::ones_like(lam).cumsum(1).repeat({1,1,lam.size(1)}); // (batch_size, [N], N) 113 | auto rc = torch::ones_like(lam).repeat({1,1,lam.size(1)}).cumsum(2); // (batch_size, N, [N]) 114 | auto mask = cr.eq(rc); 115 | return torch::where(mask, maps, torch::zeros_like(maps)); 116 | } 117 | 118 | /* Identity operator */ 119 | torch::Tensor I(torch::Tensor n, int diagonal=0) { 120 | // n (batch_size,) : width of identity matrix 121 | // return identity matrices of maximum width 122 | int l; 123 | if (diagonal == 0) { l = torch::max(n).item(); } 124 | else { l = torch::max(n).item() - abs(diagonal); } 125 | auto i = torch::ones(l, ntopt()).to(n.dtype()); 126 | return torch::diag(i, diagonal).unsqueeze(0).repeat({n.size(0),1,1}); 127 | } 128 | 129 | /* Difference operator */ 130 | torch::Tensor Dxx(torch::Tensor n, torch::Tensor h) { 131 | auto Dx = I(n, +1) - 2*I(n) + I(n, -1); 132 | return Dx / h.pow(2).view({-1,1,1}); 133 | } 134 | torch::Tensor Dxf(torch::Tensor n, torch::Tensor h) { 135 | auto Dx = I(n, +1) - I(n); 136 | return Dx / h.view({-1,1,1}); 137 | } 138 | torch::Tensor Dxb(torch::Tensor n, torch::Tensor h) { 139 | auto Dx = I(n) - I(n, -1); 140 | return Dx / h.view({-1,1,1}); 141 | } 142 | torch::Tensor Dxxxx(torch::Tensor n, torch::Tensor h) { 143 | auto Dx = I(n, +2) - 4*I(n, +1) + 6*I(n) - 4*I(n, -1) + I(n, -2); 144 | return Dx / h.pow(4).view({-1,1,1}); 145 | } 146 | torch::Tensor Dxxxx_clamped(torch::Tensor n, torch::Tensor h) { 147 | // Fourth-order difference operator for the boundary condition u_{-1} == u_{1}. 148 | /* [[[ 6., -4., 1., 0., , ], 149 | [-4., 7., -4., 1., 0., ], 150 | [ 1., -4., 6., -4., 1., 0.], 151 | [ 0., 1., -4., 6., -4., 1.], 152 | [ , 0., 1., -4., 7., -4.], 153 | [ , , 0., 1., -4., 6.]]] / h^4 */ 154 | int n_max = torch::max(n).item(); 155 | auto ones = torch::ones(n_max, ntopt()).to(n.dtype()).view({1,n_max,1}); 156 | auto maps = ones.cumsum(1)-1; 157 | auto rpos = (n-2).view({-1,1,1}).repeat({1,n_max,1}); // (batch_size, n_max, 1) : filled with (n-2) 158 | auto mask_l = maps.eq(ones); // true only at index 1 159 | auto mask_r = maps.eq(rpos); // true only at index n-2 160 | auto SM = I(n) * mask_l.logical_or(mask_r); 161 | auto Dx = I(n, +2) - 4*I(n, +1) + 6*I(n) - 4*I(n, -1) + I(n, -2); 162 | return (Dx + SM) / h.pow(4).view({-1,1,1}); 163 | } 164 | torch::Tensor Mxc(torch::Tensor n) { 165 | return (I(n, +1) + I(n, -1)) / 2; 166 | } 167 | 168 | torch::Tensor block_matrices(vector< vector > X) { 169 | int n_rows = X.size(); int n_cols = X[0].size(); 170 | torch::Tensor out; 171 | for (int i=0; i < n_rows; i++) { 172 | torch::Tensor row = X[i][0]; 173 | for (int j=1; j < n_cols; j++) { 174 | row = torch::cat({row, X[i][j]}, 2); 175 | } 176 | if (i==0) { out = row; } 177 | else { out = torch::cat({out, row}, 1); } 178 | } 179 | return out; 180 | } 181 | 182 | torch::Tensor mask_1d(torch::Tensor u, torch::Tensor N, int N_max) { 183 | // u : (batch_size, N_max, 1) 184 | auto maps = torch::ones_like(u).cumsum(1); // (batch_size, N_max, 1) : arange for N_max 185 | auto cons = N.view({-1,1,1}).repeat({1,N_max,1}); // (batch_size, N_max, 1) : filled with N 186 | auto mask = maps.le(cons); // (batch_size, N_max, 1) : boolean mask 187 | // mask to actual length 188 | return u * mask; 189 | } 190 | torch::Tensor mask_2d(torch::Tensor X, torch::Tensor N, int N_max) { 191 | // X : (batch_size, N_max, N_max) 192 | auto maps = torch::ones_like(X).cumsum(1).cumsum(2); // (batch_size, N_max, N_max) 193 | auto cons = N.view({-1,1,1}).repeat({1,N_max,N_max}); // (batch_size, N_max, N_max) 194 | auto mask = maps.le(cons); // (batch_size, N_max, N_max) 195 | // mask to actual length 196 | return X * mask; 197 | } 198 | torch::Tensor dirichlet_boundary(torch::Tensor u, torch::Tensor N, int N_max) { 199 | // u : (batch_size, N_max, 1) 200 | // zero-out u at position index 0 and N 201 | auto maps = torch::ones_like(u).cumsum(1)-1; // (batch_size, N_max, 1) : arange for N_max 202 | auto zero = torch::zeros_like(u); // (batch_size, N_max, 1) : filled with 0 203 | auto rpos = N.view({-1,1,1}).repeat({1,N_max,1}); // (batch_size, N_max, 1) : filled with N 204 | auto mask_l = maps.eq(zero).logical_not(); // false only at index 0 205 | auto mask_r = maps.eq(rpos).logical_not(); // false only at index N 206 | return u * mask_l * mask_r; 207 | } 208 | 209 | torch::Tensor inverse_like(torch::Tensor A) { 210 | // Compute the Moore-Penrose pseudo-inverse 211 | c10::optional atol; c10::optional rtol; 212 | return at::linalg_pinv(A, atol, rtol, /*hermitian*/true); 213 | } 214 | vector split_blocks(torch::Tensor X, int N_t, int N_l) { 215 | auto X_split = X.split({ N_t, N_l }, /*dim*/-2); 216 | auto X_01 = X_split[0].split({ N_t, N_l }, /*dim*/-1); 217 | auto X_23 = X_split[1].split({ N_t, N_l }, /*dim*/-1); 218 | return { X_01[0], X_01[1], X_23[0], X_23[1] }; 219 | } 220 | torch::Tensor sparse_blocks(vector X, int N_t_max, int N_l_max) { 221 | auto X_0 = expand(X[0], /*width*/N_t_max, /*height*/N_t_max); 222 | auto X_1 = expand(X[1], /*width*/N_l_max, /*height*/N_t_max); 223 | auto X_2 = expand(X[2], /*width*/N_t_max, /*height*/N_l_max); 224 | auto X_3 = expand(X[3], /*width*/N_l_max, /*height*/N_l_max); 225 | return block_matrices({ { X_0, X_1 }, { X_2, X_3 } }); // N_t_max+N_l_max, N_t_max+N_l_max 226 | } 227 | 228 | torch::Tensor tridiagonal_inverse(torch::Tensor X, torch::Tensor N) { 229 | // X (batch_size, n, n) : tridiagonal matrix to invert 230 | // N (batch_size, ) : actual width + 1 of the matrix 231 | int batch_size = X.size(0); int n = X.size(1); 232 | auto k = 1 + torch::arange(n, ntopt()).to(X.dtype()); // (batch_size, ) 233 | auto jk = torch::outer(k,k).unsqueeze(0).repeat({batch_size,1,1}); // (batch_size, n, n) 234 | auto kb = k.unsqueeze(0).repeat({batch_size,1}); // (batch_size, n) 235 | 236 | auto a = X.select(2,0).select(1,1).unsqueeze(1); // (batch_size, 1) 237 | auto b = X.select(2,0).select(1,0).unsqueeze(1); // (batch_size, 1) 238 | auto c = X.select(2,1).select(1,0).unsqueeze(1); // (batch_size, 1) 239 | auto Nb = N.unsqueeze(1); // (batch_size, 1) 240 | 241 | auto lam = b + (a+c) * torch::cos(kb * M_PI / Nb); // (batch_size, 1) 242 | auto Lid = torch::diag(torch::ones_like(k)).unsqueeze(0); // (1, n, n) 243 | auto Lam = 1 / lam.unsqueeze(1); // (batch_size, 1, 1) 244 | auto L = Lid * Lam; // (batch_size, n, n) 245 | auto V = (2. / Nb.unsqueeze(-1)).pow(0.5) 246 | * torch::sin(jk * M_PI / Nb.unsqueeze(-1)); // (batch_size, n, n) 247 | 248 | // apply mask 249 | L = mask_2d(L, N, n); 250 | 251 | return torch::matmul(V, torch::matmul(L, V.transpose(-1,-2))); 252 | } 253 | 254 | torch::Tensor assign( 255 | torch::Tensor x, 256 | torch::Tensor y, 257 | int index, int dim) { 258 | x = x.transpose(0,dim); // transpose target dim with dim 0 259 | x[index] = y; // assign 260 | return x.transpose(0,dim); // revert by transpose 261 | } 262 | 263 | torch::Tensor add_in( 264 | torch::Tensor x, 265 | torch::Tensor y, 266 | int index, int dim) { 267 | x = x.transpose(0,dim); // transpose target dim with dim 0 268 | x[index] += y; // assign 269 | return x.transpose(0,dim); // revert by transpose 270 | } 271 | torch::Tensor lstsq( 272 | torch::Tensor LHS, 273 | torch::Tensor RHS, 274 | torch::Tensor pseudo_inverse, 275 | float threshold=1e-4) { 276 | torch::Tensor solution; 277 | 278 | // Compute the solution 279 | solution = torch::matmul(pseudo_inverse, RHS); 280 | torch::Tensor residual = torch::matmul(LHS, solution) - RHS; 281 | float err = residual.norm().item(); 282 | int iter = 0; 283 | int max_iter = 100; 284 | while ((err > threshold) and (iter++ < max_iter)) { 285 | torch::Tensor update = torch::matmul(pseudo_inverse, residual); 286 | solution -= update; 287 | residual = torch::matmul(LHS, solution) - RHS; 288 | err = residual.norm().item(); 289 | } 290 | return solution; 291 | } 292 | 293 | -------------------------------------------------------------------------------- /src/model/cpp/misc.h: -------------------------------------------------------------------------------- 1 | #ifndef MISC_H 2 | #define MISC_H 3 | 4 | torch::Device device(); 5 | torch::Device ntopt(); 6 | torch::Tensor raised_cosine(torch::Tensor, torch::Tensor, torch::Tensor, int); 7 | torch::Tensor floor_dirac_delta(torch::Tensor, torch::Tensor, int); 8 | torch::Tensor domain_x(int, torch::Tensor); 9 | torch::Tensor triangular(int, torch::Tensor, torch::Tensor, torch::Tensor); 10 | torch::Tensor expand(torch::Tensor, int, int); 11 | torch::Tensor Interpolator(int, int); 12 | torch::Tensor batched_interpolator(torch::Tensor, torch::Tensor); 13 | torch::Tensor batched_diag(torch::Tensor); 14 | torch::Tensor I(torch::Tensor, int); 15 | torch::Tensor Dxb(torch::Tensor, torch::Tensor); 16 | torch::Tensor Dxf(torch::Tensor, torch::Tensor); 17 | torch::Tensor Dxx(torch::Tensor, torch::Tensor); 18 | torch::Tensor Dxxxx(torch::Tensor, torch::Tensor); 19 | torch::Tensor Dxxxx_clamped(torch::Tensor, torch::Tensor); 20 | torch::Tensor Mxc(torch::Tensor); 21 | torch::Tensor block_matrices(vector< vector >); 22 | torch::Tensor mask_1d(torch::Tensor, torch::Tensor, int); 23 | torch::Tensor mask_2d(torch::Tensor, torch::Tensor, int); 24 | torch::Tensor dirichlet_boundary(torch::Tensor, torch::Tensor, int); 25 | torch::Tensor inverse_like(torch::Tensor); 26 | vector split_blocks(torch::Tensor, int, int); 27 | torch::Tensor sparse_blocks(vector, int, int); 28 | torch::Tensor tridiagonal_inverse(torch::Tensor, torch::Tensor); 29 | torch::Tensor assign(torch::Tensor, torch::Tensor, int, int); 30 | torch::Tensor add_in(torch::Tensor, torch::Tensor, int, int); 31 | torch::Tensor lstsq(torch::Tensor, torch::Tensor, torch::Tensor, float); 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /src/model/cpp/simulator.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | # define M_PI 3.14159265358979323846 /* pi */ 3 | const float M_HD = -0.01; /* max hammer displacement */ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "misc.h" 10 | #include "string.h" 11 | 12 | namespace F = torch::nn::functional; 13 | 14 | vector forward_fn( 15 | torch::Tensor state_u, // transverse displacement state 16 | torch::Tensor state_z, // longitudinal displacement state 17 | vector string_params, // string parameters 18 | vector bow_params, // bow excitation parameters 19 | vector hammer_params, // hammer excitation parameters 20 | torch::Tensor bow_mask, // bow excitation mask 21 | torch::Tensor hammer_mask, // hammer excitation mask 22 | vector constant, // global constants 23 | float relative_error, // order of the discretization error 24 | bool surface_integral, // pickup configuration 25 | bool manufactured, // verification configuration (using manufactured solution) 26 | int n_0, // time index for global step 27 | int Nt) { // number of simulation samples 28 | 29 | int batch_size = state_u.size(0); 30 | float k = constant[0]; 31 | 32 | auto uout = torch::zeros({batch_size,Nt}, state_u.dtype()).to(device()); // pickup displacement for output 33 | auto zout = torch::zeros({batch_size,Nt}, state_u.dtype()).to(device()); // pickup displacement for output 34 | auto v_b = torch::zeros({batch_size,Nt}, state_u.dtype()).to(device()); // relative velocity at the bowing point 35 | auto F_H = torch::zeros({batch_size,Nt}, state_u.dtype()).to(device()); // hammering force profile 36 | auto u_H = hammer_params[2]; // hammer displacement 37 | torch::Tensor sig0; // freq-independent loss term 38 | torch::Tensor sig1; // freq-dependent loss term 39 | 40 | for (int n=2; n < Nt; n++) { 41 | auto results = string_step( 42 | uout, zout, state_u, state_z, v_b, F_H, 43 | string_params, bow_params, hammer_params, 44 | bow_mask, hammer_mask, 45 | constant, n+n_0, n, relative_error, surface_integral, 46 | manufactured); 47 | uout = results[0]; 48 | zout = results[1]; 49 | state_u = results[2]; 50 | state_z = results[3]; 51 | v_b = results[4]; 52 | F_H = results[5]; 53 | u_H = results[6]; 54 | sig0 = results[7]; 55 | sig1 = results[8]; 56 | } 57 | u_H = u_H / k; 58 | return { uout, zout, state_u, state_z, v_b, F_H, u_H, sig0, sig1 }; 59 | } 60 | 61 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 62 | m.def("forward_fn", &forward_fn, "string-bow forward iteration"); 63 | } 64 | 65 | -------------------------------------------------------------------------------- /src/model/cpp/string.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | # define M_PI 3.14159265358979323846 /* pi */ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "misc.h" 9 | #include "vnv.h" 10 | #include "bow.h" 11 | #include "hammer.h" 12 | 13 | #include 14 | using namespace at; 15 | 16 | vector get_derived_vars( 17 | torch::Tensor f0, 18 | torch::Tensor kappa_rel, 19 | float k, float theta_t, float lambda_c, 20 | torch::Tensor alpha) { 21 | 22 | // Derived variables 23 | auto gamma = 2 * f0; // set parameters 24 | auto kappa = gamma * kappa_rel; // stiffness parameter 25 | auto IHP = (M_PI * kappa / gamma).pow(2); // inharmonicity parameter (>0); eq 7.21 26 | auto K = IHP.pow(0.5) * (gamma / M_PI); // set parameters 27 | 28 | torch::Tensor h_1, h_2; 29 | h_1 = lambda_c * ( 30 | (gamma.pow(2) * pow(k, 2.) + pow(gamma.pow(4) * pow(k, 4.) + 16 * K.pow(2) * pow(k, 2.) * (2 * theta_t - 1),.5)) 31 | / (2 * (2 * theta_t - 1)) 32 | ).pow(.5); 33 | auto N_t = torch::floor(1 / h_1).to(kappa_rel.dtype()); 34 | auto h_t = 1 / N_t; 35 | 36 | h_2 = lambda_c * gamma * alpha * k; 37 | auto N_l = torch::floor(1 / h_2).to(kappa_rel.dtype()); 38 | auto h_l = 1 / N_l; 39 | 40 | return { gamma, K, N_t, h_t, N_l, h_l }; 41 | } 42 | 43 | vector string_step( 44 | torch::Tensor uout, // pickup transverse displacement 45 | torch::Tensor zout, // pickup longitudinal displacement 46 | torch::Tensor state_u, // transverse displacement state 47 | torch::Tensor state_z, // longitudinal displacement state 48 | torch::Tensor v_r_out, // relative velocity 49 | torch::Tensor F_H_out, // hammering force profile 50 | vector string_params, // string parameters 51 | vector bow_params, // bow excitation parameters 52 | vector hammer_params, // hammer excitation parameters 53 | torch::Tensor bow_mask, // bow excitation mask 54 | torch::Tensor hammer_mask, // hammer excitation mask 55 | vector constant, // global constants 56 | int global_step, // global simulation time step 57 | int local_step, // local simulation time step (just for TBPTT) 58 | float relative_error, // discretization error relative to the spatial grid size 59 | bool surface_integral, 60 | bool manufactured) { // use manufactured solution (used for verification purposes) 61 | int batch_size = uout.size(0); 62 | 63 | //============================== 64 | // Setup variables 65 | //============================== 66 | // string parameters 67 | auto kappa_rel = string_params[0]; auto alpha = string_params[1]; 68 | auto u0 = string_params[2]; auto v0 = string_params[3]; auto p_a = string_params[4]; 69 | auto f0 = string_params[5]; auto rp = string_params[6]; auto T60 = string_params[7]; 70 | 71 | // bow control parameters 72 | auto x_bow = bow_params[0]; auto v_bow = bow_params[1]; auto F_bow = bow_params[2]; 73 | auto phi_0 = bow_params[3]; auto phi_1 = bow_params[4]; auto wid_b = bow_params[5]; 74 | 75 | // hammer control parameters 76 | auto x_H = hammer_params[0]; auto v_H = hammer_params[1]; auto u_H_out = hammer_params[2]; 77 | auto w_H = hammer_params[3]; auto M_r = hammer_params[4]; auto alpha_H = hammer_params[5]; 78 | 79 | // constants 80 | float k = constant[0]; float theta_t = constant[1]; float lambda_c = constant[2]; 81 | 82 | // derived variables 83 | auto vars = get_derived_vars(f0.select(1,local_step), kappa_rel, k, theta_t, lambda_c, alpha); 84 | auto gamma = vars[0]; auto K = vars[1]; 85 | auto N_t = vars[2]; auto h_t = vars[3]; // transverse (u) 86 | auto N_l = vars[4]; auto h_l = vars[5]; // longitudinal (zeta) 87 | 88 | auto bow_wid_length = wid_b.select(1,local_step) * h_t; 89 | auto tol_t = h_t.pow(relative_error); 90 | auto tol_l = h_l.pow(relative_error); 91 | 92 | //============================== 93 | // Simulation step 94 | //============================== 95 | 96 | // Scheme loss parameters; eq 7.29 97 | torch::Tensor zeta1; 98 | torch::Tensor zeta2; 99 | 100 | zeta1 = torch::where(K.gt(0), 101 | - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * M_PI * T60.select(2,0).select(1,0)).pow(2)).pow(.5), // if is stiff string 102 | T60.select(2,0).select(1,0).pow(2) / gamma.pow(2) // otherwise 103 | ); 104 | zeta2 = torch::where(K.gt(0), 105 | - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * M_PI * T60.select(2,0).select(1,1)).pow(2)).pow(.5), // if is stiff string 106 | 107 | T60.select(2,0).select(1,1).pow(2) / gamma.pow(2) // otherwise 108 | ); 109 | 110 | auto T60_mask = T60.prod(2).prod(1).ne(0); 111 | auto sig0 = torch::where(T60_mask, 112 | - zeta2 / T60.select(2,1).select(1,0) + zeta1 / T60.select(2,1).select(1,1), // lossy string 113 | T60_mask // lossless string 114 | ); 115 | auto sig1 = torch::where(T60_mask, 116 | 1 / T60.select(2,1).select(1,0) - 1 / T60.select(2,1).select(1,1), // lossy string 117 | T60_mask // lossless string 118 | ); 119 | sig0 = (6 * log(10) * sig0 / (zeta1 - zeta2)).view({-1,1,1}); // freq-independent loss term 120 | sig1 = (6 * log(10) * sig1 / (zeta1 - zeta2)).view({-1,1,1}); // freq-dependent loss term 121 | 122 | // setup displacements 123 | int N_t_max = state_u.size(-1); 124 | int N_l_max = state_z.size(-1); 125 | auto u1 = state_u.narrow(1,local_step-1,1).transpose(2,1); // (batch_size, N_t_max, 1) 126 | auto u2 = state_u.narrow(1,local_step-2,1).transpose(2,1); // (batch_size, N_t_max, 1) 127 | auto z1 = state_z.narrow(1,local_step-1,1).transpose(2,1); // (batch_size, N_l_max, 1) 128 | auto z2 = state_z.narrow(1,local_step-2,1).transpose(2,1); // (batch_size, N_l_max, 1) 129 | u1 = mask_1d(u1, N_t+1, N_t_max); 130 | u2 = mask_1d(u2, N_t+1, N_t_max); 131 | z1 = mask_1d(z1, N_l+1, N_l_max); 132 | z2 = mask_1d(z2, N_l+1, N_l_max); 133 | 134 | auto w1 = torch::cat({u1, z1}, 1); 135 | auto w2 = torch::cat({u2, z2}, 1); 136 | 137 | // setup operators 138 | auto Id_tt = I(N_t+1, 0); auto Id_ll = I(N_l+1, 0); 139 | auto Dxf_tt = Dxf(N_t+1, h_t); auto Dxf_ll = Dxf(N_l+1, h_l); 140 | auto Dxb_tt = Dxb(N_t+1, h_t);// auto Dxb_ll = Dxb(N_l+1, h_l); 141 | auto Dxx_tt = Dxx(N_t+1, h_t); auto Dxx_ll = Dxx(N_l+1, h_l); 142 | //auto Dxxxx_tt = Dxxxx(N_t+1, h_t);// auto Dxxxx_ll = Dxxxx(N_l+1, h_l); 143 | auto Dxxxx_tt = Dxxxx_clamped(N_t+1, h_t); 144 | auto Int_tl = batched_interpolator(N_l+1, N_t+1); 145 | auto Int_lt = batched_interpolator(N_t+1, N_l+1); 146 | auto Mxc_tt = Mxc(N_t+1); 147 | 148 | auto Theta_tt = theta_t * Id_tt + (1-theta_t) * Mxc_tt; 149 | 150 | // setup recursion 151 | auto gamma_k = gamma.pow(2).view({-1,1,1}) * pow(k, 2.); 152 | auto phi_pow = gamma_k * (alpha.pow(2).view({-1,1,1}) - 1) / 4; 153 | auto Lam = batched_diag(torch::matmul(Dxb_tt, u1.narrow(1,0,Dxb_tt.size(-1)))); 154 | auto Qp_tt = Theta_tt + 2 * sig0 * k * Id_tt - 2 * sig1 * k * Dxx_tt; 155 | auto Qm_tt = Theta_tt - 2 * sig0 * k * Id_tt + 2 * sig1 * k * Dxx_tt; 156 | auto Qp_ll = (1 + 2 * sig0 * k) * Id_ll - 2 * sig1 * k * Dxx_ll; 157 | auto Qm_ll = (1 - 2 * sig0 * k) * Id_ll + 2 * sig1 * k * Dxx_ll; 158 | auto K_tl = - phi_pow * torch::matmul(Dxf_tt, torch::matmul(Lam, torch::matmul(Dxb_tt, Int_tl))); 159 | auto K_lt = - phi_pow * torch::matmul(Dxf_ll, torch::matmul(Int_lt, torch::matmul(Lam, Dxb_tt))); 160 | auto V_tt = - phi_pow * torch::matmul(Dxf_tt, torch::matmul(Lam.pow(2), Dxb_tt)); 161 | 162 | auto B_1 = -2 * Theta_tt - gamma_k * Dxx_tt + K.pow(2).view({-1,1,1}) * pow(k,2.) * Dxxxx_tt; 163 | auto B_2 = 2 * K_tl; 164 | auto B_3 = torch::zeros_like(B_2).transpose(1,2); 165 | auto B_4 = -2 * Id_ll - gamma_k * alpha.pow(2).view({-1,1,1}) * Dxx_ll; 166 | 167 | /* A @ w^{n+1} + B @ w^{n} + C @ w^{n-1} = 0 */ 168 | // matrices with size (batch, N_t+N_l, N_t+N_l) 169 | auto A_1 = Qp_tt + V_tt; auto A_2 = K_tl; auto A_3 = K_lt; auto A_4 = Qp_ll; 170 | auto C_1 = Qm_tt + V_tt; auto C_2 = K_tl; auto C_3 = K_lt; auto C_4 = Qm_ll; 171 | 172 | // inverse A before it gets zero-padded 173 | int t_wid = A_1.size(-1); int l_wid = A_2.size(-1); 174 | auto A_b = block_matrices({ { A_1, A_2 }, { A_3, A_4 } }); 175 | auto A_p = torch::linalg::inv(A_b); 176 | 177 | // zero-pad to maximal size (batch, N_t_max+N_l_max, N_t_max+N_l_max) 178 | auto A = sparse_blocks({A_1, A_2, A_3, A_4}, N_t_max, N_l_max); 179 | auto B = sparse_blocks({B_1, B_2, B_3, B_4}, N_t_max, N_l_max); 180 | auto C = sparse_blocks({C_1, C_2, C_3, C_4}, N_t_max, N_l_max); 181 | auto A_P = sparse_blocks(split_blocks(A_p, t_wid, l_wid), N_t_max, N_l_max); 182 | 183 | auto u_H1 = u_H_out.narrow(1,local_step-1,1).view(-1); 184 | auto u_H2 = u_H_out.narrow(1,local_step-2,1).view(-1); 185 | 186 | // iterate for implicit scheme 187 | int iter = 0; 188 | bool not_converged_t = true; 189 | bool not_converged_l = true; 190 | torch::Tensor u = state_u.narrow(1,local_step-1,1).transpose(2,1); // initialize by u1 191 | torch::Tensor z = state_z.narrow(1,local_step-1,1).transpose(2,1); // initialize by z1; 192 | torch::Tensor u_H; 193 | torch::Tensor F_H; 194 | torch::Tensor d_H; 195 | torch::Tensor v_rel; 196 | 197 | M_r = M_r / lambda_c; 198 | w_H = w_H / lambda_c; 199 | //bow_wid_length = bow_wid_length / lambda_c; 200 | while (not_converged_t or not_converged_l) { 201 | /* Bow excitation */ 202 | auto Bow = bow_term_rhs( 203 | N_t, h_t, k, u, u1, u2, 204 | x_bow.select(1,local_step), 205 | v_bow.select(1,local_step), 206 | F_bow.select(1,local_step), 207 | bow_wid_length, phi_0, phi_1, 208 | iter); 209 | auto G_B = Bow[0]; v_rel = Bow[1]; 210 | 211 | /* Hammer excitation */ 212 | auto Hammer = hammer_term_rhs( 213 | N_t, h_t, k, u, u1, u2, 214 | x_H, u_H1, u_H2, w_H, M_r, 215 | alpha_H, tol_t, hammer_mask.view(-1)); 216 | auto G_H = Hammer[0]; F_H = Hammer[1]; u_H = Hammer[2]; d_H = Hammer[3]; 217 | 218 | G_B = expand(G_B, 1, N_t_max+N_l_max); 219 | G_H = expand(G_H, 1, N_t_max+N_l_max); 220 | 221 | // solve 222 | auto LHS = A; 223 | auto RHS = torch::matmul(B, w1) 224 | + torch::matmul(C, w2) 225 | + bow_mask * G_B.nan_to_num() 226 | + hammer_mask * G_H.nan_to_num(); 227 | if (manufactured) { // using manufactured solution 228 | auto x = domain_x(N_t_max+N_l_max, N_t); 229 | auto t = global_step * k; 230 | auto f = manufactured_solution_forcing_term(gamma, sig0, K, p_a, x, t); 231 | RHS -= f * pow(k, 2); 232 | } 233 | RHS = mask_1d(RHS, N_t+N_l+2, N_t_max+N_l_max); 234 | 235 | //auto w = lstsq(LHS, - RHS, A_P, 1e-8); 236 | //auto w = get<0>(torch::linalg::lstsq(LHS, - RHS, at::nullopt, at::nullopt)); 237 | //auto w = torch::linalg::solve(LHS, - RHS); 238 | auto w = torch::matmul(A_P, - RHS); 239 | 240 | auto new_u = w.narrow(1,0,N_t_max); 241 | auto new_z = w.narrow(1,N_t_max,N_l_max); 242 | new_u = mask_1d(new_u, N_t+1, N_t_max); 243 | new_z = mask_1d(new_z, N_l+1, N_l_max); 244 | 245 | new_u = dirichlet_boundary(new_u, N_t, N_t_max); 246 | new_z = dirichlet_boundary(new_z, N_l, N_l_max); 247 | 248 | torch::Tensor residual_u = u - new_u; 249 | torch::Tensor residual_z = z - new_z; 250 | auto res_u = get<0>(residual_u.flatten(1).abs().max(1)); // \ell_\infty norm (values, index) 251 | auto res_z = get<0>(residual_z.flatten(1).abs().max(1)); // \ell_\infty norm (values, index) 252 | not_converged_t = res_u.gt(tol_t).any().item(); 253 | not_converged_l = res_z.gt(tol_l).any().item(); 254 | 255 | u = new_u; 256 | z = new_z; 257 | iter++; 258 | } 259 | 260 | u = u.squeeze(2); 261 | z = z.squeeze(2); 262 | 263 | // save and readout 264 | state_u = add_in(state_u, u, local_step, 1); 265 | state_z = add_in(state_z, z, local_step, 1); 266 | auto u_rp_int = 1 + torch::floor(N_t * rp).view({-1,1}).to(torch::kLong); // rounded grid index for readout 267 | auto u_rp_frac = 1 + rp.view({-1,1}) / h_t.view({-1,1}) - u_rp_int; // fractional part of readout location 268 | auto z_rp_int = 1 + torch::floor(N_l * rp).view({-1,1}).to(torch::kLong); // rounded grid index for readout 269 | auto z_rp_frac = 1 + rp.view({-1,1}) / h_l.view({-1,1}) - z_rp_int; // fractional part of readout location 270 | 271 | torch::Tensor u_out; 272 | torch::Tensor z_out; 273 | if (surface_integral) { // using surface integral of velocities 274 | auto r_w = 0.5 * h_t.view({-1,1,1}); 275 | auto r_H = r_w; 276 | auto r_B = r_w; 277 | u_rp_frac = u_rp_frac.unsqueeze(2); // (B, 1, 1) 278 | z_rp_frac = z_rp_frac.unsqueeze(2); // (B, 1, 1) 279 | u_out = (u - state_u.narrow(1,local_step-1,1).squeeze(1)); 280 | z_out = (z - state_z.narrow(1,local_step-1,1).squeeze(1)); 281 | 282 | // Naive weighting. TODO: use distance-based weighting 283 | auto w_u = r_w * torch::ones_like(u_rp_frac) // (B, 1, 1) 284 | + r_H * hammer_mask // (B, 1, 1) 285 | + r_B * bow_mask; // (B, 1, 1) 286 | auto w_z = r_w * torch::ones_like(z_rp_frac) // (B, 1, 1) 287 | + r_H * hammer_mask // (B, 1, 1) 288 | + r_B * bow_mask; // (B, 1, 1) 289 | 290 | u_out = (u_out * w_u.squeeze(1) / k).sum(-1); // (B, ) 291 | z_out = (z_out * w_z.squeeze(1) / k).sum(-1); // (B, ) 292 | } 293 | else { // using interpolated pickup point 294 | u_out = (1 - u_rp_frac) * u.gather(1, u_rp_int ).view({-1,1}) 295 | + u_rp_frac * u.gather(1, u_rp_int+1).view({-1,1}); 296 | z_out = (1 - z_rp_frac) * z.gather(1, z_rp_int ).view({-1,1}) 297 | + z_rp_frac * z.gather(1, z_rp_int+1).view({-1,1}); 298 | } 299 | uout = assign(uout, u_out.view(-1), local_step, 1); 300 | zout = assign(zout, z_out.view(-1), local_step, 1); 301 | v_r_out = assign(v_r_out, v_rel.view(-1), local_step, /*dim*/1); 302 | F_H_out = assign(F_H_out, F_H.view(-1), local_step, /*dim*/1); 303 | u_H_out = add_in(u_H_out, u_H.view(-1), local_step, /*dim*/1); 304 | 305 | return { uout, zout, state_u, state_z, v_r_out, F_H_out, u_H_out, sig0, sig1 }; 306 | } 307 | 308 | 309 | -------------------------------------------------------------------------------- /src/model/cpp/string.h: -------------------------------------------------------------------------------- 1 | #ifndef FDM_H 2 | #define FDM_H 3 | 4 | vector get_derived_vars( 5 | torch::Tensor, torch::Tensor, float, float, torch::Tensor 6 | ); 7 | vector string_step( 8 | torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, // uout, zout, state_u, state_z 9 | torch::Tensor, torch::Tensor, // v_b, F_H 10 | vector, 11 | vector, 12 | vector, 13 | torch::Tensor, 14 | torch::Tensor, 15 | vector, int, int, float, 16 | bool, bool 17 | ); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /src/model/cpp/vnv.cpp: -------------------------------------------------------------------------------- 1 | using namespace std; 2 | 3 | # define M_PI 3.14159265358979323846 /* pi */ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "misc.h" 10 | 11 | torch::Tensor manufactured_solution_forcing_term( 12 | torch::Tensor gamma, 13 | torch::Tensor sig0, 14 | torch::Tensor K, 15 | torch::Tensor p_a, 16 | torch::Tensor x, 17 | double t 18 | ) { 19 | /* returns the forcing term for the manufactured solution 20 | * sigma == sig0 21 | * omega == gamma 22 | * mu == pi 23 | */ 24 | auto sigma = sig0; 25 | auto omega = gamma; 26 | auto mu = M_PI; 27 | auto mu_sq = pow(M_PI,2); 28 | 29 | auto coeff_1 = (sigma.pow(2) - omega.pow(2) - 2*sig0*sigma) * torch::cos(mu * x).pow(2); 30 | auto coeff_2 = (2*mu_sq * (4*K.pow(2)*mu_sq + gamma.pow(2))) * torch::cos(2*mu * x); 31 | auto coeff_3 = 2*omega*(sigma - sig0) * torch::cos(mu*x).pow(2); 32 | 33 | auto cos_term = (coeff_1 + coeff_2) * torch::cos(omega*t); 34 | auto sin_term = coeff_3 * torch::sin(omega*t); 35 | 36 | return p_a * (cos_term + sin_term) * torch::exp(-1 * sigma * t); 37 | } 38 | 39 | -------------------------------------------------------------------------------- /src/model/cpp/vnv.h: -------------------------------------------------------------------------------- 1 | #ifndef VNV_H 2 | #define VNV_H 3 | 4 | torch::Tensor manufactured_solution_forcing_term( 5 | torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, double 6 | ); 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /src/model/nn/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | from src.utils import misc as utils 9 | class Identity(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | def forward(self, x): 13 | return x 14 | 15 | def swish(x): 16 | return x * torch.sigmoid(x) 17 | 18 | def get_activation(name): 19 | if name is None or name == 'linear': 20 | return nn.Identity(), 'linear' 21 | elif name.lower() == 'relu': 22 | return nn.ReLU(), 'relu' 23 | elif name.lower() == 'leaky_relu': 24 | return nn.LeakyReLU(), 'leaky_relu' 25 | elif name.lower() == 'tanh': 26 | return nn.Tanh(), 'tanh' 27 | elif name.lower() == 'sin': 28 | return torch.sin, 'tanh' 29 | elif name.lower() == 'sigmoid': 30 | return nn.Sigmoid(), 'sigmoid' 31 | elif name.lower() == 'swish': 32 | return swish, 'linear' 33 | else: 34 | raise NotImplementedError(name) 35 | 36 | def apply_gain(x, gain, fn=None): 37 | gain = fn(gain) if fn is not None else gain 38 | x_list = x.chunk(len(gain), -1) 39 | x_list = [gain[i] * x_i for i, x_i in enumerate(x_list)] 40 | return torch.cat(x_list, dim=-1) 41 | 42 | class FMBlock(nn.Module): 43 | def __init__(self, input_dim, embed_dim, num_features): 44 | super().__init__() 45 | concat_size = embed_dim * num_features + embed_dim 46 | feature_dim = embed_dim * num_features 47 | self.rff2 = RFF2(input_dim, embed_dim//2) 48 | self.tmlp = mlp(concat_size, feature_dim, 5) 49 | self.proj = nn.Linear(concat_size, 2*input_dim) 50 | self.activation = nn.GLU(dim=-1) 51 | 52 | gain_in = torch.randn(num_features) / 2 53 | gain_out = torch.Tensor([0.1]) 54 | self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True)) 55 | self.register_parameter('gain_out', nn.Parameter(gain_out, requires_grad=True)) 56 | 57 | def forward(self, input, feature, slider, omega): 58 | ''' input : (B T input_dim) 59 | feature: (B T feature_dim) 60 | slider : (B T 1) 61 | ''' 62 | _input = input / (1.3*math.pi) - 1 63 | _input = self.rff2(_input) 64 | feature = apply_gain(feature, self.gain_in, torch.tanh) 65 | 66 | x = torch.cat((_input, feature), dim=-1) 67 | x = torch.cat((self.tmlp(x), _input), dim=-1) 68 | x = self.activation(self.proj(x)) 69 | 70 | gate = torch.tanh((slider - 1) * self.gain_out) 71 | return input + omega * x * gate 72 | 73 | class AMBlock(nn.Module): 74 | def __init__(self, input_dim, embed_dim, num_features): 75 | super().__init__() 76 | concat_size = embed_dim * num_features + embed_dim 77 | feature_dim = embed_dim * num_features 78 | self.rff2 = RFF2(input_dim, embed_dim//2) 79 | self.tmlp = mlp(concat_size, feature_dim, 5) 80 | self.proj = nn.Linear(concat_size, 2*input_dim) 81 | self.activation = nn.GLU(dim=-1) 82 | 83 | gain_in = torch.randn(num_features) / 2 84 | self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True)) 85 | 86 | def forward(self, input, feature, slider): 87 | ''' input : (B T input_dim) 88 | feature: (B T feature_dim) 89 | slider : (B T 1) 90 | ''' 91 | _input = input * 110 - 0.55 92 | _input = self.rff2(_input) 93 | feature = apply_gain(feature, self.gain_in, torch.tanh) 94 | 95 | x = torch.cat((_input, feature), dim=-1) 96 | x = torch.cat((self.tmlp(x), _input), dim=-1) 97 | x = self.activation(self.proj(x)) 98 | 99 | return input * (1 + x) 100 | 101 | class ModBlock(nn.Module): 102 | def __init__(self, input_dim, feature_dim, embed_dim): 103 | super().__init__() 104 | cat_size = 1+feature_dim 105 | self.tmlp = mlp(cat_size, feature_dim, 2) 106 | self.proj = nn.Linear(cat_size, 2) 107 | self.activation = nn.GLU(dim=-1) 108 | 109 | def forward(self, input, feature, slider): 110 | ''' input : (B T input_dim) 111 | feature: (B T feature_dim) 112 | slider : (B T 1) 113 | ''' 114 | input = input.unsqueeze(-1) # (B T input_dim 1) 115 | feature = feature.unsqueeze(-2).repeat(1,1,input.size(-2),1) 116 | x = torch.cat((input, feature), dim=-1) 117 | x = torch.cat((self.tmlp(x), input), dim=-1) 118 | x = self.activation(self.proj(x)) 119 | return (input * (1 + x)).squeeze(-1) 120 | 121 | def mlp(in_size, hidden_size, n_layers): 122 | channels = [in_size] + (n_layers) * [hidden_size] 123 | net = [] 124 | for i in range(n_layers): 125 | net.append(nn.Linear(channels[i], channels[i + 1])) 126 | #net.append(nn.LayerNorm(channels[i + 1])) 127 | net.append(nn.PReLU()) 128 | return nn.Sequential(*net) 129 | 130 | class RFF2(nn.Module): 131 | """ Random Fourier Features Module """ 132 | def __init__(self, input_dim, embed_dim, scale=1.): 133 | super().__init__() 134 | N = torch.ones((input_dim, embed_dim)) / input_dim / embed_dim 135 | N = nn.Parameter(N, requires_grad=False) 136 | e = torch.Tensor([scale]) 137 | e = nn.Parameter(e, requires_grad=True) 138 | self.register_buffer('N', N) 139 | self.register_parameter('e', e) 140 | 141 | def forward(self, x): 142 | ''' x: (Bs, Nt, input_dim) 143 | -> (Bs, Nt, embed_dim) 144 | ''' 145 | B = self.e * self.N 146 | x_embd = utils.fourier_feature(x, B) 147 | return x_embd 148 | 149 | class RFF(nn.Module): 150 | """ Random Fourier Features Module """ 151 | def __init__(self, scales, embed_dim): 152 | super().__init__() 153 | input_dim = len(scales) 154 | N = torch.randn(input_dim, embed_dim) 155 | N = nn.Parameter(N, requires_grad=False) 156 | e = torch.Tensor(scales).view(-1,1) 157 | e = nn.Parameter(e, requires_grad=True) 158 | self.register_buffer('N', N) 159 | self.register_parameter('e', e) 160 | 161 | def forward(self, x): 162 | ''' x: (Bs, Nt, input_dim) 163 | -> (Bs, Nt, input_dim*embed_dim) 164 | ''' 165 | xs = x.chunk(self.N.size(0), -1) # (Bs, Nt, 1) * input_dim 166 | Ns = self.N.chunk(self.N.size(0), 0) # (1, embed_dim) * input_dim 167 | Bs = [torch.pow(10, self.e[i]) * N for i, N in enumerate(Ns)] 168 | x_embd = [utils.fourier_feature(xs[i], B) for i, B in enumerate(Bs)] 169 | return torch.cat(x_embd, dim=-1) 170 | 171 | class ModeEstimator(nn.Module): 172 | def __init__(self, n_modes, hidden_dim, kappa_scale=None, gamma_scale=None, inharmonic=True, sr=48000): 173 | super().__init__() 174 | self.sr = sr 175 | self.kappa_scale = kappa_scale 176 | self.gamma_scale = gamma_scale 177 | self.rff = RFF([1.]*5, hidden_dim//2) 178 | self.a_mlp = mlp(5*hidden_dim, hidden_dim, 2) 179 | self.a_proj = nn.Linear(hidden_dim, n_modes) 180 | self.tanh = nn.Tanh() 181 | if inharmonic: 182 | self.f_mlp = mlp(5*hidden_dim, hidden_dim, 2) 183 | self.f_proj = nn.Linear(hidden_dim, n_modes) 184 | self.sigmoid = nn.Sigmoid() 185 | else: 186 | self.f_mlp = None 187 | self.f_proj = None 188 | self.sigmoid = nn.Sigmoid() 189 | 190 | def forward(self, u_0, x_p, kappa, gamma): 191 | ''' u_0 : (b, 1, x) 192 | x_p : (b, 1, 1) 193 | kappa : (b, 1, 1) 194 | gamma : (b, 1, 1) 195 | ''' 196 | p_x = torch.argmax(u_0, dim=-1, keepdim=True) / 255. # (b, 1, 1) 197 | p_a = torch.max(u_0, dim=-1, keepdim=True).values / 0.02 # (b, 1, 1) 198 | kappa = self.normalize_kappa(kappa) 199 | gamma = self.normalize_gamma(gamma) 200 | con = torch.cat((p_x, p_a, x_p, kappa, gamma), dim=-1) # (b, 1, 5) 201 | con = self.rff(con) # (b, 1, 3*hidden_dim) 202 | 203 | mode_amps = self.a_mlp(con) # (b, 1, k) 204 | mode_amps = self.tanh(1e-3 * self.a_proj(mode_amps)) # (b, 1, m) 205 | 206 | if self.f_mlp is not None: 207 | mode_freq = self.f_mlp(con) # (b, 1, k) 208 | mode_freq = 0.3 * self.sigmoid(self.f_proj(mode_freq)) # (b, 1, m) 209 | mode_freq = mode_freq.cumsum(-1) 210 | else: 211 | int_mults = torch.ones_like(mode_amps).cumsum(-1) # (b, 1, k) 212 | omega = gamma / self.sr * (2*math.pi) 213 | mode_freq = omega * int_mults 214 | 215 | return mode_amps, mode_freq 216 | 217 | def normalize_gamma(self, x): 218 | if self.gamma_scale is not None: 219 | minval = min(self.gamma_scale) 220 | denval = max(self.gamma_scale) - minval 221 | x = (x - minval) / denval 222 | return x 223 | 224 | def normalize_kappa(self, x): 225 | if self.kappa_scale is not None: 226 | minval = min(self.kappa_scale) 227 | denval = max(self.kappa_scale) - minval 228 | x = (x - minval) / denval 229 | return x 230 | 231 | -------------------------------------------------------------------------------- /src/model/nn/ddsp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.model.nn.blocks import FMBlock, AMBlock 4 | from src.utils.ddsp import upsample 5 | from src.utils.ddsp import remove_above_nyquist_mode 6 | from src.utils.ddsp import amp_to_impulse_response, fft_convolve 7 | from src.utils.ddsp import modal_synth 8 | from src.utils.ddsp import resample 9 | import math 10 | 11 | class DDSP(nn.Module): 12 | def __init__(self, 13 | feature_size, hidden_size, 14 | n_modes, n_bands, sampling_rate, block_size, 15 | fm=False, 16 | ): 17 | super().__init__() 18 | self.n_modes = n_modes 19 | 20 | self.freq_modulator = FMBlock(n_modes, feature_size) if fm else None 21 | self.coef_modulator = AMBlock(n_modes, feature_size) 22 | self.noise_proj = nn.Linear(feature_size, n_bands) 23 | 24 | noise_gate = nn.Parameter(torch.tensor([1e-2]), requires_grad=True) 25 | self.register_parameter("noise_gate", noise_gate) 26 | self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) 27 | self.register_buffer("block_size", torch.tensor(block_size)) 28 | 29 | def forward(self, hidden, mode_freq, mode_coef, times, alpha, lengths): 30 | ''' hidden : (Bs, 1, hidden_size) 31 | mode_freq : (Bs, Nt, n_modes) 32 | mode_coef : (Bs, 1, n_modes) 33 | times : (Bs, Nt, 1) 34 | ''' 35 | if self.freq_modulator is None: 36 | freq_m = mode_freq # integer multiples 37 | else: 38 | freq_m = self.freq_modulator(mode_freq, hidden) 39 | coef_m = self.coef_modulator(mode_coef, hidden, times) 40 | 41 | #============================== 42 | # harmonic part 43 | #============================== 44 | freqs = freq_m / (2*math.pi) * self.sampling_rate 45 | coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes) 46 | freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths) 47 | coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths) 48 | harmonic = modal_synth(freq_s, coef_s, self.sampling_rate) 49 | 50 | #============================== 51 | # noise part 52 | #============================== 53 | ngate = torch.tanh((alpha - 1) * self.noise_gate) 54 | param = ngate * torch.sigmoid(self.noise_proj(hidden) - 5) 55 | 56 | impulse = amp_to_impulse_response(param, self.block_size) 57 | noise = torch.rand( 58 | impulse.shape[0], 59 | impulse.shape[1], 60 | self.block_size, 61 | ).to(impulse) * 2 - 1 62 | noise = fft_convolve(noise, impulse).contiguous() 63 | noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths) 64 | 65 | signal = harmonic + noise 66 | return signal.squeeze(-1), freq_m, coef_m 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /src/model/nn/dmsp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.model.nn.blocks import FMBlock, AMBlock, ModBlock 4 | from src.utils.ddsp import scale_function, remove_above_nyquist, upsample 5 | from src.utils.ddsp import remove_above_nyquist_mode 6 | from src.utils.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve 7 | from src.utils.ddsp import modal_synth 8 | from src.utils.ddsp import resample 9 | import math 10 | 11 | class DMSP(nn.Module): 12 | def __init__(self, 13 | embed_dim, hidden_size, n_features, 14 | n_modes, n_bands, sampling_rate, block_size, 15 | ): 16 | super().__init__() 17 | self.n_modes = n_modes 18 | 19 | self.freq_modulator = FMBlock(n_modes, embed_dim, n_features) 20 | self.coef_modulator = AMBlock(n_modes, embed_dim, n_features) 21 | self.proj_noise = nn.Linear(n_features*embed_dim, n_bands) 22 | 23 | self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) 24 | self.register_buffer("block_size", torch.tensor(block_size)) 25 | 26 | def forward(self, hidden, mode_freq, mode_coef, times, alpha, omega, lengths): 27 | ''' hidden : (Bs, 1, hidden_size) 28 | mode_freq : (Bs, Nt, n_modes) 29 | mode_coef : (Bs, 1, n_modes) 30 | times : (Bs, Nt, 1) 31 | ''' 32 | freq_m = self.freq_modulator(mode_freq, hidden, alpha, omega) 33 | coef_m = self.coef_modulator(mode_coef, hidden, times) 34 | 35 | #============================== 36 | # harmonic part 37 | #============================== 38 | freqs = freq_m / (2*math.pi) * self.sampling_rate 39 | coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes) 40 | freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths) 41 | coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths) 42 | harmonic = modal_synth(freq_s, coef_s, self.sampling_rate) 43 | 44 | #============================== 45 | # noise part 46 | #============================== 47 | param = scale_function(self.proj_noise(hidden) - 5) 48 | 49 | impulse = amp_to_impulse_response(param, self.block_size) 50 | noise = torch.rand( 51 | impulse.shape[0], 52 | impulse.shape[1], 53 | self.block_size, 54 | ).to(impulse) * 2 - 1 55 | noise = fft_convolve(noise, impulse).contiguous() 56 | noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths) 57 | 58 | signal = harmonic + noise 59 | return signal.squeeze(-1), freq_m, coef_m 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /src/model/nn/synthesizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from src.utils import audio as audio 8 | 9 | class Synthesizer(nn.Module): 10 | """ Synthesizer Network """ 11 | def __init__(self, 12 | embed_dim=64, 13 | x_scale=1, t_scale=1, 14 | gamma_scale=0, kappa_scale=0, alpha_scale=0, sig_0_scale=0, sig_1_scale=0, 15 | **kwargs): 16 | super().__init__() 17 | self.sr=kwargs['sr'] 18 | hidden_dim=kwargs['hidden_dim'] 19 | self.n_modes = kwargs['n_modes'] 20 | inharmonic = kwargs['harmonic'].lower() == 'inharmonic' 21 | 22 | self.x_scale = x_scale 23 | self.t_scale = t_scale 24 | self.gamma_scale = gamma_scale 25 | self.kappa_scale = kappa_scale 26 | self.alpha_scale = alpha_scale 27 | self.sig_0_scale = sig_0_scale 28 | self.sig_1_scale = sig_1_scale 29 | 30 | from src.model.nn.blocks import RFF, ModeEstimator 31 | n_feats = 7 32 | self.material_encoder = RFF([1.]*n_feats, embed_dim // 2) 33 | feature_size = embed_dim * n_feats 34 | self.mode_estimator = ModeEstimator( 35 | self.n_modes, embed_dim, kappa_scale, gamma_scale, 36 | inharmonic=inharmonic, 37 | ) 38 | if inharmonic: 39 | from src.model.nn.dmsp import DMSP 40 | self.net = DMSP( 41 | embed_dim=embed_dim, 42 | hidden_size=hidden_dim, 43 | n_features=n_feats, 44 | n_modes=kwargs['n_modes'], 45 | n_bands=kwargs['n_bands'], 46 | block_size=kwargs['block_size'], 47 | sampling_rate=kwargs['sr'], 48 | ) 49 | else: 50 | from src.model.nn.ddsp import DDSP 51 | self.net = DDSP( 52 | feature_size=feature_size, 53 | hidden_size=hidden_dim, 54 | n_modes=kwargs['n_modes'], 55 | n_bands=kwargs['n_bands'], 56 | block_size=kwargs['block_size'], 57 | sampling_rate=kwargs['sr'], 58 | fm=kwargs['ddsp_frequency_modulation'], 59 | ) 60 | 61 | def forward(self, params, pitch, initial): 62 | ''' params : input parameters 63 | pitch : fundamental frequency in Hz 64 | initial: initial condition 65 | ''' 66 | space, times, kappa, alpha, t60, mode_freq, mode_coef = params 67 | 68 | f_0 = pitch.unsqueeze(2) # (b, frames, 1) 69 | times = times.unsqueeze(-1) # (b, sample, 1) 70 | kappa = kappa.unsqueeze(-1) # (b, 1, 1) 71 | alpha = alpha.unsqueeze(-1) # (b, 1, 1) 72 | space = space.unsqueeze(-1) # (b, 1, 1) 73 | gamma = 2*f_0 # (b, frames, 1) 74 | omega = f_0 / self.sr * (2*math.pi) # (b, t, 1) 75 | relf0 = omega - omega.narrow(1,0,1) # (b, t, 1) 76 | 77 | in_coef, in_freq = self.mode_estimator(initial, space, kappa, gamma.narrow(1,9,1)) 78 | mode_coef = in_coef if mode_coef is None else mode_coef 79 | mode_freq = in_freq if mode_freq is None else mode_freq 80 | mode_freq = mode_freq + relf0 # linear FM 81 | 82 | Nt = times.size(1) # total number of samples 83 | Nf = mode_freq.size(1) # total number of frames 84 | frames = self.get_frame_time(times, Nf) 85 | 86 | space = space.repeat(1,f_0.size(1),1) # (b, frames, 1) 87 | alpha = alpha.repeat(1,f_0.size(1),1) # (b, frames, 1) 88 | kappa = kappa.repeat(1,f_0.size(1),1) # (b, frames, 1) 89 | sigma = audio.T60_to_sigma(t60, f_0, 2*f_0*kappa) # (b, frames, 2) 90 | 91 | # fourier features 92 | feat = [space, frames, kappa, alpha, sigma, gamma] 93 | feat = self.normalize_params(feat) 94 | feat = self.material_encoder(feat) # (b, frames, n_feats * embed_dim) 95 | 96 | damping = torch.exp(- frames * sigma.narrow(-1,0,1)) 97 | mode_coef = mode_coef * damping 98 | ut, ut_freq, ut_coef = self.net(feat, mode_freq, mode_coef, frames, alpha, omega, Nt) 99 | return ut, [in_freq, in_coef], [ut_freq, ut_coef] 100 | 101 | def get_frame_time(self, times, Nf): 102 | t_0 = times.narrow(1,0,1) # (Bs, 1, 1) 103 | t_k = torch.ones_like(t_0).repeat(1,Nf,1).cumsum(1) / self.sr 104 | t_k = t_k + t_0 # (Bs, Nt, 1) 105 | return t_k 106 | 107 | def normalize_params(self, params): 108 | def rescale(var, scale): 109 | minval = min(scale) 110 | denval = max(scale) - minval 111 | return (var - minval) / denval 112 | space, times, kappa, alpha, sigma, gamma = params 113 | sig_0, sig_1 = sigma.chunk(2, -1) 114 | space = rescale(space, self.x_scale) 115 | times = rescale(times - max(self.t_scale), self.t_scale) 116 | kappa = rescale(kappa, self.kappa_scale) 117 | alpha = rescale(alpha, self.alpha_scale) 118 | sig_0 = rescale(sig_0, self.sig_0_scale) 119 | sig_1 = rescale(sig_1, self.sig_1_scale) 120 | gamma = rescale(gamma, self.gamma_scale) 121 | sigma = torch.cat((sig_0, sig_1), dim=-1) 122 | return torch.cat([space, times, kappa, alpha, sigma, gamma], dim=-1) 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /src/task/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import tqdm 5 | import torch 6 | import numpy as np 7 | import soundfile as sf 8 | 9 | from src.utils import plot as plot 10 | from src.utils.misc import f0_interpolate 11 | from src.utils.analysis.frequency import compute_harmonic_parameters 12 | from src.utils.fdm import stiff_string_modes 13 | 14 | def evaluate(load_dir): 15 | data_list = sorted(glob.glob(f"{load_dir}/*/string_params.npz")) 16 | iterator = tqdm.tqdm(data_list) 17 | iterator.set_description("Evaluating") 18 | for path in iterator: 19 | iterator.set_postfix(path=path) 20 | subd = path.split('/')[-2] 21 | string_data = np.load(path) 22 | bow_data = np.load(path.replace('string_params.npz', 'bow_params.npz')) 23 | hammer_data = np.load(path.replace('string_params.npz', 'hammer_params.npz')) 24 | 25 | uout, sr = sf.read(path.replace('string_params.npz', 'output-u.wav')) 26 | zout, sr = sf.read(path.replace('string_params.npz', 'output-z.wav')) 27 | k = 1 / sr 28 | theta_t = 0.5 + 2/(np.pi**2) 29 | 30 | f0_input = string_data["f0"] 31 | T60 = string_data["T60"] 32 | kappa_rel = string_data["kappa"] 33 | alpha = string_data["alpha"] 34 | f0_target = string_data["target_f0"] 35 | 36 | kappa = (2 * f0_input * kappa_rel).mean() 37 | modes = stiff_string_modes(f0_input, kappa_rel, 10)[0] 38 | 39 | h_params = compute_harmonic_parameters(uout, sr) 40 | f0_estimate = h_params['f0'] 41 | f0_input_interpolated = f0_interpolate(f0_input, len(f0_estimate), len(uout) / sr) 42 | f0_target_interpolated = f0_interpolate(f0_target, len(f0_estimate), len(uout) / sr) 43 | modes_interpolated = [f0_interpolate(m, len(f0_estimate), len(uout) / sr) for m in modes] 44 | f0_diff_input = np.mean(np.abs(f0_input_interpolated - f0_estimate)) 45 | f0_diff_target = np.mean(np.abs(f0_target_interpolated - f0_estimate)) 46 | f0_diff_modes = np.mean(np.abs(modes_interpolated[0] - f0_estimate)) 47 | f0_diff_ground = np.mean(np.abs(modes_interpolated[0] - f0_input_interpolated)) 48 | 49 | front = int(len(f0_estimate) * 0.2) # 0.2 sec 50 | f0_diff_input_front = np.mean(np.abs(f0_input_interpolated[:front] - f0_estimate[:front])) 51 | f0_diff_modes_front = np.mean(np.abs(modes_interpolated[0][:front] - f0_estimate[:front])) 52 | 53 | with open(f"{load_dir}/{subd}/string_params.txt", 'w') as f: 54 | f.write(f"f0 diff (input)\t{f0_diff_input:.2f}\n") 55 | f.write(f"f0 diff (target)\t{f0_diff_target:.2f}\n") 56 | f.write(f"f0 diff (modes)\t{f0_diff_modes:.2f}\n") 57 | f.write(f"f0 diff (ground)\t{f0_diff_ground:.2f}\n") 58 | f.write(f"f0 diff input front\t{f0_diff_input_front:.2f}\n") 59 | f.write(f"f0 diff modes front\t{f0_diff_modes_front:.2f}\n") 60 | #plot_spectrum_uz(f'{load_dir}/{subd}/spectrum.pdf', uout, zout, f0_input, f0_estimate, modes, sr) 61 | #plot.rainbowgram(f'{load_dir}/{subd}/spec.pdf', uout, sr, colorbar=False) 62 | plot.rainbowgram(f'{load_dir}/{subd}/f0-naive.pdf', uout, sr, f0_input=f0_input, colorbar=False) 63 | plot.rainbowgram(f'{load_dir}/{subd}/f0-precorrected.pdf', uout, sr, f0_input=f0_target, colorbar=False) 64 | 65 | 66 | -------------------------------------------------------------------------------- /src/task/process_training_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | from glob import glob 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import math 9 | 10 | import src.utils.fdm as fdm 11 | import src.utils.misc as ms 12 | import src.utils.data as data 13 | import src.utils.audio as audio 14 | import src.model.analytic as analytic 15 | from src.utils.analysis.frequency import compute_harmonic_parameters 16 | 17 | def is_processed(directory, N): 18 | if not os.path.exists(directory): return False 19 | ut_list = glob(f"{directory}/ut-*.wav") 20 | ua_list = glob(f"{directory}/ua-*.wav") 21 | vt_list = glob(f"{directory}/vt.wav") 22 | parameters = f"{directory}/parameters.npz" 23 | if len(ut_list) != N: return False 24 | if len(ua_list) != N: return False 25 | if len(vt_list) != 1: return False 26 | if not os.path.exists(parameters): return False 27 | return True 28 | 29 | def rms(x, eps=1e-18): 30 | mean_val = np.mean(x ** 2) 31 | return 1 if mean_val < eps else np.sqrt(np.mean(x ** 2)) 32 | 33 | def load_data(dirs): 34 | _sim = np.load(f"{dirs}/simulation.npz"); _sim_dict = dict() 35 | _str = np.load(f"{dirs}/string_params.npz"); _str_dict = dict() 36 | _bow = np.load(f"{dirs}/bow_params.npz"); _bow_dict = dict() 37 | _ham = np.load(f"{dirs}/hammer_params.npz"); _ham_dict = dict() 38 | 39 | for key in _sim.keys(): _sim_dict[key] = _sim[key] 40 | for key in _str.keys(): _str_dict[key] = _str[key] 41 | for key in _bow.keys(): _bow_dict[key] = _bow[key] 42 | for key in _ham.keys(): _ham_dict[key] = _ham[key] 43 | return _sim_dict, _str_dict, _bow_dict, _ham_dict 44 | 45 | def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate): 46 | ''' amplitudes: (batch, Nt, n_harmoincs) 47 | frequencies: (batch, Nt, n_harmonics) 48 | ''' 49 | aa = (frequencies < sampling_rate / 2).float() + 1e-4 50 | return amplitudes * aa 51 | 52 | def synth(freq, coef, damp, n_chunks=100): 53 | freqs = freq.chunk(n_chunks, 1) 54 | coefs = coef.chunk(n_chunks, 1) 55 | damps = damp.chunk(n_chunks, 1) 56 | lastf = torch.zeros_like(freqs[0]) 57 | sols = [] 58 | for f, c, d in zip(freqs, coefs, damps): 59 | fcs = f.cumsum(1) + lastf 60 | sol = (torch.cos(fcs) * c * d).sum(-1, keepdim=True) 61 | lastf = fcs.narrow(1,-1,1) 62 | sols.append(sol) 63 | return torch.cat(sols, 1) 64 | 65 | def T60_to_sigma(T60, f_0, K): 66 | ''' T60 : (Bs, 2, 2) [[T60_freq_1, T60_1], [T60_freq_2, T60_2]] 67 | f_0 : (Bs, Nt, 1) fundamental frequency 68 | K : (Bs, Nt, 1) kappa (K == gamma * kappa_rel) 69 | -> sig : (Bs, Nt, 2) 70 | ''' 71 | gamma = f_0 * 2 72 | freq1, time1 = T60.narrow(1,0,1).chunk(2,-1) 73 | freq2, time2 = T60.narrow(1,1,1).chunk(2,-1) 74 | 75 | zeta1 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq1).pow(2)).pow(.5) 76 | zeta2 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq2).pow(2)).pow(.5) 77 | sig0 = - zeta2 / time1 + zeta1 / time2 78 | sig0 = 6 * math.log(10) * sig0 / (zeta1 - zeta2) 79 | 80 | sig1 = 1 / time1 - 1 / time2 81 | sig1 = 6 * math.log(10) * sig1 / (zeta1 - zeta2) 82 | 83 | sig = torch.cat((sig0, sig1), dim=-1) 84 | return sig 85 | 86 | def get_analytic_solution(u0, f0, kr, ts, sr, new_Nx, strict=True, device='cuda:0'): 87 | Nt, Nx = u0.shape 88 | 89 | if isinstance(u0, np.ndarray): u0 = torch.from_numpy(u0) 90 | if isinstance(f0, np.ndarray): f0 = torch.from_numpy(f0) 91 | if isinstance(kr, np.ndarray): kr = torch.from_numpy(kr) 92 | if isinstance(ts, np.ndarray): ts = torch.from_numpy(ts) 93 | dtype = u0.dtype 94 | ti = torch.arange(Nt, dtype=dtype).view(1,-1,1) / sr 95 | xi = torch.linspace(0,1,Nx, dtype=dtype).view(1,1,-1) 96 | xvals = np.linspace(0,1,new_Nx) 97 | _u0 = torch.from_numpy(ms.interpolate( 98 | u0.squeeze(0).numpy(), ti, xi, xvals 99 | )).narrow(0,0,1) 100 | 101 | _, mode_freq, mode_amps = analytic.lossy_stiff_string(_u0, f0, kr, ts, Nt, new_Nx, sr, strict=strict, device=device) 102 | 103 | return mode_freq, mode_amps 104 | 105 | 106 | def save_upsampled_data(load_dir, save_dir, sr, Nx, strict=True): 107 | try: 108 | _sim, _str, _bow, _ham = load_data(load_dir) 109 | except FileNotFoundError as err: 110 | print("*"*30) 111 | print(f"File Not Found in {load_dir}") 112 | print("*"*30) 113 | return 0 114 | 115 | ut = _sim['state_u'] # (time, Nu) 116 | f0 = _str['f0'] # (time, ) 117 | kr = _str['kappa'] 118 | al = _str['alpha'] 119 | ts = _str['T60'] # (2, 2) 120 | k = 1 / sr 121 | with open(f"{load_dir}/simulation_config.yaml", 'r') as f: 122 | constants = yaml.load(f, Loader=yaml.FullLoader) 123 | theta_t = constants["theta_t"] 124 | lambda_c = constants["lambda_c"] 125 | nx_t, _, nx_l, _ = fdm.get_derived_vars( 126 | torch.from_numpy(f0), 127 | torch.from_numpy(kr), k, theta_t, lambda_c, 128 | torch.from_numpy(al))[2:6] 129 | 130 | dtype = np.float64 if ut.dtype == 'float64' else np.float32 131 | Nt, Nu = list(ut.shape) 132 | ki = max(min([5, int(min(nx_t))-1]), 1) 133 | xi = np.linspace(0,1,Nx)[None,:] 134 | ti = np.arange(Nt, dtype=dtype)[:,None] / sr 135 | 136 | ''' Upsample ut, zt to the spatial resolution with Nx 137 | ''' 138 | if np.abs(f0 - np.mean(f0)).sum() < 0.1: # Hz 139 | # constant f0 140 | xu = np.linspace(0,1,Nu, dtype=dtype)[None,:] 141 | ut = ms.interpolate(ut, ti, xu, xi, kx=ki, ky=ki) # (time, Nx) 142 | else: 143 | # time-varying f0 144 | _ut = np.zeros((Nt, Nx)) 145 | for t in range(Nt): 146 | _Nu = int(nx_t[t]) + 1 147 | _xu = np.linspace(0,1,_Nu, dtype=dtype)[None,:] 148 | _ut[t] += ms.interpolate1d(ut[t,:_Nu][None,:], _xu, xi, k=ki)[0] # (time, Nx) 149 | ut = _ut 150 | 151 | Na = 1024 152 | xa = np.linspace(0,1,Na, dtype=dtype)[None,:] 153 | xi = np.linspace(0,1,Nx)[None,:] 154 | 155 | pitch = torch.from_numpy(f0).cuda() 156 | kappa = torch.from_numpy(kr).view(1,1,1).cuda() # (1,1,1) 157 | t60_s = torch.from_numpy(ts[None,:,:]).cuda() # (1,2,2) 158 | times = torch.from_numpy(ti).view(1,-1,1).cuda() # (1,Nt,1) 159 | 160 | ''' Calculate analytic solution and downsample to the spatial resolution with Nx 161 | ''' 162 | mode_freq, mode_amps = get_analytic_solution(ut, pitch, kr, ts, sr, new_Nx=Na, strict=strict) # (time, Na) 163 | mode_amps_nx = np.zeros((mode_amps.shape[0], Nx)) 164 | for n in range(mode_amps.shape[0]): # (n_modes, Na) --> (n_modes, Nx) 165 | mode_amps_nx[n] = ms.interpolate1d(mode_amps[n][None,:], xa, xi)[0] 166 | mode_amps = mode_amps_nx 167 | 168 | omega = pitch / sr * (2*math.pi) 169 | romg = (omega - omega[0]).view(1,-1,1) # ( 1, Nt, 1) 170 | mode_freq = torch.from_numpy( mode_freq[None,None,:]).cuda() # ( 1, 1, n_modes) 171 | mode_amps = torch.from_numpy((mode_amps.T)[:,None,:]).cuda() # (Nx, 1, n_modes) 172 | mode_freq_tv = mode_freq + romg # ( 1, Nt, n_modes) 173 | 174 | sigma = T60_to_sigma(t60_s, pitch, 2*pitch*kappa) # (1, Nt, 2) 175 | damping = torch.exp(- times * sigma.narrow(-1,0,1)) # (1, Nt, 1) 176 | 177 | mode_freq_hz = mode_freq_tv / (2*math.pi) * sr # (Hz) 178 | mode_amps_tv = remove_above_nyquist_mode(mode_amps, mode_freq_hz, sr) 179 | 180 | # (Nx, Nt, 1) 181 | ua = synth(mode_freq_tv, mode_amps_tv, damping).cpu() 182 | ua = ua.squeeze(-1).numpy().T # (time, Nx) 183 | 184 | mode_freq = mode_freq.squeeze().cpu() # (n_modes,) 185 | mode_amps = mode_amps.squeeze().transpose(0,1).cpu() # (n_modes, Nx) 186 | 187 | uas = np.sum(ua, axis=1); _ua = uas / rms(uas) 188 | uts = np.sum(ut, axis=1); _ut = uts / rms(uts) 189 | ua_f0 = compute_harmonic_parameters(_ua, sr)['f0'] # (101,) 190 | ut_f0 = compute_harmonic_parameters(_ut, sr)['f0'] # (101,) 191 | 192 | gain = audio.ell_infty_normalize(ut.flatten())[1] 193 | u0 = ut[0,:][None,:] 194 | _str.pop("v0") 195 | _sim.pop("state_u") 196 | _sim.pop("state_z") 197 | 198 | vt = torch.from_numpy(ut).unsqueeze(0) # (Nt, Nx) 199 | vt = audio.state_to_wav(vt).squeeze(0).numpy() # (Nt) 200 | 201 | _sim.update(dict(ua_f0=ua_f0)) 202 | _sim.update(dict(ut_f0=ut_f0)) 203 | 204 | _sim.update(dict(mode_freq=mode_freq, mode_amps=mode_amps)) 205 | _sim.update(dict(x=xi, t=ti)) 206 | _sim.update(dict(ut=ut, ua=ua, vt=vt)) 207 | _sim.update(dict(gain=gain.squeeze().item())) 208 | _str.update(dict(u0=u0)) 209 | #---------- 210 | _bow.update(dict(ph0_B=_bow.pop('phi_0'))) 211 | _bow.update(dict(ph1_B=_bow.pop('phi_1'))) 212 | _bow.update(dict(wid_B=_bow.pop('wid_B'))) 213 | #---------- 214 | _ham.update(dict(M_H=_ham.pop("M_r"))) 215 | _ham.update(dict(a_H=_ham.pop("alpha"))) 216 | #---------- 217 | 218 | _ovr = {} 219 | _ovr.update(_sim) 220 | _ovr.update(_str) 221 | _ovr.update(_bow) 222 | _ovr.update(_ham) 223 | data.save(save_dir, _ovr) 224 | 225 | def process(args): 226 | path_to_dir = os.path.join(args.task.root_dir, args.task.result_dir) 227 | subdirs = sorted([d for d in glob(f'{path_to_dir}/*') if os.path.isdir(d) and not 'codes' in d]) 228 | 229 | if args.task.data_split > 1: 230 | subdirs = subdirs[args.task.split_n::args.task.data_split] 231 | 232 | iterator = tqdm(subdirs) 233 | iterator.set_description("Preprocess Data (Simulation --> Training)") 234 | for subdir in iterator: 235 | iterator.set_postfix( 236 | load_dir=subdir, 237 | Nx=args.task.Nx) 238 | save_dir = subdir.replace(args.task.result_dir, args.task.save_dir) 239 | os.makedirs(save_dir, exist_ok=True) 240 | 241 | if is_processed(save_dir, args.task.Nx): continue 242 | save_upsampled_data(subdir, save_dir, args.task.sr, args.task.Nx, args.task.strict) 243 | 244 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import sys 7 | import glob 8 | 9 | import pytorch_lightning as pl 10 | from lightning.pytorch import seed_everything 11 | from pytorch_lightning.loggers import WandbLogger 12 | from pytorch_lightning.callbacks import LearningRateMonitor 13 | from torchinfo import summary 14 | 15 | def get_proj_dir(args): 16 | if not os.path.isabs(args.task.ckpt_dir): 17 | return f"{args.task.root_dir}/{args.task.ckpt_dir}" 18 | else: 19 | return args.task.ckpt_dir 20 | 21 | def get_checkpoint(args): 22 | ''' args.task.ckpt_dir <-- args.task.result_dir ''' 23 | proj_dir = get_proj_dir(args) 24 | ckpt_dir = f'{proj_dir}/{args.task.project}/*/checkpoints' 25 | best_ckpt_path = glob.glob(f'{ckpt_dir}/*.ckpt') 26 | assert len(best_ckpt_path) == 1, [best_ckpt_path, ckpt_dir] 27 | return os.path.join(best_ckpt_path[0]) 28 | 29 | def train(args): 30 | seed_everything(args.proc.seed, workers=True) 31 | 32 | scr = __import__(f'src.task.{args.task._name_}', fromlist=['']) 33 | model = scr.Trainer(args) 34 | if args.task.ckpt_dir is not None: 35 | ckpt_path = get_checkpoint(args) 36 | model = model.load_from_checkpoint(ckpt_path) 37 | summary(model) 38 | 39 | os.environ['WANDB_SILENT']="true" 40 | logger = WandbLogger( 41 | project=args.task.project, 42 | name=args.task.run, 43 | group=args.task._name_, 44 | dir=f".", 45 | mode = 'disabled' if args.proc.debug else 'online', 46 | allow_val_change=True, 47 | resume='allow', 48 | anonymous='allow', 49 | config=args, 50 | ) 51 | 52 | from src import callbacks 53 | pl_callbacks = [ ] 54 | pl_callbacks += [ callbacks.PlotResults(args), ] 55 | pl_callbacks += [ LearningRateMonitor(logging_interval='step'), ] if not args.proc.debug else [] 56 | 57 | mnum = min(args.proc.gpus) 58 | gpus=[gpu_num - mnum for gpu_num in args.proc.gpus] 59 | num_sanity_val_steps = 1 if args.proc.debug else 0 60 | #num_sanity_val_steps = 0 61 | 62 | pl_conf = dict( 63 | devices='auto', 64 | strategy='ddp_find_unused_parameters_true', 65 | num_sanity_val_steps=num_sanity_val_steps, 66 | logger=logger, 67 | callbacks=pl_callbacks, 68 | profiler="simple", 69 | max_epochs=args.task.total_epoch, 70 | check_val_every_n_epoch=args.task.valid_epoch, 71 | detect_anomaly=True if args.proc.debug else False, 72 | ) 73 | #if args.task.overfit: 74 | # pl_conf.update(dict(limit_train_batches=1)) 75 | 76 | # train model 77 | trainer = pl.Trainer(**pl_conf) 78 | trainer.fit(model) 79 | 80 | def eval(args): 81 | seed_everything(args.proc.seed, workers=True) 82 | 83 | print("*** Running in test mode!") 84 | 85 | proj_dir = get_proj_dir(args) 86 | sys.path.append(proj_dir) 87 | trainers_src = __import__(f'codes.src.task.{args.task._name_}', fromlist=['']) 88 | callback_src = __import__(f'codes.src.callbacks', fromlist=['']) 89 | 90 | ckpt_path = get_checkpoint(args) 91 | print(f"... load model ckpt from : {ckpt_path}") 92 | #hpar_path = f"results.{args.task.result_dir}.lightning_logs" 93 | model = trainers_src.Trainer.load_from_checkpoint( 94 | checkpoint_path=ckpt_path, 95 | #hparams_file=f"", 96 | map_location=None, 97 | args=args, 98 | ) 99 | 100 | pl_callbacks = [] 101 | if args.task.save_test_score: 102 | pl_callbacks += [ callback_src.SaveTestResults(args), ] 103 | if args.task.plot_test_video: 104 | pl_callbacks += [ callback_src.PlotStateVideo(args), ] 105 | mnum = min(args.proc.gpus) 106 | gpus=[gpu_num - mnum for gpu_num in args.proc.gpus] 107 | pl_conf = dict( 108 | devices='auto', 109 | strategy='ddp_find_unused_parameters_true', 110 | default_root_dir=proj_dir, 111 | logger=None, 112 | callbacks=pl_callbacks, 113 | profiler="simple", 114 | max_epochs=args.task.total_epoch, 115 | check_val_every_n_epoch=args.task.valid_epoch, 116 | detect_anomaly=True if args.proc.debug else False, 117 | ) 118 | trainer = pl.Trainer(**pl_conf) 119 | trainer.test(model) 120 | 121 | 122 | -------------------------------------------------------------------------------- /src/utils/analysis/frequency.py: -------------------------------------------------------------------------------- 1 | import crepe 2 | from src.utils.misc import suppress_stdout_stderr 3 | 4 | def compute_harmonic_parameters(x, sr): 5 | with suppress_stdout_stderr(): 6 | time, f0, confidence, activation = crepe.predict(x, sr, viterbi=True) 7 | return dict( 8 | f0=f0, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /src/utils/audio.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import librosa 6 | import soundfile as sf 7 | from einops import rearrange 8 | 9 | eps = np.finfo(np.float32).eps 10 | 11 | def calculate_rms(amp): 12 | if isinstance(amp, torch.Tensor): 13 | return amp.pow(2).mean(-1, keepdim=True).pow(.5) 14 | elif isinstance(amp, np.ndarray): 15 | return np.sqrt(np.mean(np.square(amp), axis=-1) + eps) 16 | else: 17 | raise TypeError(f"argument 'amp' must be torch.Tensor or np.ndarray. got: {type(amp)}") 18 | 19 | def dB2amp(dB): 20 | return np.power(10., dB/20.) 21 | 22 | def amp2dB(amp): 23 | return 20. * np.log10(amp) 24 | 25 | def rms_normalize(wav, ref_dBFS=-23.0, skip_nan=True): 26 | exists_nan = np.isnan(np.sum(wav)) 27 | if not skip_nan: 28 | assert not exists_nan, np.isnan(wav) 29 | if exists_nan: 30 | return wav, 1. 31 | # RMS normalize 32 | # value_dBFS = 20*log10(rms(signal) * sqrt(2)) = 20*log10(rms(signal)) + 3.0103 33 | rms = calculate_rms(wav) 34 | if isinstance(ref_dBFS, torch.Tensor): 35 | ref_linear = torch.pow(10, (ref_dBFS-3.0103)/20.) 36 | else: 37 | ref_linear = np.power(10, (ref_dBFS-3.0103)/20.) 38 | gain = ref_linear / (rms + eps) 39 | wav = gain * wav 40 | return wav, gain 41 | 42 | def ell_infty_normalize(wav, skip_nan=True): 43 | if isinstance(wav, np.ndarray): 44 | ''' numpy ''' 45 | exists_nan = np.isnan(np.sum(wav)) 46 | if not skip_nan: 47 | assert not exists_nan, np.isnan(wav) 48 | if exists_nan: 49 | return wav, 1. 50 | maxv = np.max(np.abs(wav), axis=-1) 51 | # 1 if maxv == 0 else 1. / maxv 52 | if len(list(maxv.shape)) == 0: 53 | gain = 1 if maxv==0 else 1. / maxv 54 | else: 55 | gain = 1. / maxv; gain[maxv==0] = 1 56 | elif isinstance(wav, torch.Tensor): 57 | ''' torch ''' 58 | exists_nan = torch.isnan(wav.sum()) 59 | if not skip_nan: 60 | assert not exists_nan, torch.isnan(wav) 61 | if exists_nan: 62 | return wav, 1. 63 | maxv = wav.abs().max(-1).values.unsqueeze(-1) 64 | # 1 if maxv == 0 else 1. / maxv 65 | gain = torch.where(maxv.eq(0), 66 | torch.ones_like(maxv), 1. / maxv) 67 | else: 68 | assert False, wav 69 | wav = gain * wav 70 | return wav, gain 71 | 72 | def dB_RMS(wav): 73 | if isinstance(wav, torch.Tensor): 74 | return 20 * torch.log10(calculate_rms(wav)) 75 | elif isinstance(wav, np.ndarray): 76 | return 20 * np.log10(calculate_rms(wav)) 77 | 78 | def mel_basis(sr, n_fft, n_mel): 79 | return librosa.filters.mel(sr=sr,n_fft=n_fft,n_mels=n_mel,fmin=0,fmax=sr//2,norm=1) 80 | 81 | def inv_mel_basis(sr, n_fft, n_mel): 82 | return librosa.filters.mel( 83 | sr=sr, n_fft=n_fft, n_mels=n_mel, norm=None, fmin=0, fmax=sr//2, 84 | ).T 85 | 86 | def lin_to_mel(linspec, sr, n_fft, n_mel=80): 87 | basis = mel_basis(sr, n_fft, n_mel) 88 | return basis @ linspec 89 | 90 | def save_waves(est, save_dir, sr=16000): 91 | data = [] 92 | batch_size = inp.shape[0] 93 | for b in range(batch_size): 94 | est_wav = est[b,0].squeeze() 95 | wave_path = f"{save_dir}/{b}.wav" 96 | sf.write(wave_path, est_wav, samplerate=sr) 97 | 98 | def get_inverse_window(forward_window, frame_length, frame_step): 99 | denom = torch.square(forward_window) 100 | overlaps = -(-frame_length // frame_step) # Ceiling division. 101 | denom = F.pad(denom, (0, overlaps * frame_step - frame_length)) 102 | denom = denom.reshape(overlaps, frame_step) 103 | denom = denom.sum(0, keepdims=True) 104 | denom = denom.tile(overlaps, 1) 105 | denom = denom.reshape(overlaps * frame_step) 106 | return forward_window / denom[:frame_length] 107 | 108 | def state_to_wav(state, normalize=True, sr=48000): 109 | ''' state: (Bs, Nt, Nx) ''' 110 | assert len(list(state.shape)) == 3, state.shape 111 | Nt = state.size(1) 112 | vel = ((state.narrow(1,1,Nt-1) - state.narrow(1,0,Nt-1)) * sr).sum(-1) 113 | return ell_infty_normalize(vel)[0] if normalize else vel 114 | 115 | def state_to_spec(x, window): 116 | ''' x: (Bs, Nt, Nx, Ch) 117 | -> (Bs, Nt, Nx, Ch*n_fft*2) 118 | ''' 119 | Bs, Nt, Nx, Ch = x.shape 120 | n_ffts = window.size(-1) 121 | n_freq = n_ffts // 2 + 1 122 | hop_length = n_ffts // 4 123 | x = rearrange(x, 'b t x c -> (b x c) t') 124 | s = torch.stft(x, n_ffts, hop_length=hop_length, window=window) 125 | s = rearrange(s, '(b x c) f t k -> b t x (c f k)', 126 | b=Bs, x=Nx, c=Ch, f=n_freq, k=2) 127 | return s 128 | 129 | def spec_to_state(x, window, length): 130 | ''' x: (Bs, Nt, Nx, Ch*n_fft*2) 131 | -> (Bs, Nt, Nx, Ch) 132 | ''' 133 | Bs, Nt, Nx, _ = x.shape 134 | n_ffts = window.size(-1) 135 | n_freq = n_ffts // 2 + 1 136 | 137 | x = rearrange(x, 'b t x (c f k) -> (b x c) f t k', f=n_freq, k=2) 138 | x = torch.istft(x, n_ffts, length=length, window=window) 139 | x = rearrange(x, '(b x c) t -> b t x c', b=Bs, x=Nx) 140 | return x 141 | 142 | 143 | def to_spec(x, window, reduce_channel=True): 144 | ''' x: (Bs, Nt) 145 | -> (Bs, Nt, Nf*2) if reduce_channel==True 146 | -> (Bs, Nt, Nf,2) otherwise 147 | ''' 148 | Bs, Nt = x.shape 149 | n_ffts = window.size(-1) 150 | n_freq = n_ffts // 2 + 1 151 | hop_length = n_ffts // 4 152 | s = torch.stft(x, n_ffts, hop_length=hop_length, window=window) 153 | s = s.transpose(1,2) 154 | if reduce_channel: 155 | s = rearrange(s, 'b t f k -> b t (f k)', 156 | b=Bs, f=n_freq, k=2) 157 | return s 158 | 159 | def from_spec(x, window, length): 160 | ''' x: (Bs, Nt, Nf*2) 161 | -> (Bs, Nt) 162 | ''' 163 | Bs, Nt, _ = x.shape 164 | n_ffts = window.size(-1) 165 | n_freq = n_ffts // 2 + 1 166 | 167 | x = rearrange(x, 'b t (f k) -> b f t k', f=n_freq, k=2) 168 | x = torch.istft(x, n_ffts, length=length, window=window) 169 | return x 170 | 171 | def adjust_gain(y, x, minmax, ref_dBFS=-23.0): 172 | ran_gain = (minmax[1] - minmax[0]) * torch.rand_like(y.narrow(-1,0,1)) + minmax[0] 173 | ref_linear = np.power(10, (ref_dBFS-3.0103)/20.) 174 | ran_linear = torch.pow(10, (ran_gain-3.0103)/20.) 175 | x_rms = calculate_rms(x) 176 | y_rms = calculate_rms(y) 177 | x_gain = ref_linear / (x_rms + eps) 178 | y_gain = ref_linear / (y_rms + eps) 179 | 180 | y_xscale = y * y_gain / x_gain 181 | return y_xscale / ran_linear 182 | 183 | def degrade(x, rir, noise): 184 | ''' x : (Bs, Nt) 185 | rir : (Bs, Nt) 186 | noise: (Bs, Nt) 187 | ''' 188 | x_pad = F.pad(x, (0,rir.size(-1))) 189 | w_pad = F.pad(rir, (0,rir.size(-1))) 190 | x_fft = torch.fft.rfft(x_pad) 191 | w_fft = torch.fft.rfft(w_pad) 192 | wet_x = torch.fft.irfft(x_fft * w_fft).narrow(-1,0,x.size(-1)) 193 | 194 | y = adjust_gain(wet_x, x, [-0, 30]) # ser 195 | n = adjust_gain(noise, y, [10, 30]) # snr 196 | return y + n 197 | 198 | def T60_to_sigma(T60, f_0, K): 199 | ''' T60 : (Bs, 2, 2) [[T60_freq_1, T60_1], [T60_freq_2, T60_2]] 200 | f_0 : (Bs, Nt, 1) fundamental frequency 201 | K : (Bs, Nt, 1) kappa (K == gamma * kappa_rel) 202 | -> sig : (Bs, Nt, 2) 203 | ''' 204 | gamma = f_0 * 2 205 | freq1, time1 = T60.narrow(1,0,1).chunk(2,-1) 206 | freq2, time2 = T60.narrow(1,1,1).chunk(2,-1) 207 | 208 | zeta1 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq1).pow(2)).pow(.5) 209 | zeta2 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq2).pow(2)).pow(.5) 210 | sig0 = - zeta2 / time1 + zeta1 / time2 211 | sig0 = 6 * math.log(10) * sig0 / (zeta1 - zeta2) 212 | 213 | sig1 = 1 / time1 - 1 / time2 214 | sig1 = 6 * math.log(10) * sig1 / (zeta1 - zeta2) 215 | 216 | sig = torch.cat((sig0, sig1), dim=-1) 217 | return sig 218 | 219 | 220 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Sequence, Mapping, Optional, Callable 3 | import functools 4 | import hydra 5 | from omegaconf import ListConfig, DictConfig, OmegaConf 6 | import rich.syntax 7 | import rich.tree 8 | 9 | # TODO this is usually used in a pattern where it's turned into a list, so can just do that here 10 | def is_list(x): 11 | return isinstance(x, Sequence) and not isinstance(x, str) 12 | 13 | def is_dict(x): 14 | return isinstance(x, Mapping) 15 | 16 | def to_dict(x, recursive=True): 17 | """Convert Sequence or Mapping object to dict 18 | 19 | lists get converted to {0: x[0], 1: x[1], ...} 20 | """ 21 | if is_list(x): 22 | x = {i: v for i, v in enumerate(x)} 23 | if is_dict(x): 24 | if recursive: 25 | return {k: to_dict(v, recursive=recursive) for k, v in x.items()} 26 | else: 27 | return dict(x) 28 | else: 29 | return x 30 | 31 | 32 | def to_list(x, recursive=False): 33 | """Convert an object to list. 34 | 35 | If Sequence (e.g. list, tuple, Listconfig): just return it 36 | 37 | Special case: If non-recursive and not a list, wrap in list 38 | """ 39 | if is_list(x): 40 | if recursive: 41 | return [to_list(_x) for _x in x] 42 | else: 43 | return list(x) 44 | else: 45 | if recursive: 46 | return x 47 | else: 48 | return [x] 49 | 50 | 51 | def extract_attrs_from_obj(obj, *attrs): 52 | if obj is None: 53 | assert len(attrs) == 0 54 | return [] 55 | return [getattr(obj, attr, None) for attr in attrs] 56 | 57 | 58 | def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): 59 | """ 60 | registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) 61 | config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor 62 | wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) 63 | *args, **kwargs: additional arguments to override the config to pass into the target constructor 64 | """ 65 | 66 | # Case 1: no config 67 | if config is None: 68 | return None 69 | # Case 2a: string means _name_ was overloaded 70 | if isinstance(config, str): 71 | _name_ = None 72 | _target_ = registry[config] 73 | config = {} 74 | # Case 2b: grab the desired callable from name 75 | else: 76 | _name_ = config.pop("_name_") 77 | try: 78 | _target_ = registry[_name_] 79 | except KeyError as err: 80 | print(f"Key error '{_name_}'. Check if {_name_} exists in src/utils/registry.py") 81 | raise err 82 | 83 | # Retrieve the right constructor automatically based on type 84 | if isinstance(_target_, str): 85 | fn = hydra.utils.get_method(path=_target_) 86 | elif isinstance(_target_, Callable): 87 | fn = _target_ 88 | else: 89 | raise NotImplementedError("instantiate target must be string or callable") 90 | 91 | # Instantiate object 92 | if wrap is not None: 93 | fn = wrap(fn) 94 | obj = functools.partial(fn, *args, **config, **kwargs) 95 | 96 | # Restore _name_ 97 | if _name_ is not None: 98 | config["_name_"] = _name_ 99 | 100 | if partial: 101 | return obj 102 | else: 103 | return obj() 104 | 105 | 106 | def get_class(registry, _name_): 107 | return hydra.utils.get_class(path=registry[_name_]) 108 | 109 | 110 | def omegaconf_filter_keys(d, fn=None): 111 | """Only keep keys where fn(key) is True. Support nested DictConfig. 112 | # TODO can make this inplace? 113 | """ 114 | if fn is None: 115 | fn = lambda _: True 116 | if is_list(d): 117 | return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) 118 | elif is_dict(d): 119 | return DictConfig( 120 | {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} 121 | ) 122 | else: 123 | return d 124 | 125 | 126 | def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place 127 | """A couple of optional utilities, controlled by main config file: 128 | - disabling warnings 129 | - easier access to debug mode 130 | - forcing debug friendly configuration 131 | Modifies DictConfig in place. 132 | Args: 133 | config (DictConfig): Configuration composed by Hydra. 134 | """ 135 | OmegaConf.register_new_resolver('eval', eval) 136 | 137 | # Filter out keys that were used just for interpolation 138 | # config = dictconfig_filter_keys(config, lambda k: not k.startswith('__')) 139 | config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) 140 | 141 | # enable adding new keys to config 142 | OmegaConf.set_struct(config, False) 143 | 144 | # disable python warnings if 145 | if config.get("ignore_warnings"): 146 | warnings.filterwarnings("ignore") 147 | 148 | if config.get("debug"): 149 | config.trainer.fast_dev_run = True 150 | 151 | # force debugger friendly configuration 152 | # Debuggers don't like GPUs or multiprocessing 153 | if config.trainer.get("gpus"): 154 | config.trainer.gpus = 0 155 | if config.loader.get("pin_memory"): 156 | config.loader.pin_memory = False 157 | if config.loader.get("num_workers"): 158 | config.loader.num_workers = 0 159 | 160 | # disable adding new keys to config 161 | # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things 162 | 163 | return config 164 | 165 | def print_config( 166 | config: DictConfig, 167 | resolve: bool = True, 168 | save_cfg=True, 169 | ) -> None: 170 | """Prints content of DictConfig using Rich library and its tree structure. 171 | Args: 172 | config (DictConfig): Configuration composed by Hydra. 173 | fields (Sequence[str], optional): Determines which main fields from config will 174 | be printed and in what order. 175 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 176 | """ 177 | 178 | style = "dim" 179 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 180 | 181 | fields = config.keys() 182 | for field in fields: 183 | branch = tree.add(field, style=style, guide_style=style) 184 | 185 | config_section = config.get(field) 186 | branch_content = str(config_section) 187 | if isinstance(config_section, DictConfig): 188 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 189 | 190 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 191 | 192 | rich.print(tree) 193 | 194 | if save_cfg: 195 | with open("config_tree.txt", "w") as fp: 196 | rich.print(tree, file=fp) 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /src/utils/control.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def constant(f0, n, dtype=None): 6 | ''' f0 (batch_size,) 7 | n (int) 8 | ''' 9 | return f0.unsqueeze(-1) * torch.ones(1,n, dtype=dtype) 10 | 11 | def linear(f1, f2, n): 12 | ''' f1 (batch_size,) 13 | f2 (batch_size,) 14 | n (int) 15 | ''' 16 | out = torch.cat((f1.unsqueeze(-1),f2.unsqueeze(-1)), dim=-1) # (batch_size, 2) 17 | out = F.interpolate(out.unsqueeze(1), size=n, mode='linear', align_corners=True).squeeze(1) # (batch_size, n) 18 | return out 19 | 20 | def glissando(f1, f2, n, mode='linear'): 21 | if mode == 'linear': 22 | return linear(f1, f2, n) 23 | else: 24 | raise NotImplementedError(mode) 25 | 26 | def vibrato(f0, k, mf=[3,5], ma=0.05, ma_in_hz=False): 27 | ''' f0 (batch_size, n) 28 | k (int): 1/sr 29 | mf (list): modulation frequency ([min, max]) 30 | ma (float): modulation amplitude (in Hz) 31 | ma_in_hz (bool): ma is given in Hz (else: ma is given as a weighting factor of f0) 32 | ''' 33 | ff = f0.narrow(-1,0,1) 34 | def get_new_vibrato(f0, k, mf, ma, ma_in_hz): 35 | mod_frq = mf[1] * torch.rand_like(ff) + mf[0] # (B, 1) 36 | mod_amp = ma * torch.rand_like(ff) # (B, 1) 37 | 38 | nt = f0.size(-1) # total time 39 | vt = torch.floor((nt // 2) * torch.rand(f0.size(0)).view(-1,1)) # vibrato time 40 | t = torch.ones_like(f0).cumsum(-1) 41 | m = t.gt(vt) # mask `t` for n <= vt 42 | vibra = m * mod_amp * (1 - torch.cos(2 * np.pi * mod_frq * (t - vt) * k)) / 2 43 | if not ma_in_hz: vibra *= f0 44 | return vibra * torch.randn_like(ff).sign() 45 | return f0 + get_new_vibrato(f0, k, mf, ma, ma_in_hz) 46 | 47 | def triangle_with_velocity(vel, n, sr_t, sr_x, max_u=.1): 48 | ''' vel (batch_size,) velocity 49 | n (int) number of samples 50 | sr_t (int) sampling rate in time 51 | sr_x (int) sampling rate in space 52 | max_u (float) maximum displacement 53 | ''' 54 | vel = vel.view(-1,1) * sr_x / sr_t # m/s to non-dimensional quantity 55 | vel = vel * torch.ones_like(vel).repeat(1,n) 56 | u_H = torch.relu(max_u - (max_u - vel.cumsum(1)).abs() - vel) 57 | u_H = u_H.pow(5).clamp(max=0.01) 58 | return u_H 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /src/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import soundfile as sf 5 | import numpy as np 6 | import src.utils.misc as ms 7 | import glob 8 | 9 | def load_wav(wav_path, npz_path, trim=None, keys=['t', 'kappa', 'alpha'], gain=1.): 10 | _ovr = {} 11 | _res = np.load(npz_path) 12 | if trim is None: 13 | for key in keys: _ovr[key] = _res[key] 14 | _ovr['target'] = gain * sf.read(wav_path)[0] 15 | else: 16 | for key in keys: 17 | _res_val = _res[key] 18 | if key == 't': 19 | _res_val = _res_val[trim[0]:trim[1]] 20 | _ovr[key] = _res_val 21 | _ovr['target'] = gain * sf.read(wav_path)[0][trim[0]:trim[1]] 22 | return _ovr 23 | 24 | def load(dir_path, n_subsample=None, sr=48000, wav_keys=['ut', 'zt', 'ua'], subsample_method='sequential'): 25 | # load preprocessed files (wav files) 26 | path_wav = sorted(glob.glob(f"{dir_path}/*.wav")) 27 | _ovr = {} 28 | for prefix in wav_keys: 29 | _wav = [] 30 | _wav_path_unsrt = glob.glob(f"{dir_path}/{prefix}-*.wav") 31 | max_N = len(_wav_path_unsrt) 32 | _wav_paths = [f"{dir_path}/{prefix}-{i}.wav" for i in range(max_N)] 33 | if n_subsample is not None: 34 | # spatial subsampling 35 | if subsample_method == 'random': 36 | # randomly sample `n_subsample` number of wavs 37 | assert isinstance(n_subsample, int), n_subsample 38 | if max_N < n_subsample: 39 | # with replacement 40 | x_idx = np.random.randint(0, max_N, size=n_subsample, ) 41 | else: 42 | # without replacement 43 | x_idx = np.random.permutation(max_N)[:n_subsample] 44 | else: 45 | # sequentially sample `n_subsample` number of wavs 46 | # with randomly drawn starting points 47 | #print(max_N, dir_path, prefix) 48 | r = np.random.randint(0, max_N-n_subsample) 49 | x_idx = np.array([r + i for i in range(n_subsample)]) 50 | _wav_paths = [_wav_paths[i] for i in x_idx] 51 | for path in _wav_paths: _wav.append(sf.read(path)[0][:,None]) 52 | _ovr[prefix] = np.concatenate(_wav, 1) 53 | 54 | _res = np.load(f"{dir_path}/parameters.npz") 55 | for key in _res.keys(): 56 | _ovr[key] = _res[key] 57 | return _ovr 58 | 59 | def save(dir_path, data_dict, sr=48000, chunk_length=.1): 60 | cl = int(sr * chunk_length) 61 | new_data_dict = data_dict.copy() 62 | for key, data in data_dict.items(): 63 | file_path = f"{dir_path}/{key}" 64 | data = data_dict[key] 65 | if isinstance(data, float) or isinstance(data, int): 66 | continue 67 | data = data.squeeze() 68 | data_dim = len(list(data.shape)) 69 | if key in ['ut', 'zt', 'ua']: 70 | Nt, Nx = data.shape 71 | assert min(Nt, Nx) > 1, [key, data.shape] 72 | for xi in range(Nx): 73 | sf.write(file_path+f"-{xi}.wav", data[:,xi], samplerate=sr, subtype='PCM_24') 74 | new_data_dict.pop(key) 75 | if key in ['vt']: 76 | sf.write(file_path+f".wav", data, samplerate=sr, subtype='PCM_24') 77 | new_data_dict.pop(key) 78 | 79 | np.savez_compressed(f"{dir_path}/parameters.npz", **new_data_dict) 80 | 81 | def set_length(x, size, method='pad', mode='linear', idx_x=None): 82 | if method == 'interpolate': 83 | x_shape = list(x.shape) 84 | if x_shape[-1] == size: 85 | return x 86 | new_shape = x_shape[:-1] + [size]; res = 3 - len(x_shape) 87 | unsqueezed_shape = [1]*res + x_shape if len(x_shape) < 3 else x_shape 88 | x = x.view(unsqueezed_shape) 89 | return F.interpolate(x, size=size, mode=mode).view(new_shape) 90 | 91 | elif method == 'pad': 92 | x_shape = list(x.shape) 93 | assert x_shape[-1] <= size, f"set Nx (={size}) geq to {x_shape[-1]}. To do this, set smaller args.task.f0_inf." 94 | if x_shape[-1] == size: 95 | return x 96 | new_shape = x_shape[:-1] + [size] 97 | new_x = torch.zeros(new_shape, device=x.device, dtype=x.dtype) 98 | new_x[...,:x_shape[-1]] = x 99 | return new_x 100 | 101 | elif method == 'random': 102 | assert idx_x is not None, idx_x 103 | new_x = ms.batched_index_select(x, -1, idx_x) # (Bs, Nt, size) 104 | return new_x 105 | 106 | else: 107 | assert False, method 108 | 109 | def stack_batch( 110 | batch, Nx, Nt=None, sr=48000, 111 | x_method='interpolate', t_method='sequential', 112 | start_time=None, end_time=None, 113 | domain='waveform', n_fft=None, 114 | ): 115 | ''' batch: list of dictionaries of each data 116 | Nx: number of samples in space 117 | Nt: number of samples in time 118 | sr: sampling rate 119 | x_method: method to subsample in space 120 | t_method: method to subsample in time 121 | start_time: specific time stamp to start 122 | end_time: specific time stamp to end 123 | domain: type of domain to model (waveform/stft) 124 | n_fft: number of FFT points 125 | (only applicable when the domain is `stft`) 126 | ''' 127 | assert x_method in ['interpolate', 'pad', 'random'], x_method 128 | assert t_method in ['interpolate', 'sequential', 'interleave'], t_method 129 | ''' interpolate: conduct linear interpolation to subsample 130 | sequential : sequentially subsample without interleave 131 | interleave : subsample with uniform interleaving 132 | ''' 133 | assert domain in ['waveform', 'stft'], domain 134 | keys = batch[0].keys() 135 | stacked_data_dict = dict() 136 | Bs = len(batch) 137 | hop_size = n_fft // 4 if n_fft is not None else None 138 | 139 | idx_x = None 140 | if x_method == 'random': 141 | ut_shape = list(batch[0]['u_in'].shape) 142 | idx_x = ms.random_index(ut_shape[-1], Nx) 143 | 144 | T = batch[0]['u_in'].shape[0] 145 | if Nt is not None: 146 | if start_time is None: 147 | st = np.random.randint(T-Nt, size=Bs) if T-Nt > 0 else np.zeros(Bs, dtype=int) 148 | else: 149 | assert isinstance(start_time, float), start_time 150 | st = int(start_time * sr) * np.ones(Bs, dtype=int) 151 | if end_time is None: 152 | et = np.random.randint(st+Nt, T, size=Bs) if (st+Nt < T).all() else (T-1) * np.zeros(Bs, dtype=int) 153 | # let (et-st) to be divisible by Nt, 154 | # to match the shape for `interleave` 155 | et = Nt * ((et - st) // Nt) + st 156 | else: 157 | assert isinstance(end_time, float), end_time 158 | et = int(end_time * sr) * np.ones(Bs, dtype=int) 159 | else: 160 | st = np.zeros(Bs, dtype=int); Nt = T 161 | et = T * np.ones(Bs, dtype=int) 162 | 163 | n_frames_t = (Nt + n_fft // 2) // hop_size - 1 if n_fft is not None else None 164 | 165 | time_varying_vars = ['u_gt', 'z_gt', 'u_in', 'z_in'] 166 | time_varying_vars += ['f0', 'Nu', 'Nz'] 167 | time_varying_vars += ['x_B', 'v_B', 'F_B', 'wid_B', ] 168 | time_varying_vars += ['v_H', 'u_H'] 169 | time_varying_vars += ['uat', 'uar', 'tt'] 170 | 171 | space_varying_vars = ['u_gt', 'z_gt', 'u_in', 'z_in'] 172 | space_varying_vars += ['uat', 'uar', 'u0', 'z0', 'xt'] 173 | 174 | for key in keys: 175 | data_list = [torch.from_numpy(x[key]) for x in batch] 176 | 177 | #============================== 178 | # Handle temporal resolution 179 | #============================== 180 | # randomize temporal initial point 181 | if key in time_varying_vars: 182 | # if domain == 'waveform', set whole length by `Nt` 183 | # if domain == 'stft', set whole length by `n_frames_t` 184 | TS = Nt 185 | TF = n_frames_t 186 | ''' batch * (time, space) ''' 187 | if t_method == 'sequential': 188 | data_list = [x.narrow(0,st[i],TS) for i, x in enumerate(data_list)] if TS is not None else None 189 | elif t_method == 'interpolate': 190 | data_list = [x.narrow(0,st[i],T-st[i]) for i, x in enumerate(data_list)] 191 | if len(list(data_list[0].shape)) < 2: 192 | data_list = [set_length(x, TS, t_method) for i, x in enumerate(data_list)] if TS is not None else None 193 | else: 194 | data_list = [set_length(x.transpose(0,1), TS, t_method).transpose(0,1) for i, x in enumerate(data_list)] if TS is not None else None 195 | elif t_method == 'interleave': 196 | data_list = [x.narrow(0,st[i],et[i]-st[i])[0::(et[i]-st[i]) // TS] for i, x in enumerate(data_list)] if TS is not None else None 197 | else: 198 | assert False, t_method 199 | 200 | #============================== 201 | # Handle spatial resolution 202 | #============================== 203 | # interpolate to the maximal spatial grid size 204 | ''' batch * (time, space) ''' 205 | if key in space_varying_vars: 206 | data_list = [set_length(x, Nx, x_method, idx_x=idx_x) for x in data_list] 207 | 208 | data_batch = torch.stack(data_list) 209 | stacked_data_dict.update({key: data_batch}) 210 | 211 | return stacked_data_dict 212 | 213 | def get_rir_list(root_dir, subdirs, prefix='rir', sep='RIRS_NOISES'): 214 | assert prefix in ['rir', 'noise'] 215 | rir_list_paths = [] 216 | for sd in subdirs: 217 | rir_dir = os.path.join(root_dir, sd) 218 | if os.path.exists(f"{rir_dir}/{prefix}_list"): 219 | rir_list_paths += [f"{rir_dir}/{prefix}_list"] 220 | else: 221 | roomdirs = [d for d in glob.glob(f"{rir_dir}/*") if os.path.isdir(d)] 222 | for d in roomdirs: 223 | rir_list_paths += glob.glob(f"{d}/{prefix}_list") 224 | rir_paths = [] 225 | for rir_list_path in rir_list_paths: 226 | with open(rir_list_path, 'r') as f: 227 | lines = f.readlines() 228 | rir_paths += [l.split(sep)[-1].split('\n')[0] for l in lines] 229 | return rir_paths 230 | 231 | def get_noise_list(root_dir, subdirs): 232 | for sd in subdirs: 233 | rir_dir = os.path.join(root_dir, sd) 234 | if os.path.exists(f"{rir_dir}/rir_list"): 235 | pass 236 | 237 | 238 | # def old_stack_batch( 239 | # batch, Nx, Nt=None, Nr=None, sr=48000, 240 | # x_method='interpolate', t_method='sequential', 241 | # start_time=None, end_time=None, 242 | # domain='waveform', n_fft=None, 243 | # ): 244 | # ''' batch: list of dictionaries of each data 245 | # Nx: number of samples in space 246 | # Nt: number of samples in time 247 | # Nr: number of samples to randomly draw in time 248 | # sr: sampling rate 249 | # x_method: method to subsample in space 250 | # t_method: method to subsample in time 251 | # start_time: specific time stamp to start 252 | # end_time: specific time stamp to end 253 | # domain: type of domain to model (waveform/stft) 254 | # n_fft: number of FFT points 255 | # (only applicable when the domain is `stft`) 256 | # ''' 257 | # assert x_method in ['interpolate', 'pad', 'random'], x_method 258 | # assert t_method in ['interpolate', 'sequential', 'interleave'], t_method 259 | # ''' interpolate: conduct linear interpolation to subsample 260 | # sequential : sequentially subsample without interleave 261 | # interleave : subsample with uniform interleaving 262 | # ''' 263 | # assert domain in ['waveform', 'stft'], domain 264 | # keys = batch[0].keys() 265 | # stacked_data_dict = dict() 266 | # Bs = len(batch) 267 | # hop_size = n_fft // 4 268 | # 269 | # idx_x = None 270 | # if x_method == 'random': 271 | # ut_shape = list(batch[0]['ut'].shape) 272 | # idx_x = ms.random_index(ut_shape[-1], Nx) 273 | # 274 | # T = batch[0]['ut'].shape[0] 275 | # if Nt is not None: 276 | # if start_time is None: 277 | # st = np.random.randint(T-Nt, size=Bs) if T-Nt > 0 else np.zeros(Bs, dtype=int) 278 | # else: 279 | # assert isinstance(start_time, float), start_time 280 | # st = int(start_time * sr) * np.ones(Bs, dtype=int) 281 | # if end_time is None: 282 | # et = np.random.randint(st+Nt, T, size=Bs) if (st+Nt < T).all() else (T-1) * np.zeros(Bs, dtype=int) 283 | # # let (et-st) to be divisible by Nt, 284 | # # to match the shape for `interleave` 285 | # et = Nt * ((et - st) // Nt) + st 286 | # else: 287 | # assert isinstance(end_time, float), end_time 288 | # et = int(end_time * sr) * np.ones(Bs, dtype=int) 289 | # else: 290 | # st = np.zeros(Bs, dtype=int); Nt = T 291 | # et = T * np.ones(Bs, dtype=int) 292 | # 293 | # n_frames_t = (Nt + n_fft // 2) // hop_size - 1 294 | # n_frames_r = (Nr + n_fft // 2) // hop_size - 1 295 | # 296 | # time_varying_vars = ['ut', 'zt', 'f0', 'Nu', 'Nz'] 297 | # time_varying_vars += ['x_B', 'v_B', 'F_B', 'wid_B', ] 298 | # time_varying_vars += ['v_H', 'u_H', 'uf0'] 299 | # time_varying_vars += ['uat', 'uar', 'tt', 'tr'] 300 | # 301 | # space_varying_vars = ['ut', 'zt', 'uat', 'uar', 'u0', 'z0', 'xt', 'xr'] 302 | # 303 | # for key in keys: 304 | # data_list = [torch.from_numpy(x[key]) for x in batch] 305 | # 306 | # #============================== 307 | # # Handle temporal resolution 308 | # #============================== 309 | # # randomize temporal initial point 310 | # if key in time_varying_vars: 311 | # if key in ['ut', 'zt', 'uf0', 'uat', 'tt']: 312 | # # these variables will be used along with `ut` and `zt` 313 | # # if domain == 'waveform', set whole length by `Nt` 314 | # # if domain == 'stft', set whole length by `n_frames_t` 315 | # TS = Nt 316 | # TF = n_frames_t 317 | # else: 318 | # # these variables will be used along with `ur` and `zr` 319 | # # if domain == 'waveform', set whole length by `Nr` 320 | # # if domain == 'stft', set whole length by `n_frames_r` 321 | # TS = Nr 322 | # TF = n_frames_r 323 | # ''' batch * (time, space) ''' 324 | # if t_method == 'sequential': 325 | # data_list = [x.narrow(0,st[i],TS) for i, x in enumerate(data_list)] if TS is not None else None 326 | # elif t_method == 'interpolate': 327 | # data_list = [x.narrow(0,st[i],T-st[i]) for i, x in enumerate(data_list)] 328 | # if len(list(data_list[0].shape)) < 2: 329 | # data_list = [set_length(x, TS, t_method) for i, x in enumerate(data_list)] if TS is not None else None 330 | # else: 331 | # data_list = [set_length(x.transpose(0,1), TS, t_method).transpose(0,1) for i, x in enumerate(data_list)] if TS is not None else None 332 | # elif t_method == 'interleave': 333 | # data_list = [x.narrow(0,st[i],et[i]-st[i])[0::(et[i]-st[i]) // TS] for i, x in enumerate(data_list)] if TS is not None else None 334 | # else: 335 | # assert False, t_method 336 | # 337 | # #============================== 338 | # # Handle spatial resolution 339 | # #============================== 340 | # # interpolate to the maximal spatial grid size 341 | # ''' batch * (time, space) ''' 342 | # if key in space_varying_vars: 343 | # data_list = [set_length(x, Nx, x_method, idx_x=idx_x) for x in data_list] 344 | # 345 | # data_batch = torch.stack(data_list) 346 | # stacked_data_dict.update({key: data_batch}) 347 | # 348 | # #============================== 349 | # # Randomized input 350 | # #============================== 351 | # if Nr is not None: 352 | # rand_t = stacked_data_dict['tr'] \ 353 | # + torch.rand(size=(Bs,1,1)) / sr 354 | # rand_x = stacked_data_dict['xr'] \ 355 | # + torch.rand(size=(Bs,1,1)) / Nx 356 | # 357 | # template = torch.ones_like(rand_t * rand_x) # (Bs, Nt, Nx) 358 | # rand_t = (rand_t * template).requires_grad_() 359 | # rand_x = (rand_x * template).requires_grad_() 360 | # stacked_data_dict.update(dict(xr=rand_x, tr=rand_t)) 361 | # 362 | # return stacked_data_dict 363 | # 364 | # 365 | -------------------------------------------------------------------------------- /src/utils/ddsp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | import numpy as np 5 | import librosa as li 6 | import math 7 | 8 | from src.model.nn.blocks import get_activation 9 | 10 | def safe_log(x): 11 | return torch.log(x + 1e-7) 12 | 13 | 14 | @torch.no_grad() 15 | def mean_std_loudness(dataset): 16 | mean = 0 17 | std = 0 18 | n = 0 19 | for _, _, l in dataset: 20 | n += 1 21 | mean += (l.mean().item() - mean) / n 22 | std += (l.std().item() - std) / n 23 | return mean, std 24 | 25 | 26 | def multiscale_fft(signal, scales, overlap): 27 | stfts = [] 28 | for s in scales: 29 | S = torch.stft( 30 | signal, 31 | s, 32 | int(s * (1 - overlap)), 33 | s, 34 | torch.hann_window(s).to(signal), 35 | True, 36 | normalized=True, 37 | return_complex=True, 38 | ).abs() 39 | stfts.append(S) 40 | return stfts 41 | 42 | 43 | def resample(x, factor: int): 44 | batch, frame, channel = x.shape 45 | x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame) 46 | 47 | window = torch.hann_window( 48 | factor * 2, 49 | dtype=x.dtype, 50 | device=x.device, 51 | ).reshape(1, 1, -1) 52 | y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x) 53 | y[..., ::factor] = x 54 | y[..., -1:] = x[..., -1:] 55 | y = torch.nn.functional.pad(y, [factor, factor]) 56 | y = torch.nn.functional.conv1d(y, window)[..., :-1] 57 | 58 | y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1) 59 | 60 | return y 61 | 62 | 63 | 64 | def upsample(signal, factor): 65 | signal = signal.permute(0,2,1) 66 | signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear') 67 | return signal.permute(0,2,1) 68 | 69 | 70 | def remove_above_nyquist(amplitudes, pitch, sampling_rate): 71 | ''' amplitudes: (batch, frames, n_harmoincs) 72 | pitch: (batch, frames, 1) 73 | ''' 74 | n_harm = amplitudes.shape[-1] 75 | pitches = pitch.repeat(1,1,n_harm).cumsum(-1) 76 | aa = (pitches < sampling_rate / 2).float() + 1e-4 77 | return amplitudes * aa 78 | 79 | 80 | def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate): 81 | ''' amplitudes: (batch, frames, n_harmoincs) 82 | frequencies: (batch, frames, n_harmonics) 83 | ''' 84 | aa = (frequencies < sampling_rate / 2).float() + 1e-4 85 | return amplitudes * aa 86 | 87 | def scale_function(x): 88 | ''' 0 ~ 2''' 89 | return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7 90 | 91 | def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): 92 | S = li.stft( 93 | signal, 94 | n_fft=n_fft, 95 | hop_length=block_size, 96 | win_length=n_fft, 97 | center=True, 98 | ) 99 | S = np.log(abs(S) + 1e-7) 100 | f = li.fft_frequencies(sampling_rate, n_fft) 101 | a_weight = li.A_weighting(f) 102 | 103 | S = S + a_weight.reshape(-1, 1) 104 | 105 | S = np.mean(S, 0)[..., :-1] 106 | 107 | return S 108 | 109 | 110 | def extract_pitch(signal, sampling_rate, block_size): 111 | length = signal.shape[-1] // block_size 112 | f0 = crepe.predict( 113 | signal, 114 | sampling_rate, 115 | step_size=int(1000 * block_size / sampling_rate), 116 | verbose=1, 117 | center=True, 118 | viterbi=True, 119 | ) 120 | f0 = f0[1].reshape(-1)[:-1] 121 | 122 | if f0.shape[-1] != length: 123 | f0 = np.interp( 124 | np.linspace(0, 1, length, endpoint=False), 125 | np.linspace(0, 1, f0.shape[-1], endpoint=False), 126 | f0, 127 | ) 128 | 129 | return f0 130 | 131 | 132 | def harmonic_synth(pitch, amplitudes, sampling_rate): 133 | n_harmonic = amplitudes.shape[-1] 134 | omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1) 135 | omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) 136 | signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True) 137 | return signal 138 | 139 | def modal_synth(modes, amplitude, sampling_rate, n_chunks=16): 140 | freqs = modes.chunk(n_chunks, 1) 141 | coefs = amplitude.chunk(n_chunks, 1) 142 | lastf = torch.zeros_like(freqs[0]) 143 | sols = [] 144 | for f, c in zip(freqs, coefs): 145 | fcs = f.cumsum(1) + lastf 146 | sol = (torch.cos(fcs) * c).sum(-1, keepdim=True) 147 | lastf = fcs.narrow(1,-1,1) 148 | sols.append(sol) 149 | return torch.cat(sols, 1) 150 | 151 | 152 | def amp_to_impulse_response(amp, target_size): 153 | amp = torch.stack([amp, torch.zeros_like(amp)], -1) 154 | amp = torch.view_as_complex(amp) 155 | amp = fft.irfft(amp) 156 | 157 | filter_size = amp.shape[-1] 158 | 159 | amp = torch.roll(amp, filter_size // 2, -1) 160 | win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) 161 | 162 | amp = amp * win 163 | 164 | amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size))) 165 | amp = torch.roll(amp, -filter_size // 2, -1) 166 | 167 | return amp 168 | 169 | 170 | def fft_convolve(signal, kernel): 171 | signal = nn.functional.pad(signal, (0, signal.shape[-1])) 172 | kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) 173 | 174 | output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) 175 | output = output[..., output.shape[-1] // 2:] 176 | 177 | return output 178 | 179 | -------------------------------------------------------------------------------- /src/utils/fdm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | #from src.utils.misc import raised_cosine, sqrt 6 | def sqrt(x): 7 | return x.pow(.5) if isinstance(x, torch.Tensor) else x**.5 8 | 9 | def tridiagonal_inverse(X, check_tridiagonal=False): 10 | n = X.size(0); ks = torch.arange(n).to(X.device) + 1; jk = torch.outer(ks,ks) 11 | a = X[1,0]; b = X[0,0]; c = X[0,1] 12 | if check_tridiagonal: 13 | assert list(X.shape) == [n,n], X.shape 14 | assert torch.allclose(torch.diag(X, diagonal=-1), a * torch.ones(n-1)) \ 15 | and torch.allclose(torch.diag(X, diagonal= 0), b * torch.ones(n)) \ 16 | and torch.allclose(torch.diag(X, diagonal=+1), c * torch.ones(n-1)) \ 17 | and all([torch.allclose(torch.diag(X, diagonal=i), torch.zeros(n-i)) for i in range(2,n)]) 18 | lam = b + (a+c) * torch.cos(ks * np.pi / (n+1)) 19 | L = torch.diag(1 / lam) 20 | V = np.sqrt(2 / (n+1)) * torch.sin(jk * np.pi / (n+1)) 21 | return V @ L @ V.T 22 | 23 | def I(n, diagonal=0): 24 | l = n - abs(diagonal) 25 | i = torch.ones(l) 26 | return torch.diag(i, diagonal) 27 | 28 | def Dxx(n, h): 29 | Dx = I(n, +1) - 2*I(n) + I(n, -1) 30 | Dx = Dx / (h**2) 31 | return Dx 32 | 33 | def Dxxxx(n, h): 34 | Dx = I(n, +2) - 4*I(n, +1) + 6*I(n) - 4*I(n, -1) + I(n, -2) 35 | Dx = Dx / (h**4) 36 | return Dx 37 | 38 | def D(n, xd, h): 39 | if xd == 'x-': 40 | Dx = I(n) - I(n, -1) 41 | Dx = Dx / h 42 | elif xd == 'x+': 43 | Dx = I(n, +1) - I(n) 44 | Dx = Dx / h 45 | elif xd == 'xc': 46 | Dx = I(n, +1) - I(n, -1) 47 | Dx = Dx / h 48 | elif xd == 'xx': 49 | Dx = I(n, +1) - 2*I(n) + I(n, -1) 50 | Dx = Dx / (h**2) 51 | elif xd == 'xxxx': 52 | Dx = I(n, +2) - 4*I(n, +1) + 6*I(n) - 4*I(n, -1) + I(n, -2) 53 | Dx = Dx / (h**4) 54 | else: 55 | assert False 56 | return Dx 57 | 58 | 59 | def displacement_update(u, n): 60 | # interpolate? 61 | return F.interpolate(u.view(1,1,-1), size=n, mode='linear').view(-1) 62 | # pad/trim raises artifact waves 63 | #if u.size(-1) < n: 64 | # return F.pad(u, (0,n-u.size(-1))) 65 | #else: 66 | # return u[...,:n] 67 | 68 | def bow_term_rhs(N, h, k, u1, u2, x_B, v_B, F_B, wid, friction_fn): 69 | rc = raised_cosine(N-1, h, x_B, wid) 70 | rc = rc / rc.abs().sum() 71 | I = rc 72 | J = rc / h 73 | v_rel = I @ (u1 - u2) / k - v_B # using explicit scheme 74 | Gamma = J * F_B * friction_fn(v_rel) 75 | return - k**2 * Gamma, v_rel 76 | 77 | def initialize_state(u0, v0, Nt, Nx_t, Nx_l, k, dtype=None): 78 | ''' u0 (batch_size, Nt, Nx_t) predefined displacement for each time 79 | v0 (batch_size, Nt, Nx_t) predefined displacement for each time 80 | Nt (int) number of samples in time 81 | Nx_t (int) number of transverse samples in space 82 | Nx_l (int) number of longitudinal samples in space 83 | k (int) temporal spacing 84 | --- 85 | state_t (batch_size, Nt, Nx_t+1) 86 | state_l (batch_size, Nt, Nx_l+1) 87 | ''' 88 | batch_size = u0.size(0) 89 | u0 = torch.from_numpy(u0, dtype=dtype) if isinstance(u0, np.ndarray) else u0 90 | v0 = torch.from_numpy(v0, dtype=dtype) if isinstance(v0, np.ndarray) else v0 91 | 92 | u1 = u0 + k * v0 93 | u2 = u0 94 | 95 | state_t = torch.zeros(batch_size, Nt, Nx_t+1, dtype=dtype) 96 | state_l = torch.zeros(batch_size, Nt, Nx_l+1, dtype=dtype) 97 | state_t[:,:-1,:] = u2[:,:-1,:] 98 | state_t[:,+1:,:] = u1[:,:-1,:] 99 | return state_t, state_l 100 | 101 | def get_derived_vars(f0, kappa_rel, k, theta_t, lambda_c, alpha): 102 | # Derived variables 103 | gamma = 2 * f0 # set parameters 104 | kappa = gamma * kappa_rel # stiffness parameter 105 | IHP = (np.pi * kappa / gamma)**2 # inharmonicity parameter (>0); eq 7.21 106 | K = sqrt(IHP) * (gamma / np.pi) # set parameters 107 | if isinstance(lambda_c, torch.Tensor): 108 | lambda_c = torch.relu(lambda_c - 1) + int(1) # make sure >= 1 109 | else: 110 | lambda_c = int(1) if lambda_c <= 1 else lambda_c 111 | 112 | h = lambda_c * sqrt( \ 113 | (gamma**2 * k**2 + sqrt(gamma**4 * k**4 + 16 * K**2 * k**2 * (2 * theta_t - 1))) \ 114 | / (2 * (2 * theta_t - 1)) \ 115 | ) 116 | N_t = torch.floor(1/h) if isinstance(h, torch.Tensor) else int(1 / h) 117 | h_t = 1 / N_t 118 | 119 | h = lambda_c * gamma * alpha * k 120 | N_l = torch.floor(1/h) if isinstance(h, torch.Tensor) else int(1 / h) 121 | h_l = 1 / N_l 122 | 123 | return gamma, K, N_t, h_t, N_l, h_l 124 | 125 | def get_theta(kappa_max, f0_inf, sr, lambda_c=1): 126 | ''' theta gets larger as... 127 | (1) f0 gets larger 128 | (2) kappa gets smaller 129 | ''' 130 | gamma = 2 * f0_inf 131 | kappa = gamma * kappa_max 132 | k = 1 / sr 133 | 134 | R = ((gamma**4 * k**2 + 4*kappa**2 * math.pi**2) / (gamma**4 * k**2))**.5 135 | S = gamma**4 * k**2 * lambda_c**2 / (4 * kappa**2 * math.pi**4) 136 | expr_1 = 2 * S * lambda_c**2 * (R-1)**2 137 | expr_2 = math.pi**2 * S * (R-1) 138 | theta = 0.5 + expr_1 + expr_2 139 | assert theta < 1, theta 140 | 141 | return theta 142 | 143 | def stiff_string_modes(f0, kappa_rel, p_max=1): 144 | ''' Returns a list of modes of an ideal lossless stiff string. 145 | This inharmonicity factor `B` is valid only if kappa_rel is small. 146 | c.f.; 147 | Fletcher `Normal Vibration Frequencies of a Stiff Piano String` 148 | Bilbao `Numerical Sound Synthesis` (pp. 176) 149 | ''' 150 | B = (np.pi * kappa_rel)**2 151 | 152 | modes = [] 153 | factor = [] 154 | for p in range(1,p_max+1): 155 | w_p = p * (1 + (2/np.pi) * B**.5 + 4/np.pi**2 * B) * (1 + B*p**2)**.5 156 | factor.append(w_p) 157 | modes.append(f0 * w_p) 158 | return modes, factor 159 | 160 | if __name__=='__main__': 161 | sr = 48000 162 | k = 1/sr 163 | #tt = 0.5 + 2/(np.pi**2) 164 | 165 | #for f0 in [20,40,80,160,320]: 166 | # _, _, N_t, _, N_l, _ = get_derived_vars(f0=f0, kappa_rel=0.03, k=k, theta_t=tt, alpha=1) 167 | # print(f0, N_t, N_l) 168 | #for al in [1,2,3,4]: 169 | # _, _, N_t, _, N_l, _ = get_derived_vars(f0=96, kappa_rel=0.03, k=k, theta_t=tt, alpha=al) 170 | # print(al, N_t, N_l) 171 | 172 | #------------------------------ 173 | 174 | #def vibrato(f0, k, mf=3, ma=0.01, upward=True, ma_in_Hz=False, dtype=None): 175 | # nt = f0.size(-1) # total time 176 | # vt = torch.floor((nt // 2) * torch.rand(f0.size(0)).view(-1,1)) # vibrato time 177 | # t = torch.ones_like(f0).cumsum(-1) 178 | # m = t.gt(vt) # mask `t` for n <= vt 179 | # vibra = m * ma * (1 - torch.cos(2 * np.pi * mf * (t - vt) * k)) / 2 180 | # if not ma_in_Hz: vibra *= f0 181 | # return f0 + vibra if upward else f0 - vibra 182 | 183 | #import matplotlib.pyplot as plt 184 | 185 | #f0 = 40 * torch.ones(sr).view(1,-1) 186 | #f0 = vibrato(f0, 1/sr) 187 | 188 | #_, _, N_t, _, N_l, _ = get_derived_vars(f0=f0, kappa_rel=0., k=k, theta_t=tt, lambda_c=1, alpha=10) 189 | 190 | #fig, ax = plt.subplots(nrows=3) 191 | #ax[0].plot(f0[0]) 192 | #ax[1].plot(N_t[0]) 193 | #ax[2].plot(N_l[0]) 194 | #plt.savefig('asdf.png') 195 | 196 | #------------------------------ 197 | 198 | #tt = 1 199 | #------------------------------ 200 | f0 = 55 * torch.ones(sr).view(1,-1) 201 | lam = 1.01 202 | als = 1 203 | #kappa_rel = 0.03 204 | #kappa_rel = 0.02 205 | kappa_rel = 0.01 206 | #------------------------------ 207 | #f0 = 60. * torch.ones(sr).view(1,-1) 208 | #lam = 2 209 | #als = 5 210 | #kappa_rel = 0.03 211 | #------------------------------ 212 | tt = get_theta(kappa_rel, f0.min(), sr, lam) 213 | _, _, N_t, _, N_l, _ = get_derived_vars(f0=f0, kappa_rel=kappa_rel, k=k, theta_t=tt, lambda_c=1, alpha=als) 214 | print(N_t[0,0], N_l[0,0]) 215 | _, _, N_t, _, N_l, _ = get_derived_vars(f0=f0, kappa_rel=kappa_rel, k=k, theta_t=tt, lambda_c=lam, alpha=als) 216 | print(N_t[0,0], N_l[0,0]) 217 | _, _, N_t, _, N_l, _ = get_derived_vars(f0=f0, kappa_rel=kappa_rel, k=k, theta_t=tt, lambda_c=lam**2, alpha=als) 218 | print(N_t[0,0], N_l[0,0]) 219 | 220 | 221 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import abc 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from auraloss.freq import MultiResolutionSTFTLoss 7 | from einops import rearrange 8 | from src.utils import misc as ms 9 | from src.utils import audio as audio 10 | 11 | def rms(x): 12 | out_shape = [x.size(0)] + [1] * len(list(x.shape)[1:]) 13 | x = x.flatten(1) 14 | x_rm = x.pow(2).mean(1, keepdim=True) 15 | x_rms = torch.where(x_rm.eq(0), torch.ones_like(x_rm), x_rm.sqrt()) 16 | return x_rms.view(out_shape) 17 | 18 | def stft(x, n_fft, hop_length): 19 | w = torch.hann_window(n_fft).to(x.device) 20 | x = torch.stft(x, n_fft, hop_length, window=w) 21 | if x.shape[-1] == 2: x = torch.view_as_complex(x) 22 | return x.transpose(-1,-2) # (batch_size, frames, num_freqs) 23 | 24 | def melscale(mag, n_fft, n_mel, sr): 25 | mel_fbank = audio.mel_basis(sr, n_fft, n_mel) 26 | mel_basis = torch.from_numpy(mel_fbank).to(mag.device) 27 | mel = torch.matmul(mel_basis, mag.transpose(-1,-2)).transpose(-1,-2) 28 | return mel 29 | 30 | def stft_loss(x, y, n_fft=1024, n_mel=128, sr=48000, eps=1e-5): 31 | ''' x: (batch_size, n_samples) 32 | y: (batch_size, n_samples) 33 | ''' 34 | n_fft = n_fft if x.size(1) > n_fft else x.size(1) 35 | hop_length = n_fft // 4 36 | x_linmag = stft(x, n_fft, hop_length).abs() 37 | y_linmag = stft(y, n_fft, hop_length).abs() 38 | x_logmag = 20 * torch.log10(x_linmag + eps) 39 | y_logmag = 20 * torch.log10(y_linmag + eps) 40 | x_linmel = melscale(x_linmag, n_fft, n_mel, sr) 41 | y_linmel = melscale(y_linmag, n_fft, n_mel, sr) 42 | x_logmel = 20 * torch.log10(x_linmel + eps) 43 | y_logmel = 20 * torch.log10(y_linmel + eps) 44 | def l1_dist(x, y): 45 | return (x - y).abs().flatten(1).mean(1) 46 | scores = dict( 47 | linmag=l1_dist(x_linmag, y_linmag), 48 | logmag=l1_dist(x_logmag, y_logmag), 49 | linmel=l1_dist(x_linmel, y_linmel), 50 | logmel=l1_dist(x_logmel, y_logmel), 51 | ) 52 | return scores 53 | 54 | def mse_loss(preds, target): 55 | return F.mse_loss(preds, target) 56 | 57 | def dirichlet_bc(u, dim=-1): 58 | ''' u: (batch_size, time, space) ''' 59 | u_D = u.roll(1,dim).narrow(dim,0,2) # (b, t, x=0,-1) 60 | return u_D.abs().mean() 61 | 62 | def pde_loss( 63 | #---------- 64 | # string parameters 65 | ut, zt, u0, f0, kappa, alpha, sig0, sig1, masks, 66 | #---------- 67 | # excitation parameters 68 | bow_params, hammer_params, 69 | #---------- 70 | # grids and metrics 71 | x, t, f_ic, f_bc, f_r, w_ic=1., w_bc=1., w_r=1., 72 | #---------- 73 | ): 74 | ''' ut: (batch, time, space) ''' 75 | est_u0 = ut.narrow(1,0,1) 76 | val_ic = f_ic(est_u0, u0) 77 | val_bc = f_bc(ut) 78 | val_r, results = f_r( 79 | ut, zt, x, t, f0, kappa, alpha, sig0, sig1, 80 | masks, bow_params, hammer_params) 81 | return w_ic * val_ic \ 82 | + w_bc * val_bc \ 83 | + w_r * val_r, results 84 | 85 | def si_sdr(reference_signal, estimated_signal, scaling=True, eps=None): 86 | ''' reference_signal: (batch_size, channels, time) 87 | estimated_signal: (batch_size, channels, time) 88 | -> SISDR calculated for the last dim (batch_size, channels) 89 | ''' 90 | eps = torch.finfo(reference_signal.dtype).eps if eps is None else eps 91 | batch_size = estimated_signal.shape[0] 92 | 93 | if scaling: 94 | num = torch.sum(reference_signal*estimated_signal, dim=-1, keepdim=True) + eps 95 | den = reference_signal.pow(2).sum(-1, keepdim=True) + eps 96 | a = num / den 97 | else: 98 | a = torch.ones_like(reference_signal) 99 | 100 | e_true = a * reference_signal 101 | e_res = estimated_signal - e_true 102 | 103 | Sss = (e_true**2).sum(dim=-1) + eps 104 | Snn = (e_res**2).sum(dim=-1) + eps 105 | 106 | SDR = 10 * torch.log10(Sss / Snn) 107 | return SDR 108 | 109 | class MSELoss(nn.Module): 110 | def __init__(self): 111 | super().__init__() 112 | self.metric = nn.MSELoss() 113 | 114 | def forward(self, preds, target): 115 | preds = preds.permute(0,3,1,2) 116 | target = target.permute(0,3,1,2) 117 | return self.metric(preds, target) 118 | 119 | class FkLoss(nn.Module): 120 | def __init__(self, scale=1., weight=1.): 121 | super().__init__() 122 | self.scale = scale 123 | self.weight = weight 124 | self.metric = nn.L1Loss() 125 | 126 | def forward(self, preds_fk, target_fk): 127 | w = torch.ones_like(target_fk).cumsum(-1).flip(-1) / target_fk.size(-1) 128 | scale = self.scale * w 129 | preds_fk = scale * preds_fk 130 | target_fk = scale * target_fk 131 | #print("fk", self.metric(preds_fk, target_fk)) 132 | return self.weight * self.metric(preds_fk, target_fk) 133 | 134 | class ModeFreqLoss(nn.Module): 135 | def __init__(self, scale=1., weight=1., sr=48000): 136 | super().__init__() 137 | self.sr = sr 138 | self.scale = scale 139 | self.weight = weight 140 | self.metric = nn.L1Loss() 141 | 142 | def forward(self, preds_freq, target_fk): 143 | #w = torch.ones_like(target_fk).cumsum(-1).flip(-1) / target_fk.size(-1) 144 | #scale = self.scale * w 145 | preds_freq = self.scale * preds_freq 146 | target_fk = self.scale * target_fk 147 | return self.weight * self.metric(preds_freq, target_fk) 148 | 149 | class ModeAmpsLoss(nn.Module): 150 | def __init__(self, scale=1., weight=1.): 151 | super().__init__() 152 | self.scale = scale 153 | self.weight = weight 154 | self.metric = nn.L1Loss() 155 | 156 | def forward(self, preds_coef, target_ck): 157 | preds_coef = self.scale * preds_coef 158 | target_ck = self.scale * target_ck 159 | return self.weight * self.metric(preds_coef, target_ck) 160 | 161 | class L1Loss(nn.Module): 162 | def __init__(self, weight=1., scale_invariance=False): 163 | super().__init__() 164 | self.si = scale_invariance 165 | self.weight = weight 166 | self.metric = nn.L1Loss() 167 | 168 | def forward(self, preds, target): 169 | if self.si: 170 | eps = torch.finfo(target.dtype).eps 171 | preds_rms = preds.pow(2).mean(-1, keepdim=True).clamp(min=eps).sqrt() 172 | target_rms = target.pow(2).mean(-1, keepdim=True).clamp(min=eps).sqrt() 173 | preds = preds / preds_rms 174 | target = target / target_rms 175 | return self.weight * self.metric(preds, target) 176 | 177 | class SISDR(nn.Module): 178 | def __init__(self): 179 | super().__init__() 180 | self.metric = si_sdr 181 | 182 | def forward(self, preds, target): 183 | preds = rearrange( preds, 'b (1 t) -> b 1 t') 184 | target = rearrange(target, 'b (1 t) -> b 1 t') 185 | value = self.metric(preds, target, eps=1e-8) 186 | return - value.mean() / 20 187 | 188 | class FFTLoss(nn.Module): 189 | def __init__(self, weight=1.): 190 | super().__init__() 191 | self.weight = weight 192 | self.metric = nn.L1Loss() 193 | 194 | def forward(self, preds, target): 195 | preds = torch.fft.rfft( preds) 196 | target = torch.fft.rfft(target) 197 | return self.weight * self.metric(preds, target) 198 | 199 | class MRSTFT(nn.Module): 200 | def __init__(self, input_scale=5., weight=1., **kwargs): 201 | super().__init__() 202 | self.scale = input_scale 203 | self.weight = weight 204 | self.metric = MultiResolutionSTFTLoss(**kwargs) 205 | 206 | def forward(self, preds, target): 207 | target = target * self.scale 208 | preds = preds * self.scale 209 | if len(list(preds.shape)) == 4: 210 | preds = rearrange(preds, 'b t x c -> b (c x) t') 211 | target = rearrange(target, 'b t x c -> b (c x) t') 212 | elif len(list(preds.shape)) == 2: 213 | preds = preds.unsqueeze(1) 214 | target = target.unsqueeze(1) 215 | else: 216 | assert len(list(preds.shape)) == 3, preds.shape 217 | return self.weight * self.metric(preds, target) 218 | 219 | class PDELoss(nn.Module): 220 | def __init__(self, f_ic, f_bc, f_r, w_ic=1., w_bc=1., w_r=1.): 221 | super().__init__() 222 | self.f_ic = f_ic; self.f_bc = f_bc; self.f_r = f_r 223 | self.w_ic = w_ic; self.w_bc = w_bc; self.w_r = w_r 224 | 225 | def forward(self, 226 | pde_preds, 227 | u0, f0, kappa, alpha, sig0, sig1, 228 | bow_mask, hammer_mask, 229 | x_B, v_B, F_B, ph0_B, ph1_B, wid_B, 230 | x_H, v_H, u_H, w_H, M_H, a_H, 231 | xt, tt, 232 | ): 233 | ''' pde_preds: (batch, time, space, 2) ''' 234 | ut, zt = pde_preds.chunk(2, dim=-1) # (Bs, Nt, Nx, 1) 235 | ut = ut.squeeze(-1) 236 | zt = zt.squeeze(-1) 237 | ms = [bow_mask, hammer_mask] 238 | bp = [x_B, v_B, F_B, ph0_B, ph1_B, wid_B] 239 | hp = [x_H, v_H, u_H, w_H, M_H, a_H] 240 | return pde_loss( 241 | #---------- 242 | ut, zt, u0, f0, kappa, alpha, sig0, sig1, 243 | ms, bp, hp, xt, tt, 244 | #---------- 245 | self.f_ic, self.f_bc, self.f_r, 246 | self.w_ic, self.w_bc, self.w_r, 247 | ) 248 | 249 | class BCLoss(nn.Module): 250 | def __init__(self, weight=1.): 251 | super().__init__() 252 | self.weight = weight 253 | self.metric = nn.L1Loss() 254 | 255 | def forward(self, preds_bc): 256 | target = torch.zeros_like(preds_bc) 257 | return self.weight * self.metric(preds_bc, target) 258 | 259 | class ICLoss(nn.Module): 260 | def __init__(self, weight=1.): 261 | super().__init__() 262 | self.weight = weight 263 | self.metric = nn.L1Loss() 264 | 265 | def forward(self, preds_ic, target_ic): 266 | return self.weight * self.metric(preds_ic, target_ic) 267 | 268 | class F0Loss(nn.Module): 269 | def __init__(self, scale=10., weight=1.): 270 | super().__init__() 271 | self.scale = scale 272 | self.weight = weight 273 | self.metric = nn.L1Loss() 274 | 275 | def forward(self, preds_f0, target_f0): 276 | ''' (Bs, Nt) ''' 277 | target_mean = target_f0.mean() 278 | target_f0 = target_f0 - target_mean 279 | preds_f0 = preds_f0 - target_mean 280 | target_std = target_f0.std() 281 | target_f0 = target_f0 / target_std 282 | preds_f0 = preds_f0 / target_std 283 | 284 | preds_f0 = preds_f0 * self.scale 285 | target_f0 = target_f0 * self.scale 286 | return self.weight * self.metric(preds_f0, target_f0) 287 | 288 | def adv_loss(logits, target): 289 | assert target in [1, 0] 290 | targets = torch.full_like(logits, fill_value=target) 291 | loss = F.binary_cross_entropy_with_logits(logits, targets) 292 | return loss 293 | 294 | class DisLoss(nn.Module): 295 | def __init__(self): 296 | ''' Discriminator Loss ''' 297 | super().__init__() 298 | 299 | def forward(self, real_disc, fake_disc): 300 | loss_real = sum([adv_loss(d, 1) for d in real_disc]) / len(real_disc) 301 | loss_fake = sum([adv_loss(d, 0) for d in fake_disc]) / len(fake_disc) 302 | return loss_real + loss_fake 303 | 304 | class GenLoss(nn.Module): 305 | def __init__(self): 306 | ''' Generator Loss ''' 307 | super().__init__() 308 | 309 | def forward(self, fake_disc): 310 | return sum([adv_loss(d, 1) for d in fake_disc]) / len(fake_disc) 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from scipy.interpolate import RectBivariateSpline 7 | 8 | from contextlib import contextmanager,redirect_stderr,redirect_stdout 9 | from os import devnull 10 | 11 | chars = [c for c in '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'] 12 | 13 | @contextmanager 14 | def suppress_stdout_stderr(): 15 | """A context manager that redirects stdout and stderr to devnull""" 16 | with open(devnull, 'w') as fnull: 17 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: 18 | yield (err, out) 19 | 20 | def batchify(x, batch_size, n_samples): 21 | pass 22 | 23 | def random_str(length=8): 24 | return "".join(np.random.choice(chars, length)) 25 | 26 | def sqrt(x): 27 | return x.pow(.5) if isinstance(x, torch.Tensor) else x**.5 28 | 29 | def soft_bow(v_rel, a=100): 30 | return np.sqrt(2*a) * v_rel * torch.exp(-a * v_rel**2 + 0.5) 31 | 32 | def hard_bow(v_rel, a=5, eps=0.1, hard_sign=True): 33 | sign = torch.sign(v_rel) if hard_sign else torch.tanh(100 * v_rel) 34 | return sign * (eps + (1-eps) * torch.exp(-a * v_rel.abs())) 35 | 36 | def raised_cosine(N, h, ctr, wid, n): 37 | ''' N (int): number of maximal samples in space 38 | h (float): spatial grid cell width 39 | ctr (B,1,1): center points for each batch 40 | wid (B,1,1): width lengths for each batch 41 | n (B,): number of actual samples in space 42 | ''' 43 | xax = torch.linspace(h, 1, N).to(ctr.device).view(1,-1,1) # (1, N, 1) 44 | ctr = (ctr * n / N) 45 | wid = wid / N 46 | ind = torch.sign(torch.relu(-(xax - ctr - wid / 2) * (xax - ctr + wid / 2))) 47 | out = 0.5 * ind * (1 + torch.cos(2 * np.pi * (xax - ctr) / wid)) 48 | return out / out.abs().sum(1, keepdim=True) # (batch_Size, N, 1) 49 | 50 | def floor_dirac_delta(n, ctr, N): 51 | ''' torch::Tensor n, // number of samples in space 52 | torch::Tensor ctr, // center point of raised cosine curve 53 | int N 54 | ''' 55 | xax = torch.ones_like(ctr).view(-1,1,1).repeat(1,N,1).cumsum(1) - 1 56 | idx = torch.floor(ctr * n).view(-1,1,1) 57 | #return torch.floor(xax).eq(idx).to(n.dtype()) # (batch_size, N, 1) 58 | return torch.floor(xax).eq(idx) # (batch_size, N, 1) 59 | 60 | def triangular(N, n, p_x, p_a): 61 | ''' N (int): number of maximal samples in space 62 | n (B, 1, 1): number of actual samples in space 63 | p_x (B, Nt, 1): peak position 64 | p_a (B, Nt, 1): peak amplitude 65 | ''' 66 | vel_l = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / p_x / n) 67 | vel_r = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / (1-p_x) / n) 68 | vel_l = ((vel_l * torch.ones_like(vel_l).repeat(1,1,N)).cumsum(2) - vel_l).clamp(min=0) 69 | vel_r = ((vel_r * torch.ones_like(vel_r).repeat(1,1,N)).cumsum(2) - vel_r * (N-n+1)).clamp(min=0).flip(2) 70 | tri = torch.minimum(vel_l, vel_r) 71 | assert not torch.isnan(tri).any(), torch.isnan(tri.flatten(1).sum(1)) 72 | return tri 73 | 74 | def pre_shaper(x, sr, velocity=10): 75 | w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity) 76 | return w * x 77 | 78 | def post_shaper(x, sr, pulloff, velocity=100): 79 | offset = x.size(-1) - int(sr * pulloff) 80 | w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity).flip(-1) 81 | w = F.pad(w.narrow(-1,offset,w.size(-1)-offset), (0,offset)) 82 | return w * x 83 | 84 | def random_uniform(floor, ceiling, size=None, weight=None, dtype=None): 85 | if not isinstance(size, tuple): size = (size,) 86 | if weight is None: weight = torch.ones(size, dtype=dtype) 87 | # NOTE: torch.rand(..., dtype=dtype) for dtype \in [torch.float32, torch.float64] 88 | # can result in different random number generation 89 | # (for different precisions; despite fixiing the random seed.) 90 | return (ceiling - floor) * torch.rand(size=size).to(dtype) * weight + floor 91 | 92 | def equidistant(floor, ceiling, steps, dtype=None): 93 | return torch.linspace(floor, ceiling, steps).to(dtype) 94 | 95 | def get_masks(model_name, bs, disjoint=True): 96 | ''' setting `disjoint=False` enables multiple excitations allowed 97 | (e.g., bowing over hammered strings.) While this could be a 98 | charming choice, but it can also drive the simulation unstable. 99 | ''' 100 | # boolean mask that determines whether to impose each excitation 101 | if model_name.endswith('bow'): 102 | bow_mask = torch.ones( size=(bs,)).view(-1,1,1) 103 | hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1) 104 | elif model_name.endswith('hammer'): 105 | bow_mask = torch.zeros(size=(bs,)).view(-1,1,1) 106 | hammer_mask = torch.ones( size=(bs,)).view(-1,1,1) 107 | elif model_name.endswith('pluck'): 108 | bow_mask = torch.zeros(size=(bs,)).view(-1,1,1) 109 | hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1) 110 | else: 111 | bow_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1) 112 | hammer_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1) 113 | if disjoint: 114 | both_are_true = torch.logical_and( 115 | torch.logical_or(bow_mask, hammer_mask), 116 | torch.logical_or(bow_mask, hammer_mask.logical_not()) 117 | ) 118 | hammer_mask[both_are_true] = False 119 | bow_mask = bow_mask.view(-1,1,1) 120 | hammer_mask = hammer_mask.view(-1,1,1) 121 | return [bow_mask, hammer_mask] 122 | 123 | def f0_interpolate(f0_1, n_frames, tmax): 124 | t_0 = np.linspace(0, tmax, n_frames) 125 | t_1 = np.linspace(0, tmax, f0_1.shape[0]) 126 | return np.interp(t_0, t_1, f0_1) 127 | 128 | def interpolate1d(u, xaxis, xvals, k=5): 129 | ''' u: (1, Nx) 130 | xaxis: (1, Nx_input) 131 | xvals: (1, Nx_output) 132 | -> (1, Nx_output) 133 | ''' 134 | t = np.arange(k)[:,None] / k 135 | rbs = RectBivariateSpline(t, xaxis, u.repeat(k,0), kx=1, ky=k) 136 | return rbs(t, xvals, grid=True)[k//2][None,:] 137 | 138 | def interpolate(u, taxis, xaxis, xvals, kx=5, ky=5): 139 | ''' u: (Nt, Nx) 140 | taxis: (Nt, 1) 141 | xaxis: (1, Nx_input) 142 | xvals: (1, Nx_output) 143 | -> (Nt, Nx_output) 144 | ''' 145 | rbs = RectBivariateSpline(taxis, xaxis, u, kx=kx, ky=ky) 146 | return rbs(taxis, xvals, grid=True) 147 | 148 | def torch_interpolate(x, scale_factor): 149 | y = F.interpolate(x, scale_factor=scale_factor) 150 | res = x.size(-1) - y.size(-1) 151 | if res % 2 == 0: y = F.pad(y, (res//2, res//2)) 152 | else: y = F.pad(y, (res//2, res//2+1)) 153 | return y 154 | 155 | 156 | def minmax_normalize(x, dim=-1): 157 | x_min = x.min(dim, keepdim=True).values 158 | x = x - x_min 159 | x_max = x.max(dim, keepdim=True).values 160 | x = x / x_max 161 | return x 162 | 163 | def get_minmax(x): 164 | if np.isnan(x.sum()): 165 | return None, None 166 | return np.nan_to_num(x.min()), np.nan_to_num(x.max()) 167 | 168 | def select_with_batched_index(input, dim, index): 169 | ''' input: (bs, ..., n, ...) 170 | dim : (int) 171 | index: (bs, ..., 1, ...) index to select on dim `dim` 172 | -> out : (bs, ..., 1, ...) for each batch, select `index`-th element on dim `dim` 173 | ''' 174 | assert input.size(0) == index.size(0), [input.shpae, index.shape] 175 | bs = input.size(0) 176 | ins = input.chunk(bs, 0) 177 | idx = index.chunk(bs, 0) 178 | out = [] 179 | for b in range(bs): 180 | out.append(batched_index_select(ins[b], dim, idx[b])) 181 | return torch.cat(out, dim=0) 182 | 183 | def batched_index_select(input, dim, index): 184 | ''' input: (..., n, ...) 185 | dim : (int) 186 | index: (..., k, ...) index to select on dim `dim` 187 | -> out : (..., k, ...) select k out of n elements on dim `dim` 188 | ''' 189 | Nx = len(list(input.shape)) 190 | expanse = [-1 if k==(dim % Nx) else 1 for k in range(Nx)] 191 | tiler = [1 if k==(dim % Nx) else n for k, n in enumerate(input.shape)] 192 | index = index.to(torch.int64).view(expanse).tile(tiler) 193 | return torch.gather(input, dim, index) 194 | 195 | def random_index(max_N, idx_N): 196 | if max_N < idx_N: 197 | # choosing with replacement 198 | return torch.randint(0, max_N, (idx_N,)) 199 | else: 200 | # choosing without replacement 201 | return torch.randperm(max_N)[:idx_N] 202 | 203 | def ell_infty_normalize(x, normalize_dims=1): 204 | eps = torch.finfo(x.dtype).eps 205 | x_shape = list(x.shape) 206 | m_shape = x_shape[:normalize_dims] + [1] * (len(x_shape) - normalize_dims) 207 | x_max = x.abs().flatten(normalize_dims).max(normalize_dims).values + eps 208 | x_gain = 1. / x_max.view(m_shape) 209 | return x * x_gain, x_gain 210 | 211 | def sinusoidal_embedding(x, n, gain=10000, dim=-1): 212 | ''' let `x` be normalized to be in the nondimensional (0 ~ 1) range ''' 213 | assert n % 2 == 0, n 214 | x = x.unsqueeze(-1) 215 | shape = [1] * len(list(x.shape)); shape[dim] = -1 # e.g., [1,1,-1] 216 | half_n = n // 2 217 | 218 | expnt = torch.arange(half_n, device=x.device, dtype=x.dtype).view(shape) 219 | _embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1))) 220 | _embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1))) 221 | _embed = x * _embed 222 | emb = torch.cat((torch.sin(_embed), torch.cos(_embed)), dim) 223 | return emb # list(x.shape) + [n] 224 | 225 | def fourier_feature(x, B): 226 | ''' x: (Bs, ..., in_dim) 227 | B: (in_dim, out_dim) 228 | ''' 229 | if B is None: 230 | return x 231 | else: 232 | x_proj = (2.*np.pi*x) @ B 233 | return torch.cat((torch.sin(x_proj), torch.cos(x_proj)), dim=-1) 234 | 235 | def save_simulation_data(directory, excitation_type, overall_results, constants): 236 | os.makedirs(directory, exist_ok=True) 237 | string_params = overall_results.pop('string_params') 238 | hammer_params = overall_results.pop('hammer_params') 239 | bow_params = overall_results.pop('bow_params') 240 | simulation_dict = overall_results 241 | string_dict = { 242 | 'kappa': string_params[0], 243 | 'alpha': string_params[1], 244 | 'u0' : string_params[2], 245 | 'v0' : string_params[3], 246 | 'p_a' : string_params[4], 247 | 'f0' : string_params[5], 248 | 'pos' : string_params[6], 249 | 'T60' : string_params[7], 250 | 'target_f0': string_params[8], 251 | } 252 | hammer_dict = { 253 | 'x_H' : hammer_params[0], 254 | 'v_H' : hammer_params[1], 255 | 'u_H' : hammer_params[2], 256 | 'w_H' : hammer_params[3], 257 | 'M_r' : hammer_params[4], 258 | 'alpha': hammer_params[5], 259 | } 260 | bow_dict = { 261 | 'x_B' : bow_params[0], 262 | 'v_B' : bow_params[1], 263 | 'F_B' : bow_params[2], 264 | 'phi_0': bow_params[3], 265 | 'phi_1': bow_params[4], 266 | 'wid_B': bow_params[5], 267 | } 268 | 269 | def sample(val): 270 | try: 271 | _val = val.item(0) 272 | except AttributeError as err: 273 | if isinstance(val, float) or isinstance(val, int): 274 | _val = val 275 | else: 276 | raise err 277 | return _val 278 | short_configuration = { 279 | 'excitation_type': excitation_type, 280 | 'theta_t' : constants[1], 281 | 'lambda_c': constants[2], 282 | } 283 | short_configuration['value-string'] = {} 284 | for key, val in string_dict.items(): 285 | short_configuration['value-string'].update({ key : sample(val) }) 286 | short_configuration['value-hammer'] = {} 287 | for key, val in hammer_dict.items(): 288 | short_configuration['value-hammer'].update({ key : sample(val) }) 289 | short_configuration['value-bow'] = {} 290 | for key, val in bow_dict.items(): 291 | short_configuration['value-bow'].update({ key : sample(val) }) 292 | 293 | np.savez_compressed(f'{directory}/simulation.npz', **simulation_dict) 294 | np.savez_compressed(f'{directory}/string_params.npz', **string_dict) 295 | np.savez_compressed(f'{directory}/hammer_params.npz', **hammer_dict) 296 | np.savez_compressed(f'{directory}/bow_params.npz', **bow_dict) 297 | 298 | with open(f"{directory}/simulation_config.yaml", 'w') as f: 299 | yaml.dump(short_configuration, f, default_flow_style=False) 300 | 301 | def add_noise(x, c, vals, eps=1e-5): 302 | noise = eps * torch.randn_like(x) 303 | for val in vals: 304 | mask = torch.where(c == val, torch.ones_like(c), torch.zeros_like(c)) 305 | x = x + mask * noise 306 | return x 307 | 308 | def downsample(x, factor=None, size=None): 309 | ''' x: (Bs, Nt) -> (Bs, Nt // factor) 310 | ''' 311 | if size is None: 312 | size = x.size(1) // factor + bool(x.size(1) % factor) 313 | else: 314 | assert factor is None, [factor, size] 315 | return F.interpolate(x.unsqueeze(1), size=size, mode='linear').squeeze(1) 316 | 317 | 318 | if __name__=='__main__': 319 | N = 10 320 | B = 1 321 | h = 1 / N 322 | ctr = 0.5 * torch.ones(B).view(-1,1,1) 323 | wid = 1 * torch.ones(B).view(-1,1,1) 324 | n = N * torch.ones(B) 325 | ''' N (int): number of maximal samples in space 326 | h (float): spatial grid cell width 327 | ctr (B,1,1): center points for each batch 328 | wid (B,1,1): width lengths for each batch 329 | n (B,): number of actual samples in space 330 | ''' 331 | c = raised_cosine(N, h, ctr, wid, n) 332 | print(c.shape) 333 | import matplotlib.pyplot as plt 334 | plt.figure() 335 | plt.plot(c[0,:,0]) 336 | plt.savefig('asdf.png') 337 | 338 | -------------------------------------------------------------------------------- /src/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class NoamLR(_LRScheduler): 9 | def __init__(self, optimizer, warmup_steps, model_size=None): 10 | self.warmup_steps = warmup_steps 11 | 12 | self.model_size = model_size 13 | 14 | super().__init__(optimizer) 15 | 16 | def get_lr(self): 17 | last_epoch = max(1, self.last_epoch) 18 | scale = min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)) 19 | if self.model_size is None: 20 | scale = scale * self.warmup_steps ** (0.5) 21 | return [base_lr * scale for base_lr in self.base_lrs] 22 | else: 23 | return [self.model_size ** (-0.5) * scale for _ in self.base_lrs] 24 | 25 | 26 | class Novograd(torch.optim.Optimizer): 27 | """ 28 | Implements Novograd algorithm. 29 | 30 | Args: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float, optional): learning rate (default: 1e-3) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.95, 0)) 36 | eps (float, optional): term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 39 | grad_averaging: gradient averaging 40 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 41 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 42 | (default: False) 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8, 46 | weight_decay=0, grad_averaging=False, amsgrad=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay, 57 | grad_averaging=grad_averaging, 58 | amsgrad=amsgrad) 59 | 60 | super(Novograd, self).__init__(params, defaults) 61 | 62 | def __setstate__(self, state): 63 | super(Novograd, self).__setstate__(state) 64 | for group in self.param_groups: 65 | group.setdefault('amsgrad', False) 66 | 67 | def step(self, closure=None): 68 | """Performs a single optimization step. 69 | 70 | Arguments: 71 | closure (callable, optional): A closure that reevaluates the model 72 | and returns the loss. 73 | """ 74 | loss = None 75 | if closure is not None: 76 | loss = closure() 77 | 78 | for group in self.param_groups: 79 | for p in group['params']: 80 | if p.grad is None: 81 | continue 82 | grad = p.grad.data 83 | if grad.is_sparse: 84 | raise RuntimeError('Sparse gradients are not supported.') 85 | amsgrad = group['amsgrad'] 86 | 87 | state = self.state[p] 88 | 89 | # State initialization 90 | if len(state) == 0: 91 | state['step'] = 0 92 | # Exponential moving average of gradient values 93 | state['exp_avg'] = torch.zeros_like(p.data) 94 | # Exponential moving average of squared gradient values 95 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 96 | if amsgrad: 97 | # Maintains max of all exp. moving avg. of sq. grad. values 98 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 99 | 100 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 101 | if amsgrad: 102 | max_exp_avg_sq = state['max_exp_avg_sq'] 103 | beta1, beta2 = group['betas'] 104 | 105 | state['step'] += 1 106 | 107 | norm = torch.sum(torch.pow(grad, 2)) 108 | 109 | if exp_avg_sq == 0: 110 | exp_avg_sq.copy_(norm) 111 | else: 112 | exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) 113 | 114 | if amsgrad: 115 | # Maintains the maximum of all 2nd moment running avg. till now 116 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 117 | # Use the max. for normalizing running avg. of gradient 118 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 119 | else: 120 | denom = exp_avg_sq.sqrt().add_(group['eps']) 121 | 122 | grad.div_(denom) 123 | if group['weight_decay'] != 0: 124 | grad.add_(p.data, alpha=group['weight_decay']) 125 | if group['grad_averaging']: 126 | grad.mul_(1 - beta1) 127 | exp_avg.mul_(beta1).add_(grad) 128 | 129 | p.data.add_(exp_avg, alpha=-group['lr']) 130 | 131 | return loss 132 | 133 | 134 | def get_optimizer(optimizer_name, model_parameters, config): 135 | 136 | if optimizer_name == 'sgd': 137 | optimizer = torch.optim.SGD(model_parameters, **config) 138 | elif optimizer_name == 'adam': 139 | optimizer = torch.optim.Adam(model_parameters, **config) 140 | elif optimizer_name == "adamw": 141 | optimizer = torch.optim.AdamW(model_parameters, **config) 142 | elif optimizer_name == "radam": 143 | optimizer = torch.optim.RAdam(model_parameters, **config) 144 | elif optimizer_name == "novograd": 145 | optimizer = Novograd(model_parameters, **config) 146 | else: 147 | print('Unknown optimizer', optimizer_name) 148 | sys.exit() 149 | return optimizer 150 | 151 | def get_scheduler(scheduler_name, optimizer, config): 152 | if scheduler_name == 'step': 153 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **config) 154 | elif scheduler_name == 'multistep': 155 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **config) 156 | elif scheduler_name == 'sgdr': 157 | scheduler = SGDRScheduler(optimizer, **config) 158 | elif scheduler_name == 'lambda_lr': 159 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, **config) 160 | elif scheduler_name == 'reduce_on_plateau': 161 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **config) 162 | elif scheduler_name == 'cosine': 163 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **config) 164 | elif scheduler_name == 'noam': 165 | scheduler = NoamLR(optimizer, **config) 166 | elif scheduler_name == 'constant': 167 | scheduler = None 168 | else: 169 | print('Unknown scheduler', scheduler_name) 170 | sys.exit() 171 | return scheduler 172 | 173 | 174 | -------------------------------------------------------------------------------- /src/utils/vnv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def RDE(f_1, f_exact, dim=None): 4 | num = f_1 - f_exact 5 | den = f_exact 6 | num = num.sum(dim) if dim is not None else num 7 | den = den.sum(dim) if dim is not None else den 8 | return 100 * num / den 9 | 10 | --------------------------------------------------------------------------------