├── neuralpiano ├── artwork │ ├── __init__.py │ ├── Neural-Piano-Logo.png │ ├── Neural-Piano-Sign.png │ ├── Tegridy-Code-2025.png │ ├── Neural-Piano-Artwork.png │ ├── Project-Los-Angeles.png │ └── README.md ├── output_samples │ ├── __init__.py │ └── README.md ├── seed_midis │ ├── __init__.py │ ├── README.md │ ├── neural_piano_seed_midi_1.mid │ ├── neural_piano_seed_midi_2.mid │ ├── neural_piano_seed_midi_3.mid │ ├── neural_piano_seed_midi_4.mid │ ├── neural_piano_seed_midi_5.mid │ └── neural_piano_seed_midi_6.mid ├── music2latent │ ├── __init__.py │ ├── hparams_inference.py │ ├── README.md │ ├── hparams.py │ ├── utils.py │ ├── audio.py │ ├── inference.py │ └── models.py ├── __init__.py ├── README.md ├── sample_midis.py ├── denoise.py ├── bass.py ├── mixer.py ├── enhancer.py ├── master.py └── neuralpiano.py ├── requirements.txt ├── MANIFEST.in ├── README.md ├── pyproject.toml └── LICENSE /neuralpiano/artwork/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neuralpiano/output_samples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neuralpiano/seed_midis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neuralpiano/output_samples/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /neuralpiano/music2latent/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import EncoderDecoder -------------------------------------------------------------------------------- /neuralpiano/artwork/Neural-Piano-Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Logo.png -------------------------------------------------------------------------------- /neuralpiano/artwork/Neural-Piano-Sign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Sign.png -------------------------------------------------------------------------------- /neuralpiano/artwork/Tegridy-Code-2025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Tegridy-Code-2025.png -------------------------------------------------------------------------------- /neuralpiano/seed_midis/README.md: -------------------------------------------------------------------------------- 1 | # Neural Piano sample seed MIDIs 2 | 3 | *** 4 | 5 | ### Project Los Angeles 6 | ### Tegridy Code 2025 7 | -------------------------------------------------------------------------------- /neuralpiano/artwork/Neural-Piano-Artwork.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Artwork.png -------------------------------------------------------------------------------- /neuralpiano/artwork/Project-Los-Angeles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Project-Los-Angeles.png -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_1.mid -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_2.mid -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_3.mid -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_4.mid -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_5.mid -------------------------------------------------------------------------------- /neuralpiano/seed_midis/neural_piano_seed_midi_6.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_6.mid -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | ipywidgets 3 | matplotlib 4 | hf-transfer 5 | huggingface_hub 6 | soundfile 7 | torch 8 | numpy==1.24.4 9 | librosa 10 | midirenderer -------------------------------------------------------------------------------- /neuralpiano/artwork/README.md: -------------------------------------------------------------------------------- 1 | # Neural Piano Concept Artwork 2 | 3 | *** 4 | 5 | ## Images were created with Stable Diffusion 3.5 Large image AI model 6 | 7 | *** 8 | 9 | ### Project Los Angeles 10 | ### Tegridy Code 2025 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include neuralpiano/README.md 2 | recursive-include neuralpiano/ * 3 | recursive-include neuralpiano/artwork * 4 | recursive-include neuralpiano/output_samples * 5 | recursive-include neuralpiano/seed_midis * 6 | recursive-include neuralpiano/music2latent * -------------------------------------------------------------------------------- /neuralpiano/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_midis import get_sample_midi_files 2 | 3 | from .music2latent.inference import EncoderDecoder 4 | 5 | from .denoise import denoise_audio 6 | 7 | from .bass import enhance_audio_bass 8 | 9 | from .enhancer import enhance_audio_full 10 | 11 | from .master import master_mono_piano 12 | 13 | from .mixer import mix_audio 14 | 15 | from .neuralpiano import * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP] Neural Piano 2 | ## Hi-Fi neural MIDI piano synthesizer and MIDI renderer 3 | 4 | Neural-Piano-Artwork 5 | 6 | *** 7 | 8 | ## Installation 9 | 10 | ### pip and setuptools 11 | 12 | ```sh 13 | # It is recommended that you upgrade pip and setuptools prior to install for max compatibility 14 | !pip install --upgrade pip 15 | !pip install --upgrade setuptools build wheel 16 | ``` 17 | 18 | ### pip install 19 | 20 | ```sh 21 | # The following command will install Neural Piano pip package 22 | # Please note that Neural Piano requires Nvidia GPU with at least 40GB VRAM 23 | 24 | !pip install -U neuralpiano 25 | ``` 26 | 27 | *** 28 | 29 | ## Quick-start use example 30 | 31 | ```python 32 | # Import main Neural Piano module 33 | import neuralpiano 34 | 35 | # Render MIDI 36 | neuralpiano.render_midi('input.mid', 'output.wav') 37 | ``` 38 | 39 | *** 40 | 41 | ### Project Los Angeles 42 | ### Tegridy Code 2025 43 | -------------------------------------------------------------------------------- /neuralpiano/README.md: -------------------------------------------------------------------------------- 1 | # [WIP] Neural Piano 2 | ## Hi-Fi neural MIDI piano synthesizer and MIDI renderer 3 | 4 | Neural-Piano-Artwork 5 | 6 | *** 7 | 8 | ## Installation 9 | 10 | ### pip and setuptools 11 | 12 | ```sh 13 | # It is recommended that you upgrade pip and setuptools prior to install for max compatibility 14 | !pip install --upgrade pip 15 | !pip install --upgrade setuptools build wheel 16 | ``` 17 | 18 | ### pip install 19 | 20 | ```sh 21 | # The following command will install Neural Piano pip package 22 | # Please note that Neural Piano requires Nvidia GPU with at least 40GB VRAM 23 | 24 | !pip install -U neuralpiano 25 | ``` 26 | 27 | *** 28 | 29 | ## Quick-start use example 30 | 31 | ```python 32 | # Import main Neural Piano module 33 | import neuralpiano 34 | 35 | # Render MIDI 36 | neuralpiano.render_midi('input.mid', 'output.wav') 37 | ``` 38 | 39 | *** 40 | 41 | ### Project Los Angeles 42 | ### Tegridy Code 2025 43 | -------------------------------------------------------------------------------- /neuralpiano/music2latent/hparams_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | home_root = os.getcwd() 4 | 5 | load_path_inference_multi_instrumental_default = os.path.join(home_root, 'models/music2latent.pt') 6 | load_path_inference_solo_piano_default = os.path.join(home_root, 'models/music2latent_maestro_loss_16.871_iters_45500.pt') 7 | load_path_inference_solo_piano_v1_default = os.path.join(home_root, 'models/music2latent_maestro_loss_27.834_iters_14300.pt') 8 | 9 | max_batch_size_encode = 1 # maximum inference batch size for encoding: tune it depending on the available GPU memory 10 | max_waveform_length_encode = 44100*60 # maximum length of waveforms in the batch for encoding: tune it depending on the available GPU memory 11 | max_batch_size_decode = 1 # maximum inference batch size for decoding: tune it depending on the available GPU memory 12 | max_waveform_length_decode = 44100*60 # maximum length of waveforms in the batch for decoding: tune it depending on the available GPU memory 13 | 14 | sigma_rescale = 0.06 # rescale sigma for inference -------------------------------------------------------------------------------- /neuralpiano/sample_midis.py: -------------------------------------------------------------------------------- 1 | #=================================================================================================== 2 | # Neural Piano sample_midis Python module 3 | #=================================================================================================== 4 | # Project Los Angeles 5 | # Tegridy Code 2025 6 | #=================================================================================================== 7 | # License: Apache 2.0 8 | #=================================================================================================== 9 | 10 | import importlib.resources as pkg_resources 11 | from neuralpiano import seed_midis 12 | 13 | #=================================================================================================== 14 | 15 | def get_sample_midi_files(): 16 | 17 | midi_files = [] 18 | 19 | for resource in pkg_resources.contents(seed_midis): 20 | if resource.endswith('.mid'): 21 | with pkg_resources.path(seed_midis, resource) as p: 22 | midi_files.append((resource, str(p))) 23 | 24 | return sorted(midi_files) 25 | 26 | #=================================================================================================== 27 | # This is the end of the sample_midis Python module 28 | #=================================================================================================== -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=80.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "neuralpiano" 7 | version = "25.11.37" 8 | description = "Hi-Fi neural MIDI piano synthesizer and MIDI renderer" 9 | license = "Apache-2.0" 10 | license-files = ["LICENSE"] 11 | authors = [ 12 | { name = "Alex Lev", email = "alexlev61@proton.me" } 13 | ] 14 | maintainers = [ 15 | { name = "Alex Lev", email = "alexlev61@proton.me" } 16 | ] 17 | dependencies = [ 18 | "tqdm", 19 | "ipywidgets", 20 | "matplotlib", 21 | "hf-transfer", 22 | "huggingface_hub", 23 | "soundfile", 24 | "torch", 25 | "numpy==1.24.4", 26 | "librosa", 27 | "midirenderer" 28 | ] 29 | keywords = ["MIDI", "music", "music ai", "MIDI piano", "Piano synthesizer", "renderer", "MIDI synthesizer", "MIDI renderer", "neural synthesizer", "artificial intelligence", "AI", "audio synthesis"] 30 | classifiers = [ 31 | "Development Status :: 5 - Production/Stable", 32 | "Intended Audience :: Science/Research", 33 | "Topic :: Multimedia :: Sound/Audio :: MIDI", 34 | "Operating System :: OS Independent", 35 | "Programming Language :: Python :: 3.6", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | "Programming Language :: Python :: 3.9", 39 | "Programming Language :: Python :: 3.10", 40 | "Programming Language :: Python :: 3.11", 41 | "Programming Language :: Python :: 3.12" 42 | ] 43 | requires-python = ">=3.6" 44 | 45 | [project.readme] 46 | file = "neuralpiano/README.md" 47 | content-type = "text/markdown" 48 | 49 | [project.urls] 50 | Homepage = "https://github.com/asigalov61/neuralpiano" 51 | SoundCloud = "https://soundcloud.com/aleksandr-sigalov-61/" 52 | Samples = "https://github.com/asigalov61/neuralpiano/tree/main/neuralpiano/output_samples" 53 | Documentation = "https://github.com/asigalov61/neuralpiano" 54 | Issues = "https://github.com/asigalov61/neuralpiano/issues" 55 | Discussions = "https://github.com/asigalov61/neuralpiano/discussions" 56 | Dataset = "https://magenta.withgoogle.com/datasets/maestro" 57 | Demo = "https://huggingface.co/datasets/projectlosangeles/Neural-Piano-Synthesizer" 58 | -------------------------------------------------------------------------------- /neuralpiano/music2latent/README.md: -------------------------------------------------------------------------------- 1 | # Music2Latent 2 | Encode and decode audio samples to/from compressed representations! Useful for efficient generative modelling applications and for other downstream tasks. 3 | 4 | ![music2latent](music2latent.png) 5 | 6 | Read the ISMIR 2024 paper [here](https://arxiv.org/abs/2408.06500). 7 | Listen to audio samples [here](https://sonycslparis.github.io/music2latent-companion/). 8 | 9 | Under the hood, __Music2Latent__ uses a __Consistency Autoencoder__ model to efficiently encode and decode audio samples. 10 | 11 | 44.1 kHz audio is encoded into a sequence of __~10 Hz__, and each of the latents has 64 channels. 12 | 48 kHz audio can also be encoded, which results in a sequence of ~12 Hz. 13 | A generative model can then be trained on these embeddings, or they can be used for other downstream tasks. 14 | 15 | Music2Latent was trained on __music__ and on __speech__. Refer to the [paper](https://arxiv.org/abs/2408.06500) for more details. 16 | 17 | 18 | ## Installation 19 | 20 | ```bash 21 | pip install music2latent 22 | ``` 23 | The model weights will be downloaded automatically the first time the code is run. 24 | 25 | 26 | ## How to use 27 | To encode and decode audio samples to/from latent embeddings: 28 | ```bash 29 | audio_path = librosa.example('trumpet') 30 | wv, sr = librosa.load(audio_path, sr=44100) 31 | 32 | from music2latent import EncoderDecoder 33 | encdec = EncoderDecoder() 34 | 35 | latent = encdec.encode(wv) 36 | # latent has shape (batch_size/audio_channels, dim (64), sequence_length) 37 | 38 | wv_rec = encdec.decode(latent) 39 | ``` 40 | To extract encoder features to use in downstream tasks: 41 | ```bash 42 | features = encoder.encode(wv, extract_features=True) 43 | ``` 44 | These features are extracted before the encoder bottleneck, and thus have more channels (contain more information) than the latents used for reconstruction. It will not be possible to directly decode these features back to audio. 45 | 46 | music2latent supports more advanced usage, including GPU memory management controls. Please refer to __tutorial.ipynb__. 47 | 48 | 49 | ## License 50 | This library is released under the CC BY-NC 4.0 license. Please refer to the LICENSE file for more details. 51 | 52 | 53 | 54 | This work was conducted by [Marco Pasini](https://twitter.com/marco_ppasini) during his PhD at Queen Mary University of London, in partnership with Sony Computer Science Laboratories Paris. 55 | This work was supervised by Stefan Lattner and George Fazekas. -------------------------------------------------------------------------------- /neuralpiano/music2latent/hparams.py: -------------------------------------------------------------------------------- 1 | # GENERAL 2 | mixed_precision = True # use mixed precision 3 | seed = 42 # seed for Pytorch 4 | 5 | # DATA 6 | data_channels = 2 # channels of input data 7 | data_length = 64 # sequence length of input data 8 | data_length_test = 1024//4 # sequence length of data used for testing 9 | sample_rate = 44100 # sampling rate of input/output audio 10 | 11 | hop = 128*4 # hop size of transformation 12 | 13 | alpha_rescale = 0.65 14 | beta_rescale = 0.34 15 | sigma_data = 0.5 16 | 17 | 18 | 19 | # MODEL 20 | base_channels = 64 # base channel number for architecture 21 | layers_list = [2,2,2,2,2] # number of blocks per each resolution level 22 | multipliers_list = [1,2,4,4,4] # base channels multipliers for each resolution level 23 | attention_list = [0,0,1,1,1] # for each resolution, 0 if no attention is performed, 1 if attention is performed 24 | freq_downsample_list = [1,0,0,0] # for each resolution, 0 if frequency 4x downsampling, 1 if standard frequency 2x and time 2x downsampling 25 | 26 | layers_list_encoder = [1,1,1,1,1] # number of blocks per each resolution level 27 | attention_list_encoder = [0,0,1,1,1] # for each resolution, 0 if no attention is performed, 1 if attention is performed 28 | bottleneck_base_channels = 512 # base channels to use for block before/after bottleneck 29 | num_bottleneck_layers = 4 # number of blocks to use before/after bottleneck 30 | frequency_scaling = True 31 | 32 | heads = 4 # number of attention heads 33 | cond_channels = 256 # dimension of time embedding 34 | use_fourier = False # if True, use random Fourier embedding, if False, use Positional 35 | fourier_scale = 0.2 # scale parameter for gaussian fourier layer (original is 0.02, but to me it appears too small) 36 | normalization = True # use group normalization 37 | dropout_rate = 0. # dropout rate 38 | min_res_dropout = 16 # dropout is applied on equal or smaller feature map resolutions 39 | init_as_zero = True # initialize convolution kernels before skip connections with zeros 40 | 41 | bottleneck_channels = 32*2 # channels of encoder bottleneck 42 | 43 | pre_normalize_2d_to_1d = True # pre-normalize 2D to 1D connection in encoder 44 | pre_normalize_downsampling_encoder = True # pre-normalize downsampling layers in encoder 45 | 46 | sigma_min = 0.002 # minimum sigma 47 | sigma_max = 80. # maximum sigma 48 | rho = 7. # rho parameter for sigma schedule -------------------------------------------------------------------------------- /neuralpiano/music2latent/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path as ospath 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | import torch 7 | 8 | from .hparams import * 9 | 10 | # Get scaling coefficients c_skip, c_out, c_in based on noise sigma 11 | # These are used to scale the input and output of the consistency model, while satisfying the boundary condition for consistency models 12 | # Parameters: 13 | # sigma: noise level 14 | # Returns: 15 | # c_skip, c_out, c_in: scaling coefficients 16 | def get_c(sigma): 17 | sigma_correct = sigma_min 18 | c_skip = (sigma_data**2.)/(((sigma-sigma_correct)**2.) + (sigma_data**2.)) 19 | c_out = (sigma_data*(sigma-sigma_correct))/(((sigma_data**2.) + (sigma**2.))**0.5) 20 | c_in = 1./(((sigma**2.)+(sigma_data**2.))**0.5) 21 | return c_skip.reshape(-1,1,1,1), c_out.reshape(-1,1,1,1), c_in.reshape(-1,1,1,1) 22 | 23 | # Get noise level sigma_i based on index i and number of discretization steps k 24 | # Parameters: 25 | # i: index 26 | # k: number of discretization steps 27 | # Returns: 28 | # sigma_i: noise level corresponding to index i 29 | def get_sigma(i, k): 30 | return (sigma_min**(1./rho) + ((i-1)/(k-1))*(sigma_max**(1./rho)-sigma_min**(1./rho)))**rho 31 | 32 | # Get noise level sigma for a continuous index i in [0, 1] 33 | # Follows parameterization in https://openreview.net/pdf?id=FmqFfMTNnv 34 | # Parameters: 35 | # i: continuous index in [0, 1] 36 | # Returns: 37 | # sigma: corresponding noise level 38 | def get_sigma_continuous(i): 39 | return (sigma_min**(1./rho) + i*(sigma_max**(1./rho)-sigma_min**(1./rho)))**rho 40 | 41 | 42 | # Add Gaussian noise to input x based on given noise and sigma 43 | # Parameters: 44 | # x: input tensor 45 | # noise: tensor containing Gaussian noise 46 | # sigma: noise level 47 | # Returns: 48 | # x_noisy: x with noise added 49 | def add_noise(x, noise, sigma): 50 | return x + sigma.reshape(-1,1,1,1)*noise 51 | 52 | 53 | # Reverse the probability flow ODE by one step 54 | # Parameters: 55 | # x: input 56 | # noise: Gaussian noise 57 | # sigma: noise level 58 | # Returns: 59 | # x: x after reversing ODE by one step 60 | def reverse_step(x, noise, sigma): 61 | return x + ((sigma**2 - sigma_min**2)**0.5)*noise 62 | 63 | 64 | # Denoise samples at a given noise level 65 | # Parameters: 66 | # model: consistency model 67 | # noisy_samples: input noisy samples 68 | # sigma: noise level 69 | # Returns: 70 | # pred_noises: predicted noise 71 | # pred_samples: denoised samples 72 | def denoise(model, noisy_samples, sigma, latents=None): 73 | # Denoise samples 74 | with torch.no_grad(): 75 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): 76 | if latents is not None: 77 | pred_samples = model(latents, noisy_samples, sigma) 78 | else: 79 | pred_samples = model(noisy_samples, sigma) 80 | # Sample noise 81 | pred_noises = torch.randn_like(pred_samples) 82 | return pred_noises, pred_samples 83 | 84 | # Reverse the diffusion process to generate samples 85 | # Parameters: 86 | # model: trained consistency model 87 | # initial_noise: initial noise to start from 88 | # diffusion_steps: number of steps to reverse 89 | # Returns: 90 | # final_samples: generated samples 91 | def reverse_diffusion(model, initial_noise, diffusion_steps, latents=None): 92 | next_noisy_samples = initial_noise 93 | # Reverse process step-by-step 94 | for k in range(diffusion_steps): 95 | 96 | # Get sigma values 97 | sigma = get_sigma(diffusion_steps+1-k, diffusion_steps+1) 98 | next_sigma = get_sigma(diffusion_steps-k, diffusion_steps+1) 99 | 100 | # Denoise 101 | noisy_samples = next_noisy_samples 102 | pred_noises, pred_samples = denoise(model, noisy_samples, sigma, latents) 103 | 104 | # Step to next (lower) noise level 105 | next_noisy_samples = reverse_step(pred_samples, pred_noises, next_sigma) 106 | 107 | return pred_samples.detach().cpu() 108 | 109 | 110 | def is_path(variable): 111 | return isinstance(variable, str) and os.path.exists(variable) 112 | 113 | 114 | 115 | def download_models(): 116 | home_root = os.getcwd() 117 | models_dir = os.path.join(home_root, "models") 118 | os.makedirs(models_dir, exist_ok=True) 119 | 120 | if not os.path.exists(os.path.join(models_dir, "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2")): 121 | 122 | files = [ 123 | ("SonyCSLParis/music2latent", "music2latent.pt", "model"), 124 | ("asigalov61/music2latent-maestro", "music2latent_maestro_loss_16.871_iters_45500.pt", "model"), 125 | ("asigalov61/music2latent-maestro", "music2latent_maestro_loss_27.834_iters_14300.pt", "model"), 126 | ("projectlosangeles/soundfonts4u", "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2", "dataset"), 127 | ] 128 | 129 | for repo_id, filename, repo_type in files: 130 | print(f"Downloading {filename} from {repo_id} ...") 131 | # download to Hugging Face cache and get the real path returned 132 | cached_path = hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=filename, local_dir=models_dir) 133 | 134 | print("Models were downloaded successfully!") -------------------------------------------------------------------------------- /neuralpiano/music2latent/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .hparams import * 6 | 7 | 8 | def wv2spec(wv, hop_size=256, fac=4): 9 | X = stft(wv, hop_size=hop_size, fac=fac, device=wv.device) 10 | X = power2db(torch.abs(X)**2) 11 | X = normalize(X) 12 | return X 13 | 14 | def spec2wv(S,P, hop_size=256, fac=4): 15 | S = denormalize(S) 16 | S = torch.sqrt(db2power(S)) 17 | P = P * np.pi 18 | SP = torch.complex(S * torch.cos(P), S * torch.sin(P)) 19 | return istft(SP, fac=fac, hop_size=hop_size, device=SP.device) 20 | 21 | def denormalize_realimag(x): 22 | x = x/beta_rescale 23 | return torch.sign(x)*(x.abs()**(1./alpha_rescale)) 24 | 25 | def normalize_complex(x): 26 | return beta_rescale*(x.abs()**alpha_rescale).to(torch.complex64)*torch.exp(1j*torch.angle(x).to(torch.complex64)) 27 | 28 | def denormalize_complex(x): 29 | x = x/beta_rescale 30 | return (x.abs()**(1./alpha_rescale)).to(torch.complex64)*torch.exp(1j*torch.angle(x).to(torch.complex64)) 31 | 32 | def wv2complex(wv, hop_size=256, fac=4): 33 | X = stft(wv, hop_size=hop_size, fac=fac, device=wv.device) 34 | return X[:,:hop_size*2,:] 35 | 36 | def wv2realimag(wv, hop_size=256, fac=4): 37 | X = wv2complex(wv, hop_size, fac) 38 | X = normalize_complex(X) 39 | return torch.stack((torch.real(X),torch.imag(X)), -3) 40 | 41 | def realimag2wv(x, hop_size=256, fac=4): 42 | x = torch.nn.functional.pad(x, (0,0,0,1)) 43 | real,imag = torch.chunk(x, 2, -3) 44 | X = torch.complex(real.squeeze(-3),imag.squeeze(-3)) 45 | X = denormalize_complex(X) 46 | return istft(X, fac=fac, hop_size=hop_size, device=X.device).clamp(-1.,1.) 47 | 48 | def to_representation_encoder(x): 49 | return wv2realimag(x, hop) 50 | 51 | def to_representation(x): 52 | return wv2realimag(x, hop) 53 | 54 | def to_waveform(x): 55 | return realimag2wv(x, hop) 56 | 57 | def overlap_and_add(signal, frame_step): 58 | 59 | outer_dimensions = signal.shape[:-2] 60 | outer_rank = torch.numel(torch.tensor(outer_dimensions)) 61 | 62 | def full_shape(inner_shape): 63 | s = torch.cat([torch.tensor(outer_dimensions), torch.tensor(inner_shape)], 0) 64 | s = list(s) 65 | s = [int(el) for el in s] 66 | return s 67 | 68 | frame_length = signal.shape[-1] 69 | frames = signal.shape[-2] 70 | 71 | # Compute output length. 72 | output_length = frame_length + frame_step * (frames - 1) 73 | 74 | # Compute the number of segments, per frame. 75 | segments = -(-frame_length // frame_step) # Divide and round up. 76 | 77 | signal = torch.nn.functional.pad(signal, (0, segments * frame_step - frame_length, 0, segments)) 78 | 79 | shape = full_shape([frames + segments, segments, frame_step]) 80 | signal = torch.reshape(signal, shape) 81 | 82 | perm = torch.cat([torch.arange(0, outer_rank), torch.tensor([el+outer_rank for el in [1, 0, 2]])], 0) 83 | perm = list(perm) 84 | perm = [int(el) for el in perm] 85 | signal = torch.permute(signal, perm) 86 | 87 | shape = full_shape([(frames + segments) * segments, frame_step]) 88 | signal = torch.reshape(signal, shape) 89 | 90 | signal = signal[..., :(frames + segments - 1) * segments, :] 91 | 92 | shape = full_shape([segments, (frames + segments - 1), frame_step]) 93 | signal = torch.reshape(signal, shape) 94 | 95 | signal = signal.sum(-3) 96 | 97 | # Flatten the array. 98 | shape = full_shape([(frames + segments - 1) * frame_step]) 99 | signal = torch.reshape(signal, shape) 100 | 101 | # Truncate to final length. 102 | signal = signal[..., :output_length] 103 | 104 | return signal 105 | 106 | def inverse_stft_window(frame_length, frame_step, forward_window): 107 | denom = forward_window**2 108 | overlaps = -(-frame_length // frame_step) 109 | denom = F.pad(denom, (0, overlaps * frame_step - frame_length)) 110 | denom = torch.reshape(denom, [overlaps, frame_step]) 111 | denom = torch.sum(denom, 0, keepdim=True) 112 | denom = torch.tile(denom, [overlaps, 1]) 113 | denom = torch.reshape(denom, [overlaps * frame_step]) 114 | return forward_window / denom[:frame_length] 115 | 116 | def istft(SP, fac=4, hop_size=256, device='cuda'): 117 | x = torch.fft.irfft(SP, dim=-2) 118 | window = torch.hann_window(fac*hop_size).to(device) 119 | window = inverse_stft_window(fac*hop_size, hop_size, window) 120 | x = x*window.unsqueeze(-1) 121 | return overlap_and_add(x.permute(0,2,1), hop_size) 122 | 123 | def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): 124 | """ 125 | equivalent of tf.signal.frame 126 | """ 127 | signal_length = signal.shape[axis] 128 | if pad_end: 129 | frames_overlap = frame_length - frame_step 130 | rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) 131 | pad_size = int(frame_length - rest_samples) 132 | if pad_size != 0: 133 | pad_axis = [0] * signal.ndim 134 | pad_axis[axis] = pad_size 135 | signal = F.pad(signal, pad_axis, "constant", pad_value) 136 | frames = signal.unfold(axis, frame_length, frame_step) 137 | return frames 138 | 139 | def stft(wv, fac=4, hop_size=256, device='cuda'): 140 | window = torch.hann_window(fac*hop_size).to(device) 141 | framed_signals = frame(wv, fac*hop_size, hop_size) 142 | framed_signals = framed_signals*window 143 | return torch.fft.rfft(framed_signals, n=None, dim=- 1, norm=None).permute(0,2,1) 144 | 145 | def normalize(S, mu_rescale=-25., sigma_rescale=75.): 146 | return (S - mu_rescale) / sigma_rescale 147 | 148 | def denormalize(S, mu_rescale=-25., sigma_rescale=75.): 149 | return (S * sigma_rescale) + mu_rescale 150 | 151 | def db2power(S_db, ref=1.0): 152 | return ref * torch.pow(10.0, 0.1 * S_db) 153 | 154 | def power2db(power, ref_value=1.0, amin=1e-10): 155 | log_spec = 10.0 * torch.log10(torch.maximum(torch.tensor(amin), power)) 156 | log_spec -= 10.0 * torch.log10(torch.maximum(torch.tensor(amin), torch.tensor(ref_value))) 157 | return log_spec -------------------------------------------------------------------------------- /neuralpiano/denoise.py: -------------------------------------------------------------------------------- 1 | # Denoise Python module 2 | 3 | import math 4 | from typing import Tuple, Dict, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | TensorLike = Union[torch.Tensor, np.ndarray] 11 | 12 | def denoise_audio(mono: TensorLike, 13 | sr: int = 48000, 14 | n_fft: int = 4096, 15 | hop: Optional[int] = None, 16 | noise_seconds: float = 0.8, 17 | max_atten_db: float = 18.0, 18 | noise_floor_db: float = -60.0, 19 | smoothing_time_ms: float = 40.0, 20 | smoothing_freq_bins: int = 3, 21 | noise_sample: Optional[TensorLike] = None, 22 | device: Optional[torch.device] = None, 23 | dtype: torch.dtype = torch.float32 24 | ) -> Tuple[torch.Tensor, Dict]: 25 | 26 | """ 27 | Conservative denoiser tuned for solo piano. 28 | Returns denoised 1-D torch.Tensor and diagnostics dict. 29 | """ 30 | 31 | # --- prepare tensor --- 32 | if isinstance(mono, np.ndarray): 33 | x = torch.from_numpy(mono.astype(np.float32)) 34 | elif isinstance(mono, torch.Tensor): 35 | x = mono.clone() 36 | else: 37 | raise TypeError("mono must be a numpy array or torch.Tensor") 38 | 39 | if x.ndim != 1: 40 | raise ValueError("mono must be 1-D (mono)") 41 | 42 | device = device or (x.device if isinstance(x, torch.Tensor) else torch.device('cpu')) 43 | x = x.to(device=device, dtype=dtype, copy=False) 44 | N = x.shape[-1] 45 | eps = 1e-12 46 | 47 | hop = hop or (n_fft // 4) 48 | win = torch.hann_window(n_fft, device=device, dtype=dtype) 49 | 50 | # STFT / ISTFT helpers 51 | def stft_torch(sig): 52 | return torch.stft(sig, n_fft=n_fft, hop_length=hop, win_length=n_fft, 53 | window=win, center=True, return_complex=True, pad_mode='reflect') 54 | 55 | def istft_torch(X, length): 56 | return torch.istft(X, n_fft=n_fft, hop_length=hop, win_length=n_fft, 57 | window=win, center=True, length=length) 58 | 59 | # compute STFT 60 | X = stft_torch(x) # complex tensor shape [freq_bins, frames] 61 | mag = X.abs() + eps 62 | freq_bins, frames = mag.shape 63 | 64 | # noise magnitude estimate 65 | if noise_sample is not None: 66 | if isinstance(noise_sample, np.ndarray): 67 | n_wav = torch.from_numpy(noise_sample.astype(np.float32)).to(device=device, dtype=dtype) 68 | else: 69 | n_wav = noise_sample.to(device=device, dtype=dtype) 70 | if n_wav.ndim > 1: 71 | n_wav = n_wav.mean(dim=1) if n_wav.shape[0] == 1 else n_wav.view(-1) 72 | Xn = stft_torch(n_wav) 73 | noise_mag = Xn.abs().mean(dim=1) 74 | else: 75 | if noise_seconds <= 0: 76 | frames_for_noise = 1 77 | else: 78 | approx_frames = max(1, int(math.ceil((noise_seconds * sr) / hop))) 79 | frames_for_noise = min(approx_frames, frames) 80 | noise_mag = mag[:, :frames_for_noise].mean(dim=1) 81 | 82 | # floor noise magnitude 83 | noise_floor_lin = 10.0 ** (noise_floor_db / 20.0) 84 | noise_mag = noise_mag.clamp(min=noise_floor_lin) 85 | 86 | # Wiener-like gain: S^2 / (S^2 + N^2) 87 | S2 = mag ** 2 88 | N2 = noise_mag.unsqueeze(1) ** 2 89 | G = S2 / (S2 + N2 + eps) # shape [freq_bins, frames] 90 | 91 | # convert to attenuation dB and clamp 92 | att_db = -20.0 * torch.log10(G.clamp(min=1e-12)) 93 | att_db_clamped = att_db.clamp(max=max_atten_db) 94 | G_limited = 10.0 ** (-att_db_clamped / 20.0) # shape [freq_bins, frames] 95 | 96 | # ------------------------- 97 | # Time smoothing (per-frequency) 98 | # ------------------------- 99 | time_smooth_frames = max(1, int(round((smoothing_time_ms / 1000.0) * sr / hop))) 100 | if time_smooth_frames > 1: 101 | k = torch.hann_window(time_smooth_frames, device=device, dtype=dtype) 102 | k = k / k.sum() 103 | kernel = k.view(1, 1, -1).repeat(freq_bins, 1, 1) # [freq_bins,1,k] 104 | inp = G_limited.unsqueeze(0) # [1, freq_bins, frames] 105 | G_time = F.conv1d(inp, kernel, padding=(time_smooth_frames - 1) // 2, groups=freq_bins).squeeze(0) 106 | else: 107 | G_time = G_limited 108 | 109 | # ------------------------- 110 | # Frequency smoothing (per-frame) 111 | # ------------------------- 112 | if smoothing_freq_bins > 0: 113 | kf = torch.ones(smoothing_freq_bins, device=device, dtype=dtype) 114 | kf = kf / kf.sum() 115 | G_perm = G_time.permute(1, 0).unsqueeze(1) # [frames, 1, freq_bins] 116 | kernel_f = kf.view(1, 1, -1) 117 | G_freq_perm = F.conv1d(G_perm, kernel_f, padding=(smoothing_freq_bins - 1) // 2).squeeze(1) # [frames, freq_bins] 118 | G_freq = G_freq_perm.permute(1, 0) # back to [freq_bins, frames] 119 | else: 120 | G_freq = G_time 121 | 122 | # Ensure orientation is [freq_bins, frames] 123 | if G_freq.ndim != 2: 124 | raise RuntimeError("Unexpected G_freq dimensionality") 125 | if G_freq.shape[0] != freq_bins and G_freq.shape[1] == freq_bins: 126 | G_freq = G_freq.permute(1, 0) 127 | 128 | # Build frequency-dependent protection vector [freq_bins, 1] 129 | freqs = torch.linspace(0.0, float(sr) / 2.0, steps=freq_bins, device=device, dtype=dtype).unsqueeze(1) 130 | low_protect = (freqs < 200.0).to(dtype) 131 | high_allow = (freqs > 6000.0).to(dtype) 132 | 133 | # apply frequency-dependent scaling (broadcast across frames) 134 | G_final = G_freq * (1.0 - 0.35 * low_protect) * (1.0 + 0.25 * high_allow) 135 | G_final = G_final.clamp(min=0.0, max=1.0) 136 | 137 | # ------------------------- 138 | # ALIGNMENT FIX: ensure G_final matches X.shape exactly 139 | # ------------------------- 140 | # If smoothing changed the frame count by ±1 (or more), trim or pad G_final to match X. 141 | # Trim if longer, pad by repeating last column if shorter. 142 | if G_final.shape[0] != freq_bins: 143 | # if frequency axis mismatched, try safe transpose or resize 144 | if G_final.shape[1] == freq_bins and G_final.shape[0] == frames: 145 | G_final = G_final.permute(1, 0) 146 | else: 147 | # fallback: resize frequency axis by trimming or repeating last row 148 | if G_final.shape[0] > freq_bins: 149 | G_final = G_final[:freq_bins, :] 150 | else: 151 | pad_rows = freq_bins - G_final.shape[0] 152 | last_row = G_final[-1:, :].repeat(pad_rows, 1) 153 | G_final = torch.cat([G_final, last_row], dim=0) 154 | 155 | if G_final.shape[1] != frames: 156 | if G_final.shape[1] > frames: 157 | G_final = G_final[:, :frames] 158 | else: 159 | # pad by repeating last column 160 | pad_cols = frames - G_final.shape[1] 161 | last_col = G_final[:, -1:].repeat(1, pad_cols) 162 | G_final = torch.cat([G_final, last_col], dim=1) 163 | 164 | # final safety check 165 | if G_final.shape != (freq_bins, frames): 166 | raise RuntimeError(f"Unable to align mask to STFT shape: mask {G_final.shape}, STFT {(freq_bins, frames)}") 167 | 168 | # apply mask (preserve phase) 169 | X_denoised = X * G_final 170 | 171 | # reconstruct 172 | y = istft_torch(X_denoised, length=N) 173 | 174 | # tiny residual low-frequency subtraction (gentle) 175 | lf_cut = 60.0 176 | lf_bin = int(round(lf_cut / (sr / 2.0) * (freq_bins - 1))) 177 | if lf_bin >= 1: 178 | lf_rms = mag[:lf_bin, :].mean().item() 179 | subtract = 0.02 * lf_rms 180 | y = y - subtract * torch.mean(y) 181 | 182 | # final safety normalize (very gentle) 183 | peak_in = float(x.abs().max().item()) 184 | peak_out = float(y.abs().max().item()) 185 | if peak_out > 0.999: 186 | y = y * (0.999 / peak_out) 187 | final_scale = 0.999 / peak_out 188 | else: 189 | final_scale = 1.0 190 | 191 | diagnostics = { 192 | "sr": sr, 193 | "n_fft": n_fft, 194 | "hop": hop, 195 | "noise_seconds": noise_seconds, 196 | "max_atten_db": max_atten_db, 197 | "smoothing_time_ms": smoothing_time_ms, 198 | "smoothing_freq_bins": smoothing_freq_bins, 199 | "input_peak": peak_in, 200 | "output_peak": float(y.abs().max().item()), 201 | "final_scale": final_scale, 202 | } 203 | 204 | return y.to(dtype=dtype, device=device), diagnostics -------------------------------------------------------------------------------- /neuralpiano/bass.py: -------------------------------------------------------------------------------- 1 | # Bass Python module 2 | 3 | import math 4 | from typing import Tuple, Dict, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | TensorLike = Union[torch.Tensor, np.ndarray] 11 | 12 | def enhance_audio_bass(mono: TensorLike, 13 | sr: int = 48000, 14 | low_cutoff: float = 200.0, 15 | ir_len: int = 1025, 16 | low_gain_db: float = 8.0, 17 | sub_mix: float = 0.75, 18 | compressor_thresh_db: float = -24.0, 19 | compressor_ratio: float = 3.0, 20 | compressor_attack_ms: float = 8.0, 21 | compressor_release_ms: float = 120.0, 22 | makeup_db: float = 0.0, 23 | drive: float = 1.15, 24 | wet_mix: float = 0.9, 25 | downsample_target: int = 4000, 26 | device: Optional[torch.device] = None, 27 | dtype: torch.dtype = torch.float32 28 | ) -> Tuple[torch.Tensor, Dict]: 29 | 30 | """ 31 | Fast bass enhancement optimized for GPU with robust shape alignment. 32 | 33 | Returns: 34 | enhanced 1-D torch.Tensor (same length as input), diagnostics dict. 35 | """ 36 | 37 | # --- prepare input tensor --- 38 | if isinstance(mono, np.ndarray): 39 | x = torch.from_numpy(mono.astype(np.float32)) 40 | elif isinstance(mono, torch.Tensor): 41 | x = mono.clone() 42 | else: 43 | raise TypeError("mono must be numpy or torch.Tensor") 44 | if x.ndim != 1: 45 | raise ValueError("mono must be 1-D (mono)") 46 | 47 | device = device or (x.device if isinstance(x, torch.Tensor) else torch.device('cpu')) 48 | x = x.to(device=device, dtype=dtype, copy=False) 49 | N = x.shape[-1] 50 | eps = 1e-12 51 | 52 | # --- helpers --- 53 | def db2lin(db): return 10.0 ** (db / 20.0) 54 | def next_pow2(n): return 1 << ((n - 1).bit_length()) 55 | 56 | # linear-phase lowpass IR builder (small, efficient) 57 | def make_lowpass_ir(cutoff_hz, sr_local, length): 58 | if length % 2 == 0: 59 | length += 1 60 | t = torch.arange(length, device=device, dtype=dtype) - (length - 1) / 2.0 61 | sinc_arg = 2.0 * cutoff_hz / sr_local * t 62 | h = torch.sinc(sinc_arg) 63 | beta = 6.0 64 | # kaiser_window signature: (window_length, periodic=False, beta, *, dtype=None, layout=None, device=None) 65 | win = torch.kaiser_window(length, False, beta, dtype=dtype, device=device) 66 | h = h * win 67 | h = h / (h.sum() + eps) 68 | return h 69 | 70 | # FFT convolution using optional precomputed kernel FFT 71 | def fft_conv_signal_kernel(sig, kernel, kernel_fft=None): 72 | n = sig.shape[-1] 73 | k = kernel.shape[-1] 74 | out_len = n + k - 1 75 | size = next_pow2(out_len) 76 | S = torch.fft.rfft(sig, n=size) 77 | if kernel_fft is None: 78 | K = torch.fft.rfft(kernel, n=size) 79 | else: 80 | K = kernel_fft 81 | Y = S * K 82 | y = torch.fft.irfft(Y, n=size)[:out_len] 83 | return y[:n], K 84 | 85 | # --- 1) extract low band via FFT conv (single pass) --- 86 | lp = make_lowpass_ir(low_cutoff, sr, ir_len) 87 | low, lp_fft = fft_conv_signal_kernel(x, lp) # low is same device/dtype 88 | 89 | # --- 2) downsample low band for fast processing --- 90 | ds = max(1, int(round(sr / float(downsample_target)))) 91 | ds_sr = sr // ds 92 | if ds > 1: 93 | pad = (-low.shape[-1]) % ds 94 | if pad: 95 | low_p = F.pad(low.unsqueeze(0).unsqueeze(0), (0, pad)).squeeze(0).squeeze(0) 96 | else: 97 | low_p = low 98 | # decimate by averaging each block (cheap, preserves low-band energy) 99 | low_ds = low_p.view(-1, ds).mean(dim=1) 100 | else: 101 | low_ds = low 102 | 103 | # --- 3) compressor on downsampled low band (vectorized) --- 104 | win_len = max(1, int(round(0.01 * ds_sr))) 105 | sq = low_ds * low_ds 106 | pad = (win_len - 1) // 2 107 | sq_p = F.pad(sq.unsqueeze(0).unsqueeze(0), (pad, win_len - 1 - pad)).squeeze(0).squeeze(0) 108 | kernel = torch.ones(win_len, device=device, dtype=dtype) / float(win_len) 109 | rms_sq = F.conv1d(sq_p.unsqueeze(0).unsqueeze(0), kernel.view(1,1,-1)).squeeze(0).squeeze(0) 110 | rms = torch.sqrt(rms_sq + 1e-12) 111 | 112 | # soft-knee gain computer vectorized 113 | lvl_db = 20.0 * torch.log10(rms.clamp(min=1e-12)) 114 | knee = 3.0 115 | over = lvl_db - compressor_thresh_db 116 | zero = torch.zeros_like(over) 117 | gr_db = torch.where( 118 | over <= -knee, 119 | zero, 120 | torch.where( 121 | over >= knee, 122 | compressor_thresh_db + (over / compressor_ratio) - lvl_db, 123 | - ((1.0 - 1.0/compressor_ratio) * (over + knee)**2) / (4.0 * knee) 124 | ) 125 | ).clamp(max=0.0) 126 | gain_lin = 10.0 ** (gr_db / 20.0) 127 | 128 | # smooth gain with FIR approximations for attack/release 129 | def exp_kernel(tc_ms, sr_local, length=64): 130 | if length < 3: 131 | length = 3 132 | tau = max(1e-6, tc_ms / 1000.0) 133 | t = torch.arange(length, device=device, dtype=dtype) 134 | k = torch.exp(-t / (tau * sr_local)) 135 | k = k / k.sum() 136 | return k 137 | 138 | atk_k = exp_kernel(compressor_attack_ms, ds_sr, length=64) 139 | rel_k = exp_kernel(compressor_release_ms, ds_sr, length=128) 140 | # convolve (padding may change length slightly) 141 | g_atk = F.conv1d(gain_lin.unsqueeze(0).unsqueeze(0), atk_k.view(1,1,-1), padding=(atk_k.numel()-1)//2).squeeze(0).squeeze(0) 142 | g_smooth = F.conv1d(g_atk.unsqueeze(0).unsqueeze(0), rel_k.view(1,1,-1), padding=(rel_k.numel()-1)//2).squeeze(0).squeeze(0) 143 | 144 | # --- ALIGN: ensure g_smooth matches low_ds length --- 145 | if g_smooth.shape[0] != low_ds.shape[0]: 146 | if g_smooth.shape[0] > low_ds.shape[0]: 147 | g_smooth = g_smooth[:low_ds.shape[0]] 148 | else: 149 | pad_len = low_ds.shape[0] - g_smooth.shape[0] 150 | if g_smooth.numel() == 0: 151 | # fallback to ones if something went wrong 152 | g_smooth = torch.ones(low_ds.shape[0], device=device, dtype=dtype) 153 | else: 154 | last = g_smooth[-1:].repeat(pad_len) 155 | g_smooth = torch.cat([g_smooth, last], dim=0) 156 | 157 | # apply makeup 158 | makeup_lin = db2lin(makeup_db) 159 | low_ds_comp = low_ds * g_smooth * makeup_lin 160 | 161 | # --- 4) subharmonic generation on downsampled signal --- 162 | rect = torch.clamp(low_ds_comp, min=0.0) 163 | lp_sub_len = 513 if ds_sr >= 4000 else 257 164 | lp_sub = make_lowpass_ir(200.0, ds_sr, lp_sub_len) 165 | rect_lp, _ = fft_conv_signal_kernel(rect, lp_sub) 166 | sub_gain = db2lin(low_gain_db) 167 | sub_ds = rect_lp * sub_gain 168 | 169 | # soft saturation (vectorized) 170 | one = torch.tensor(1.0, device=device, dtype=dtype) 171 | sat_low_ds = torch.tanh(low_ds_comp * drive) / torch.tanh(one) 172 | sat_sub_ds = torch.tanh(sub_ds * (drive * 0.8)) / torch.tanh(one) 173 | enhanced_low_ds = (1.0 - sub_mix) * sat_low_ds + sub_mix * sat_sub_ds 174 | 175 | # --- 5) upsample enhanced low back to original rate --- 176 | if ds > 1: 177 | # ensure length before upsampling is consistent 178 | L_needed = (low.shape[-1] + ds - 1) // ds 179 | if enhanced_low_ds.shape[0] < L_needed: 180 | pad_len = L_needed - enhanced_low_ds.shape[0] 181 | enhanced_low_ds = torch.cat([enhanced_low_ds, enhanced_low_ds[-1:].repeat(pad_len)], dim=0) 182 | enhanced_low = F.interpolate(enhanced_low_ds.view(1,1,-1), scale_factor=ds, mode='linear', align_corners=False).view(-1)[:low.shape[-1]] 183 | else: 184 | enhanced_low = enhanced_low_ds 185 | # ensure exact length 186 | if enhanced_low.shape[0] != low.shape[-1]: 187 | if enhanced_low.shape[0] > low.shape[-1]: 188 | enhanced_low = enhanced_low[:low.shape[-1]] 189 | else: 190 | enhanced_low = torch.cat([enhanced_low, enhanced_low[-1:].repeat(low.shape[-1] - enhanced_low.shape[0])], dim=0) 191 | 192 | # --- 6) band-limit enhanced low using original lp FFT (cheap because we have lp_fft) --- 193 | enhanced_low_band, _ = fft_conv_signal_kernel(enhanced_low, lp, kernel_fft=lp_fft) 194 | 195 | # --- 7) scale wet and mix back --- 196 | wet = enhanced_low_band * db2lin(low_gain_db) 197 | out = (1.0 - wet_mix) * x + wet_mix * (x + wet) 198 | 199 | # final gentle limiter 200 | peak = float(out.abs().max().item()) 201 | if peak > 0.999: 202 | out = out * (0.999 / peak) 203 | 204 | diagnostics = { 205 | "sr": sr, 206 | "low_cutoff": low_cutoff, 207 | "ir_len": ir_len, 208 | "low_gain_db": low_gain_db, 209 | "sub_mix": sub_mix, 210 | "downsample_factor": ds, 211 | "downsample_rate": ds_sr, 212 | "compressor_avg_gain_db": float(20.0 * math.log10(max(g_smooth.mean().item(), 1e-12))), 213 | "input_peak": float(x.abs().max().item()), 214 | "output_peak": float(out.abs().max().item()), 215 | } 216 | 217 | return out.to(dtype=dtype, device=device), diagnostics -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Tegridy Code 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /neuralpiano/music2latent/inference.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from .hparams import * 7 | from .hparams_inference import * 8 | from .utils import * 9 | from .models import * 10 | from .audio import * 11 | 12 | 13 | def _should_show_progress(): 14 | """ 15 | Return True when progress bars should be shown (main process). 16 | Hides bars in distributed workers other than rank 0. 17 | """ 18 | try: 19 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 20 | return torch.distributed.get_rank() == 0 21 | except Exception: 22 | # If distributed is not configured or any error occurs, show progress 23 | return True 24 | return True 25 | 26 | 27 | class EncoderDecoder: 28 | def __init__(self, load_multi_instrumental_model=False, use_v1_piano_model=False, load_path_inference=None, device=None): 29 | download_models() 30 | if device is None: 31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | else: 33 | self.device = device 34 | self.load_path_inference = load_path_inference 35 | if load_path_inference is None: 36 | if load_multi_instrumental_model: 37 | self.load_path_inference = load_path_inference_multi_instrumental_default 38 | else: 39 | if use_v1_piano_model: 40 | self.load_path_inference = load_path_inference_solo_piano_v1_default 41 | else: 42 | self.load_path_inference = load_path_inference_solo_piano_default 43 | self.get_models() 44 | 45 | def get_models(self): 46 | gen = UNet().to(self.device) 47 | checkpoint = torch.load(self.load_path_inference, map_location=self.device, weights_only=False) 48 | gen.load_state_dict(checkpoint['gen_state_dict'], strict=False) 49 | self.gen = gen 50 | 51 | def encode(self, path_or_audio, max_waveform_length=None, max_batch_size=None, extract_features=False, show_progress=True): 52 | ''' 53 | path_or_audio: path of audio sample to encode or numpy array of waveform to encode 54 | max_waveform_length: maximum length of waveforms in the batch for encoding: tune it depending on the available GPU memory 55 | max_batch_size: maximum inference batch size for encoding: tune it depending on the available GPU memory 56 | extract_features: if True, return raw features (no sigma rescale) 57 | show_progress: whether to show tqdm progress bars 58 | 59 | WARNING! if input is numpy array of stereo waveform, it must have shape [waveform_samples, audio_channels] 60 | 61 | Returns latents with shape [audio_channels, dim, length] 62 | ''' 63 | if max_waveform_length is None: 64 | max_waveform_length = max_waveform_length_encode 65 | if max_batch_size is None: 66 | max_batch_size = max_batch_size_encode 67 | return encode_audio_inference(path_or_audio, self, max_waveform_length, max_batch_size, device=self.device, extract_features=extract_features, show_progress=show_progress) 68 | 69 | def decode(self, latent, denoising_steps=1, max_waveform_length=None, max_batch_size=None, show_progress=True): 70 | ''' 71 | latent: numpy array of latents to decode with shape [audio_channels, dim, length] 72 | denoising_steps: number of denoising steps to use for decoding 73 | max_waveform_length: maximum length of waveforms in the batch for decoding: tune it depending on the available GPU memory 74 | max_batch_size: maximum inference batch size for decoding: tune it depending on the available GPU memory 75 | show_progress: whether to show tqdm progress bars 76 | 77 | Returns numpy array of decoded waveform with shape [waveform_samples, audio_channels] 78 | ''' 79 | if max_waveform_length is None: 80 | max_waveform_length = max_waveform_length_decode 81 | if max_batch_size is None: 82 | max_batch_size = max_batch_size_decode 83 | return decode_latent_inference(latent, self, max_waveform_length, max_batch_size, diffusion_steps=denoising_steps, device=self.device, show_progress=show_progress) 84 | 85 | 86 | 87 | 88 | 89 | 90 | # decode samples with consistency model to real/imag STFT spectrograms 91 | # Parameters: 92 | # model: trained consistency model 93 | # latents: latent representation with shape [audio_channels/batch_size, dim, length] 94 | # diffusion_steps: number of steps 95 | # Returns: 96 | # decoded_spectrograms with shape [audio_channels/batch_size, data_channels, hop*2, length*downscaling_factor] 97 | def decode_to_representation(model, latents, diffusion_steps=1, device='cuda'): 98 | num_samples = latents.shape[0] 99 | downscaling_factor = 2**freq_downsample_list.count(0) 100 | sample_length = int(latents.shape[-1]*downscaling_factor) 101 | initial_noise = torch.randn((num_samples, data_channels, hop*2, sample_length)).to(device)*sigma_max 102 | decoded_spectrograms = reverse_diffusion(model, initial_noise, diffusion_steps, latents=latents) 103 | return decoded_spectrograms 104 | 105 | 106 | 107 | 108 | # Encode audio sample for inference 109 | # Parameters: 110 | # audio_path: path of audio sample 111 | # model: trained consistency model 112 | # device: device to run the model on 113 | # Returns: 114 | # latent: compressed latent representation with shape [audio_channels, dim, latent_length] 115 | @torch.no_grad() 116 | def encode_audio_inference(audio_path, trainer, max_waveform_length_encode, max_batch_size_encode, device='cuda', extract_features=False, show_progress=True): 117 | trainer.gen = trainer.gen.to(device) 118 | trainer.gen.eval() 119 | downscaling_factor = 2**freq_downsample_list.count(0) 120 | if is_path(audio_path): 121 | audio, sr = sf.read(audio_path, dtype='float32', always_2d=True) 122 | audio = np.transpose(audio, [1,0]) 123 | else: 124 | audio = audio_path 125 | sr = None 126 | if len(audio.shape)==1: 127 | # check if audio is numpy array, then use np.expand_dims, if it is a pytorch tensor, then use torch.unsqueeze 128 | if isinstance(audio, np.ndarray): 129 | audio = np.expand_dims(audio, 0) 130 | else: 131 | audio = torch.unsqueeze(audio, 0) 132 | audio_channels = audio.shape[0] 133 | if isinstance(audio, np.ndarray): 134 | audio = torch.from_numpy(audio).to(device) 135 | else: 136 | # check if audio tensor is on cpu. if it is, move it to the device 137 | if audio.device.type=='cpu': 138 | audio = audio.to(device) 139 | 140 | # EXPERIMENTAL: crop audio to be divisible by downscaling_factor 141 | cropped_length = ((((audio.shape[-1]-3*hop)//hop)//downscaling_factor)*hop*downscaling_factor)+3*hop 142 | audio = audio[:,:cropped_length] 143 | 144 | repr_encoder = to_representation_encoder(audio) 145 | sample_length = repr_encoder.shape[-1] 146 | max_sample_length = (int(max_waveform_length_encode/hop)//downscaling_factor)*downscaling_factor 147 | 148 | # if repr_encoder is longer than max_sample_length, chunk it into max_sample_length chunks, the last chunk will be zero-padded, then concatenate the chunks into the batch dimension (before encoding them) 149 | pad_size = 0 150 | if sample_length > max_sample_length: 151 | # pad repr_encoder with copies of the sample to be divisible by max_sample_length 152 | pad_size = max_sample_length - (sample_length % max_sample_length) 153 | # repeat repr_encoder such that repr_encoder.shape[-1] is higher than pad_size, then crop it such that repr_encoder.shape[-1]=pad_size 154 | repr_encoder_pad = torch.cat([repr_encoder for _ in range(1+(pad_size//repr_encoder.shape[-1]))], dim=-1)[:,:,:,:pad_size] 155 | repr_encoder = torch.cat([repr_encoder, repr_encoder_pad], dim=-1) 156 | repr_encoder = torch.split(repr_encoder, max_sample_length, dim=-1) 157 | repr_encoder = torch.cat(repr_encoder, dim=0) 158 | # encode repr_encoder using a maximum batch size (dimension 0) of max_batch_size_inference, if repr_encoder is longer than max_batch_size_inference, chunk it into max_batch_size_inference chunks, the last chunk will maybe have less samples in the batch, then encode the chunks and concatenate the results into the batch dimension 159 | max_batch_size = max_batch_size_encode 160 | if repr_encoder.shape[0] > max_batch_size: 161 | repr_encoder_ls = torch.split(repr_encoder, max_batch_size, dim=0) 162 | latent_ls = [] 163 | show = show_progress and _should_show_progress() 164 | for chunk in tqdm(repr_encoder_ls, desc="Encoding chunks", unit="chunk", leave=False, disable=not show): 165 | # Use autocast as before; mixed_precision controls dtype 166 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): 167 | latent_chunk = trainer.gen.encoder(chunk, extract_features=extract_features) 168 | latent_ls.append(latent_chunk) 169 | latent = torch.cat(latent_ls, dim=0) 170 | else: 171 | # single batch encode 172 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): 173 | latent = trainer.gen.encoder(repr_encoder, extract_features=extract_features) 174 | # split samples 175 | if latent.shape[0]>1: 176 | latent_ls = torch.split(latent, audio_channels, 0) 177 | latent = torch.cat(latent_ls, -1) 178 | latent = latent[:,:,:latent.shape[-1]-(pad_size//downscaling_factor)] 179 | if extract_features: 180 | return latent 181 | else: 182 | return latent/sigma_rescale 183 | 184 | 185 | 186 | # Decode latent representation for inference, use the same framework as in encode_audio_inference, but in reverse order for decoding 187 | # Parameters: 188 | # latent: compressed latent representation with shape [audio_channels, dim, length] 189 | # model: trained consistency model 190 | # diffusion_steps: number of diffusion steps to use for decoding 191 | # device: device to run the model on 192 | # Returns: 193 | # audio: numpy array of decoded waveform with shape [waveform_samples, audio_channels] 194 | @torch.no_grad() 195 | def decode_latent_inference(latent, trainer, max_waveform_length_decode, max_batch_size_decode, diffusion_steps=1, device='cuda', show_progress=True): 196 | trainer.gen = trainer.gen.to(device) 197 | trainer.gen.eval() 198 | downscaling_factor = 2**freq_downsample_list.count(0) 199 | latent = latent*sigma_rescale 200 | # check if latent is numpy array, then convert to tensor 201 | if isinstance(latent, np.ndarray): 202 | latent = torch.from_numpy(latent) 203 | # check if latent tensor is on cpu. if it is, move it to the device 204 | if latent.device.type=='cpu': 205 | latent = latent.to(device) 206 | # if latent has only 2 dimensions, add a third dimension as axis 0 207 | if len(latent.shape)==2: 208 | latent = torch.unsqueeze(latent, 0) 209 | audio_channels = latent.shape[0] 210 | latent_length = latent.shape[-1] 211 | max_latent_length = int(max_waveform_length_decode/hop)//downscaling_factor 212 | 213 | # if latent is longer than max_latent_length, chunk it into max_latent_length chunks, the last chunk will be zero-padded, then concatenate the chunks into the batch dimension (before decoding them) 214 | pad_size = 0 215 | if latent_length > max_latent_length: 216 | # pad latent with copies of itself to be divisible by max_latent_length 217 | pad_size = max_latent_length - (latent_length % max_latent_length) 218 | # repeat latent such that latent.shape[-1] is higher than pad_size, then crop it such that latent.shape[-1]=pad_size 219 | latent_pad = torch.cat([latent for _ in range(1+(pad_size//latent.shape[-1]))], dim=-1)[:,:,:pad_size] 220 | latent = torch.cat([latent, latent_pad], dim=-1) 221 | latent = torch.split(latent, max_latent_length, dim=-1) 222 | latent = torch.cat(latent, dim=0) 223 | # decode latent using a maximum batch size (dimension 0) of max_batch_size_inference, if latent is longer than max_batch_size_inference, chunk it into max_batch_size_inference chunks, the last chunk will maybe have less samples in the batch, then decode the chunks and concatenate the results into the batch dimension 224 | max_batch_size = max_batch_size_decode 225 | show = show_progress and _should_show_progress() 226 | if latent.shape[0] > max_batch_size: 227 | latent_ls = torch.split(latent, max_batch_size, dim=0) 228 | repr_ls = [] 229 | for chunk in tqdm(latent_ls, desc="Decoding chunks", unit="chunk", leave=False, disable=not show): 230 | repr_chunk = decode_to_representation(trainer.gen, chunk, diffusion_steps=diffusion_steps, device=device) 231 | repr_ls.append(repr_chunk) 232 | repr = torch.cat(repr_ls, dim=0) 233 | else: 234 | # single batch decode 235 | repr = decode_to_representation(trainer.gen, latent, diffusion_steps=diffusion_steps, device=device) 236 | # split samples 237 | if repr.shape[0]>1: 238 | repr_ls = torch.split(repr, audio_channels, 0) 239 | repr = torch.cat(repr_ls, -1) 240 | repr = repr[:,:,:,:repr.shape[-1]-(pad_size*downscaling_factor)] 241 | return to_waveform(repr) -------------------------------------------------------------------------------- /neuralpiano/mixer.py: -------------------------------------------------------------------------------- 1 | # Mixer Python module 2 | 3 | import os 4 | import math 5 | from typing import List, Optional, Union, Sequence, Tuple 6 | import numpy as np 7 | import librosa 8 | import soundfile as sf 9 | from concurrent.futures import ThreadPoolExecutor, as_completed 10 | 11 | # Optional progress bar 12 | try: 13 | from tqdm import tqdm 14 | _HAS_TQDM = True 15 | except Exception: 16 | _HAS_TQDM = False 17 | 18 | def _ensure_list_defaults(values, n, default=0.0): 19 | if values is None: 20 | return [default] * n 21 | if isinstance(values, (int, float)): 22 | return [float(values)] * n 23 | vals = list(values) 24 | if len(vals) >= n: 25 | return vals[:n] 26 | return vals + [default] * (n - len(vals)) 27 | 28 | def _resample_channel(channel, orig_sr, target_sr, res_type='kaiser_fast'): 29 | if orig_sr == target_sr: 30 | return channel 31 | return librosa.resample(channel, orig_sr=orig_sr, target_sr=target_sr, res_type=res_type) 32 | 33 | def _apply_fades(track: np.ndarray, sr: int, fade_in_s: float, fade_out_s: float) -> np.ndarray: 34 | n_samples = track.shape[1] 35 | if n_samples == 0: 36 | return track 37 | fi = max(0, int(round(fade_in_s * sr))) 38 | fo = max(0, int(round(fade_out_s * sr))) 39 | if fi + fo > n_samples and (fi + fo) > 0: 40 | scale = n_samples / (fi + fo) 41 | fi = int(round(fi * scale)) 42 | fo = int(round(fo * scale)) 43 | env = np.ones(n_samples, dtype=np.float32) 44 | if fi > 0: 45 | t = np.linspace(0.0, 1.0, fi, endpoint=True, dtype=np.float32) 46 | env[:fi] = 0.5 * (1.0 - np.cos(np.pi * t)) 47 | if fo > 0: 48 | t = np.linspace(0.0, 1.0, fo, endpoint=True, dtype=np.float32) 49 | env[-fo:] = 0.5 * (1.0 + np.cos(np.pi * t)) 50 | return track * env[np.newaxis, :] 51 | 52 | def _normalize_input_array(arr: np.ndarray) -> np.ndarray: 53 | arr = np.asarray(arr) 54 | if arr.ndim == 1: 55 | return np.expand_dims(arr.astype(np.float32), 0) 56 | if arr.ndim == 2: 57 | # Heuristic: if shape looks like (samples, channels) transpose to (channels, samples) 58 | if arr.shape[1] <= 2 and arr.shape[0] > arr.shape[1]: 59 | return arr.T.astype(np.float32) 60 | return arr.astype(np.float32) 61 | raise ValueError("Input numpy array must be 1D or 2D (samples,channels) or (channels,samples)") 62 | 63 | def mix_audio(input_items: Sequence[Union[str, Tuple[np.ndarray, int]]], 64 | output_path: str = 'mixed.wav', 65 | gain_db: Optional[Union[float, List[float]]] = 0.0, 66 | pan: Optional[Union[float, List[float]]] = 0.0, 67 | delays: Optional[Union[float, List[float]]] = 0.0, 68 | fade_in: Optional[Union[float, List[float]]] = 0.0, 69 | fade_out: Optional[Union[float, List[float]]] = 0.0, 70 | target_sr: Optional[int] = None, 71 | normalize: bool = True, 72 | trim_trailing_silence: bool = False, 73 | trim_threshold_db: float = -60.0, 74 | trim_padding_seconds: float = 0.01, 75 | output_subtype: Optional[str] = None, 76 | workers: int = 4, 77 | show_progress: bool = False, 78 | verbose: bool = False, 79 | return_mix: bool = False 80 | ) -> Optional[Tuple[np.ndarray, int]]: 81 | 82 | """ 83 | Mix inputs (file paths and/or (array, sr) tuples) into one audio file. 84 | If return_mix is True the function returns (mix_array, sample_rate) where mix_array 85 | is shaped (samples, channels) and dtype float32. 86 | 87 | All other parameters behave as in previous versions: 88 | - gain_db, pan, delays, fade_in, fade_out accept single values or shorter lists. 89 | - target_sr defaults to the highest input sample rate. 90 | - normalize scales final peak to 0.999 when True. 91 | - trim_trailing_silence removes trailing silence below trim_threshold_db. 92 | - workers controls parallel processing. 93 | - show_progress uses tqdm if available. 94 | - verbose prints progress messages. 95 | """ 96 | 97 | n = len(input_items) 98 | if n == 0: 99 | raise ValueError("input_items must contain at least one element") 100 | 101 | # Prepare per-track parameter lists (allow shorter lists) 102 | gains_db = _ensure_list_defaults(gain_db, n, default=0.0) 103 | pans = _ensure_list_defaults(pan, n, default=0.0) 104 | delays_sec = _ensure_list_defaults(delays, n, default=0.0) 105 | fades_in = _ensure_list_defaults(fade_in, n, default=0.0) 106 | fades_out = _ensure_list_defaults(fade_out, n, default=0.0) 107 | pans = [max(-1.0, min(1.0, float(p))) for p in pans] 108 | 109 | if verbose: 110 | print(f"[mix_audio] Preparing {n} inputs...") 111 | 112 | tracks = [] 113 | srs = [] 114 | channel_counts = [] 115 | 116 | load_iter = list(enumerate(input_items)) 117 | if show_progress and _HAS_TQDM: 118 | load_iter = list(tqdm(load_iter, desc="Loading inputs", unit="item")) 119 | 120 | for idx, item in load_iter: 121 | if isinstance(item, str): 122 | if verbose: 123 | print(f" loading file [{idx+1}/{n}]: {item}") 124 | y, sr = librosa.load(item, sr=None, mono=False) 125 | if y.ndim == 1: 126 | y = np.expand_dims(y, 0) 127 | tracks.append(y.astype(np.float32)) 128 | srs.append(sr) 129 | channel_counts.append(y.shape[0]) 130 | elif isinstance(item, (tuple, list)) and len(item) == 2: 131 | arr, sr = item 132 | if verbose: 133 | print(f" using array input [{idx+1}/{n}] with sr={sr}") 134 | y = _normalize_input_array(arr) 135 | tracks.append(y.astype(np.float32)) 136 | srs.append(int(sr)) 137 | channel_counts.append(y.shape[0]) 138 | else: 139 | raise ValueError("Each input item must be a file path or a tuple (array, sr)") 140 | 141 | # Decide target sample rate 142 | if target_sr is None: 143 | target_sr = max(srs) 144 | if verbose: 145 | print(f"[mix_audio] Target sample rate: {target_sr} Hz") 146 | 147 | # Determine output channel count: stereo if any input has >=2 channels, else mono 148 | out_ch = 2 if any(c >= 2 for c in channel_counts) else 1 149 | if verbose: 150 | print(f"[mix_audio] Output channels: {out_ch}") 151 | 152 | # Process each track: resample, channel handling, panning, fades, apply gain 153 | processed = [None] * n 154 | 155 | def process_track(i): 156 | t = tracks[i] 157 | sr = srs[i] 158 | g_db = gains_db[i] 159 | p_val = pans[i] 160 | delay = float(delays_sec[i]) 161 | fi = float(fades_in[i]) 162 | fo = float(fades_out[i]) 163 | 164 | # Downmix multichannel (>2) to mono first 165 | if t.shape[0] > 2: 166 | t = np.expand_dims(np.mean(t, axis=0), 0) 167 | 168 | # Resample channels 169 | if sr != target_sr: 170 | if workers > 1 and t.shape[0] > 1: 171 | with ThreadPoolExecutor(max_workers=min(workers, t.shape[0])) as ex: 172 | futures = [ex.submit(_resample_channel, t[ch], sr, target_sr, 'kaiser_fast') for ch in range(t.shape[0])] 173 | resampled_ch = [f.result() for f in futures] 174 | t = np.vstack(resampled_ch) 175 | else: 176 | t = np.vstack([_resample_channel(t[ch], sr, target_sr, 'kaiser_fast') for ch in range(t.shape[0])]) 177 | 178 | # Channel handling and panning 179 | if t.shape[0] == 1 and out_ch == 2: 180 | mono = t[0] 181 | angle = (p_val + 1.0) * (math.pi / 4.0) 182 | left_gain = math.cos(angle) 183 | right_gain = math.sin(angle) 184 | left = mono * left_gain 185 | right = mono * right_gain 186 | t = np.vstack([left, right]) 187 | elif t.shape[0] == 2 and out_ch == 2: 188 | left, right = t[0], t[1] 189 | mono = 0.5 * (left + right) 190 | angle = (p_val + 1.0) * (math.pi / 4.0) 191 | left_gain = math.cos(angle) 192 | right_gain = math.sin(angle) 193 | t = np.vstack([mono * left_gain, mono * right_gain]) 194 | elif t.shape[0] == 2 and out_ch == 1: 195 | mono = 0.5 * (t[0] + t[1]) 196 | t = np.expand_dims(mono, 0) 197 | 198 | # Apply fades (in seconds) 199 | if (fi > 0.0) or (fo > 0.0): 200 | t = _apply_fades(t, target_sr, fi, fo) 201 | 202 | # Apply gain (dB -> linear) 203 | lin = 10.0 ** (g_db / 20.0) 204 | t = t * lin 205 | 206 | # Convert delay seconds to sample offset (can be negative) 207 | offset_samples = int(round(delay * target_sr)) 208 | return t, offset_samples 209 | 210 | if verbose: 211 | print("[mix_audio] Resampling and processing tracks...") 212 | 213 | # Parallel processing of tracks 214 | if workers > 1: 215 | with ThreadPoolExecutor(max_workers=workers) as ex: 216 | futures = {ex.submit(process_track, i): i for i in range(n)} 217 | if show_progress and _HAS_TQDM: 218 | pbar = tqdm(total=n, desc="Processing", unit="track") 219 | for fut in as_completed(futures): 220 | i = futures[fut] 221 | processed[i] = fut.result() 222 | if show_progress and _HAS_TQDM: 223 | pbar.update(1) 224 | if show_progress and _HAS_TQDM: 225 | pbar.close() 226 | else: 227 | proc_iter = range(n) 228 | if show_progress and _HAS_TQDM: 229 | proc_iter = tqdm(proc_iter, desc="Processing", unit="track") 230 | for i in proc_iter: 231 | processed[i] = process_track(i) 232 | 233 | # Determine final mix length considering offsets 234 | end_positions = [] 235 | for t, offset in processed: 236 | start = max(0, offset) 237 | end = start + t.shape[1] 238 | end_positions.append(end) 239 | max_len = max(end_positions) if end_positions else 0 240 | min_start = min(offset for _, offset in processed) 241 | 242 | # If there are negative offsets, shift everything forward so earliest sample >= 0 243 | shift_forward = 0 244 | if min_start < 0: 245 | shift_forward = -min_start 246 | max_len += shift_forward 247 | if verbose: 248 | print(f"[mix_audio] Negative offsets detected. Shifting all tracks forward by {shift_forward} samples") 249 | 250 | # Create mix buffer and add tracks at offsets 251 | mix = np.zeros((out_ch, max_len), dtype=np.float32) 252 | if show_progress and _HAS_TQDM: 253 | mix_iter = tqdm(processed, desc="Mixing", unit="track") 254 | else: 255 | mix_iter = processed 256 | 257 | for (t, offset) in mix_iter: 258 | start = offset + shift_forward 259 | if start < 0: 260 | clip = -start 261 | if clip >= t.shape[1]: 262 | continue 263 | t = t[:, clip:] 264 | start = 0 265 | end = start + t.shape[1] 266 | mix[:, start:end] += t 267 | 268 | # Normalize to avoid clipping 269 | if normalize: 270 | peak = np.max(np.abs(mix)) 271 | if peak > 0: 272 | mix *= (0.999 / peak) 273 | if verbose: 274 | print(f"[mix_audio] Normalized by factor {(0.999/peak):.6f}") 275 | 276 | # Optional trailing silence removal 277 | if trim_trailing_silence: 278 | if verbose: 279 | print(f"[mix_audio] Trimming trailing silence below {trim_threshold_db} dBFS") 280 | threshold_lin = 10.0 ** (trim_threshold_db / 20.0) 281 | abs_max = np.max(np.abs(mix), axis=0) 282 | non_silent = np.where(abs_max > threshold_lin)[0] 283 | if non_silent.size > 0: 284 | last_idx = int(non_silent[-1]) 285 | pad_samples = int(round(trim_padding_seconds * target_sr)) 286 | new_len = min(mix.shape[1], last_idx + 1 + pad_samples) 287 | mix = mix[:, :new_len] 288 | if verbose: 289 | print(f"[mix_audio] Trimmed to {new_len} samples ({new_len/target_sr:.3f} s)") 290 | else: 291 | keep = int(round(trim_padding_seconds * target_sr)) 292 | mix = mix[:, :keep] 293 | if verbose: 294 | print(f"[mix_audio] All silent; keeping {keep} samples ({keep/target_sr:.3f} s)") 295 | 296 | out = mix.T # (samples, channels) 297 | 298 | # Infer format from extension and validate 299 | ext = os.path.splitext(output_path)[1].lower().lstrip('.') 300 | fmt = ext.upper() 301 | available = sf.available_formats() 302 | if fmt not in available: 303 | raise ValueError(f"Output format {fmt} not supported by soundfile. Available: {list(available.keys())}") 304 | 305 | if return_mix: 306 | print(f"[mix_audio] Returning output") 307 | # Return a copy to avoid accidental modification of internal buffer 308 | return out.copy(), int(target_sr) 309 | 310 | if verbose: 311 | print(f"[mix_audio] Writing output to {output_path} (format={fmt}, subtype={output_subtype})") 312 | 313 | sf.write(output_path, out, samplerate=target_sr, format=fmt, subtype=output_subtype) 314 | 315 | if verbose: 316 | print("[mix_audio] Done.") 317 | 318 | return output_path -------------------------------------------------------------------------------- /neuralpiano/enhancer.py: -------------------------------------------------------------------------------- 1 | # Enhancer Python module 2 | 3 | """ 4 | 5 | A lightweight PyTorch-based audio enhancement module focused on 6 | reducing reverb and improving clarity for solo piano or other 7 | monophonic acoustic recordings. 8 | 9 | Features 10 | -------- 11 | - STFT-based spectral denoising with Wiener-style blending. 12 | - 2-D smoothing of spectral gain (time + frequency). 13 | - Band-specific shaping (low / mid / high) and transient preservation. 14 | - Mild multiband compression to control dynamics across frequency bands. 15 | - Subtle harmonic excitation on the highest band to add perceived presence. 16 | - Gentle residual smoothing subtraction to reduce perceived reverb. 17 | - Final limiter and RMS normalization. 18 | - Optional overall gain in dB applied before final limiter/normalization. 19 | - Optional stereo output (mono duplication + per-channel normalization). 20 | 21 | Design notes 22 | ------------ 23 | This module is intended for offline processing of single-channel 24 | recordings. It uses reasonably large FFT sizes and smoothing kernels 25 | to produce stable, musical results. Defaults are tuned for piano 26 | recordings sampled at 48 kHz but are configurable. 27 | 28 | Example 29 | ------- 30 | >>> import soundfile as sf 31 | >>> from enhancer import enhance_audio_full 32 | >>> audio, sr = sf.read("piano_mono.wav") 33 | >>> enhanced, shape = enhance_audio_full(audio, sr=sr, overall_gain_db=-1.0, output_as_stereo=True) 34 | >>> # enhanced is a numpy array (if input was numpy) or torch.Tensor (if input was torch) 35 | >>> # shape describes the returned array shape: (2, samples) for stereo or (samples,) for mono 36 | """ 37 | 38 | from typing import Union, Optional, Tuple 39 | import numpy as np 40 | import torch 41 | import torch.nn.functional as F 42 | 43 | TensorOrArray = Union[torch.Tensor, np.ndarray] 44 | 45 | def _to_torch(x: TensorOrArray, device: torch.device, dtype: torch.dtype) -> torch.Tensor: 46 | if isinstance(x, np.ndarray): 47 | t = torch.from_numpy(x) 48 | else: 49 | t = x.clone() 50 | if not torch.is_floating_point(t): 51 | t = t.float() 52 | return t.to(device=device, dtype=dtype).flatten() 53 | 54 | 55 | def _to_output(t: torch.Tensor, orig: TensorOrArray, return_type: Optional[str]) -> TensorOrArray: 56 | out_type = return_type or ('numpy' if isinstance(orig, np.ndarray) else 'torch') 57 | if out_type == 'numpy': 58 | return t.cpu().numpy() 59 | return t 60 | 61 | 62 | def _rms_val(x: torch.Tensor, eps: float = 1e-12) -> float: 63 | return float(torch.sqrt(torch.mean(x**2) + eps).item()) 64 | 65 | 66 | def _soft_clip(x: torch.Tensor, drive: float = 1.0) -> torch.Tensor: 67 | return torch.tanh(x * drive) / (torch.tanh(torch.tensor(1.0, device=x.device)) + 1e-12) 68 | 69 | 70 | def _multiband_compress(mag: torch.Tensor, 71 | freq_bins: torch.Tensor, 72 | sr: int, 73 | bands: tuple = ((20, 200), (200, 2000), (2000, 8000)), 74 | thresholds_db: tuple = (-18.0, -18.0, -18.0), 75 | ratios: tuple = (1.0, 1.8, 2.2), 76 | attack_frames: int = 1, 77 | release_frames: int = 8, 78 | device: torch.device = torch.device('cpu'), 79 | dtype: torch.dtype = torch.float32) -> torch.Tensor: 80 | """ 81 | 82 | Simple multiband compressor applied to magnitude spectrogram. 83 | 84 | Parameters 85 | ---------- 86 | mag : torch.Tensor 87 | Magnitude spectrogram (bins, frames). 88 | freq_bins : torch.Tensor 89 | Frequency values for each bin (bins,). 90 | sr : int 91 | Sample rate in Hz. 92 | bands : tuple 93 | Sequence of (low_hz, high_hz) band boundaries. 94 | thresholds_db : tuple 95 | Thresholds per band in dB (linear conversion applied internally). 96 | ratios : tuple 97 | Compression ratios per band. 98 | attack_frames : int 99 | Attack smoothing window in frames (approximate). 100 | release_frames : int 101 | Release smoothing window in frames (approximate). 102 | device, dtype : torch types 103 | Device and dtype for intermediate tensors. 104 | 105 | Returns 106 | ------- 107 | torch.Tensor 108 | Compressed magnitude spectrogram (bins, frames). 109 | """ 110 | 111 | bins, frames = mag.shape 112 | out = mag.clone() 113 | for i, band in enumerate(bands): 114 | lo, hi = band 115 | mask = ((freq_bins >= lo) & (freq_bins < hi)).float().unsqueeze(1) 116 | if mask.sum() < 1: 117 | continue 118 | band_mag = (mag * mask).sum(dim=0) / (mask.sum() + 1e-12) 119 | if attack_frames > 0: 120 | env = torch.sqrt(F.conv1d(band_mag.unsqueeze(0).unsqueeze(0)**2, 121 | torch.ones(1, 1, attack_frames, device=device, dtype=dtype) / attack_frames, 122 | padding=0).squeeze() + 1e-12) 123 | else: 124 | env = band_mag.abs() 125 | threshold = 10 ** (thresholds_db[i] / 20.0) 126 | ratio = ratios[i] 127 | gain = torch.ones_like(env) 128 | over = env > threshold 129 | if over.any(): 130 | gain[over] = (threshold + (env[over] - threshold) / ratio) / (env[over] + 1e-12) 131 | kernel = torch.ones(release_frames, device=device, dtype=dtype) / float(release_frames) 132 | g = F.pad(gain.unsqueeze(0).unsqueeze(0), 133 | (release_frames // 2, release_frames - 1 - release_frames // 2), 134 | mode='replicate') 135 | g = F.conv1d(g, kernel.view(1, 1, release_frames)).squeeze() 136 | out = out * (1.0 - mask) + (out * mask) * g.unsqueeze(0) 137 | return out 138 | 139 | 140 | def enhance_audio_full(audio: TensorOrArray, 141 | sr: int = 48000, 142 | device: Union[str, torch.device] = 'cuda', 143 | dtype: torch.dtype = torch.float32, 144 | n_fft: int = 8192, 145 | hop_length: int = 2048, 146 | win_length: Optional[int] = None, 147 | hp_cut_hz: float = 30.0, 148 | denoise_strength: float = 0.55, 149 | min_gain: float = 0.25, 150 | time_smooth_k: int = 9, 151 | freq_smooth_k: int = 15, 152 | low_gain_db: float = -1.8, 153 | mid_gain_db: float = 1.6, 154 | high_gain_db: float = 1.8, 155 | transient_boost: float = 1.12, 156 | excite_amount: float = 0.01, 157 | excite_scale: float = 0.02, 158 | limiter_threshold_db: float = -0.5, 159 | target_rms_db: float = -18.0, 160 | overall_gain_db: float = -1.0, 161 | output_as_stereo: bool = False, 162 | return_type: Optional[str] = None, 163 | verbose: bool = False 164 | ) -> Tuple[TensorOrArray, Tuple[int, ...]]: 165 | 166 | """ 167 | Enhance a full audio buffer using STFT-domain processing. 168 | 169 | This function performs a sequence of spectral processing steps designed 170 | to reduce reverberation, suppress noise, preserve transients, and 171 | increase perceived clarity and presence for solo piano or similar 172 | acoustic material. 173 | 174 | Parameters 175 | ---------- 176 | audio : numpy.ndarray or torch.Tensor 177 | Input audio. Can be a 1-D array/tensor (samples,) or a 2-D array/tensor 178 | with channels. If multi-channel is provided, channels are averaged to 179 | mono before processing. 180 | sr : int, optional 181 | Sample rate in Hz (default 48000). 182 | device : str or torch.device, optional 183 | Device to run processing on (default 'cuda'). If CUDA is not available, 184 | set to 'cpu'. 185 | dtype : torch.dtype, optional 186 | Floating dtype for processing (default torch.float32). 187 | n_fft : int, optional 188 | FFT size for STFT (default 8192). 189 | hop_length : int, optional 190 | Hop length in samples between STFT frames (default 2048). 191 | win_length : int or None, optional 192 | Window length for STFT. If None, defaults to n_fft. 193 | hp_cut_hz : float, optional 194 | High-pass cutoff frequency in Hz applied to spectral gain (default 30.0). 195 | denoise_strength : float, optional 196 | Controls amount of spectral subtraction and Wiener blending (0..1). 197 | min_gain : float, optional 198 | Minimum magnitude floor applied after processing (linear). 199 | time_smooth_k : int, optional 200 | Kernel size (frames) for temporal smoothing of spectral gain. 201 | freq_smooth_k : int, optional 202 | Kernel size (bins) for frequency smoothing of spectral gain. 203 | low_gain_db, mid_gain_db, high_gain_db : float, optional 204 | Per-band gain adjustments in dB applied after denoising. 205 | transient_boost : float, optional 206 | Multiplier for transient preservation (values >= 1.0). 207 | excite_amount : float, optional 208 | Drive for harmonic excitation signal generation (small positive). 209 | excite_scale : float, optional 210 | Scaling factor for adding harmonic excitation back into STFT. 211 | limiter_threshold_db : float, optional 212 | Peak limiter threshold in dB (default -0.5 dB). 213 | target_rms_db : float, optional 214 | Target RMS level in dBFS for final normalization (default -18 dB). 215 | overall_gain_db : float, optional 216 | Final overall gain in dB applied before limiter/normalization. 217 | A recommended default is -1.0 dB to avoid clipping while preserving 218 | perceived loudness improvements. 219 | output_as_stereo : bool, optional 220 | If True, duplicate the processed mono signal into two channels and 221 | normalize each channel to the target RMS (mono duplication + norm). 222 | return_type : str or None, optional 223 | If 'numpy', returns numpy arrays; if 'torch', returns torch tensors; 224 | if None, the return type matches the input type. 225 | verbose : bool, optional 226 | If True, prints processing debug information. 227 | 228 | Returns 229 | ------- 230 | (enhanced, shape) : tuple 231 | - enhanced : numpy.ndarray or torch.Tensor 232 | The processed audio. If `output_as_stereo` is False, this is a 1-D 233 | array/tensor with shape (samples,). If `output_as_stereo` is True, 234 | this is a 2-D array/tensor with shape (2, samples) representing 235 | stereo channels. 236 | - shape : tuple 237 | A small descriptor of the returned shape: 238 | - (n,) for mono output where n is number of samples 239 | - (2, n) for stereo output 240 | 241 | Notes 242 | ----- 243 | - The function converts multi-channel inputs to mono by averaging channels. 244 | - The `overall_gain_db` is applied before the final limiter and RMS 245 | normalization so that the limiter can prevent clipping if the gain 246 | increases peaks above the threshold. 247 | - When `output_as_stereo` is True the mono signal is duplicated and each 248 | channel is scaled to match the `target_rms_db` level independently. 249 | - For best results with piano recordings, use a sample rate of 44.1 kHz 250 | or 48 kHz and keep the input as a clean mono take when possible. 251 | 252 | Example 253 | ------- 254 | >>> enhanced, shape = enhance_audio_full(audio, sr=48000, overall_gain_db=-1.0, output_as_stereo=True) 255 | """ 256 | 257 | device = torch.device(device) 258 | x = _to_torch(audio, device=device, dtype=dtype) 259 | if x.dim() != 1: 260 | if x.dim() == 2: 261 | # If input is (channels, samples) or (samples, channels), try to reduce to mono by averaging channels 262 | if x.shape[0] <= 2 and x.shape[0] < x.shape[1]: 263 | x = x.mean(dim=0) 264 | else: 265 | x = x.mean(dim=1) 266 | else: 267 | x = x.view(-1) 268 | n = x.numel() 269 | if win_length is None: 270 | win_length = n_fft 271 | 272 | if verbose: 273 | print(f"[enhance_v2_fixed] device={device}, dtype={dtype}, n={n}, n_fft={n_fft}, hop={hop_length}, overall_gain_db={overall_gain_db}, output_as_stereo={output_as_stereo}") 274 | 275 | window = torch.hann_window(win_length, device=device, dtype=dtype) 276 | 277 | # Full STFT 278 | X = torch.stft(x, 279 | n_fft=n_fft, 280 | hop_length=hop_length, 281 | win_length=win_length, 282 | window=window, 283 | center=True, 284 | return_complex=True) 285 | 286 | mag = torch.abs(X) 287 | phase = torch.angle(X) 288 | bins, frames = mag.shape 289 | 290 | freq_bins = torch.fft.rfftfreq(n_fft, 1.0 / sr).to(device=device, dtype=dtype) 291 | hp_mask = (freq_bins >= hp_cut_hz).float().unsqueeze(1) 292 | low_mask = (freq_bins <= 200.0).float() 293 | mid_mask = ((freq_bins > 200.0) & (freq_bins <= 2000.0)).float() 294 | high_mask = (freq_bins > 2000.0).float() 295 | 296 | low_gain = 10 ** (low_gain_db / 20.0) 297 | mid_gain = 10 ** (mid_gain_db / 20.0) 298 | high_gain = 10 ** (high_gain_db / 20.0) 299 | band_gain = (low_gain * low_mask + mid_gain * mid_mask + high_gain * high_mask).unsqueeze(1) 300 | 301 | est_samples = min(int(0.5 * sr), n) 302 | est_frames = max(1, int(est_samples / hop_length)) 303 | noise_floor = mag[:, :est_frames].median(dim=1).values.unsqueeze(1).clamp(min=1e-9) 304 | 305 | # Spectral subtraction + Wiener blend 306 | S2 = mag**2 307 | N2 = noise_floor**2 308 | over_sub = 1.0 + (denoise_strength * 0.6) 309 | sub = S2 - over_sub * N2 310 | sub = torch.clamp(sub, min=0.0) 311 | gain = sub / (S2 + 1e-12) 312 | gain = torch.clamp(gain, 0.0, 1.0) 313 | gain = 1.0 - (1.0 - gain) * denoise_strength 314 | 315 | # 2-D smoothing: time then frequency 316 | time_k = max(3, time_smooth_k if time_smooth_k % 2 == 1 else time_smooth_k + 1) 317 | freq_k = max(3, freq_smooth_k if freq_smooth_k % 2 == 1 else freq_smooth_k + 1) 318 | 319 | # Time smoothing (grouped conv across frames) 320 | gain_t = gain.unsqueeze(0) # (1, bins, frames) 321 | bins_count = gain_t.shape[1] 322 | time_kernel = torch.ones(time_k, device=device, dtype=dtype) / float(time_k) 323 | k_time = time_kernel.view(1, 1, time_k).repeat(bins_count, 1, 1) 324 | pad_t = (time_k // 2, time_k - 1 - time_k // 2) 325 | gain_t = F.pad(gain_t, pad_t, mode='replicate') 326 | gain_t = F.conv1d(gain_t, k_time, groups=bins_count).squeeze(0) 327 | 328 | # Frequency smoothing (frames as batch) 329 | gain_f = gain_t.transpose(0, 1).unsqueeze(1) # (frames,1,bins) 330 | freq_kernel = torch.ones(freq_k, device=device, dtype=dtype).view(1, 1, freq_k) / float(freq_k) 331 | pad_f = (freq_k // 2, freq_k - 1 - freq_k // 2) 332 | gain_f = F.pad(gain_f, pad_f, mode='replicate') 333 | gain_f = F.conv1d(gain_f, freq_kernel, groups=1) # (frames,1,bins) 334 | gain = gain_f.squeeze(1).transpose(0, 1) # (bins, frames) 335 | 336 | # Apply highpass mask and band shaping, enforce min_gain floor 337 | gain = gain * hp_mask 338 | mag = mag * gain * band_gain 339 | mag = torch.clamp(mag, min=min_gain * 1e-6) 340 | 341 | # Transient preservation (high-frequency energy rise) 342 | hf = (mag * high_mask.unsqueeze(1)).sum(dim=0) 343 | prev = F.pad(hf, (1, 0))[:-1] 344 | rise = torch.clamp((hf - prev) / (prev + 1e-9), min=0.0) 345 | transient_gain = 1.0 + (transient_boost - 1.0) * torch.clamp(rise * 2.0, 0.0, 1.0) 346 | mag = mag * transient_gain.unsqueeze(0) 347 | 348 | # Mild multiband compression 349 | mag = _multiband_compress(mag, freq_bins, sr, 350 | bands=((20, 200), (200, 2000), (2000, 8000)), 351 | thresholds_db=(-22.0, -20.0, -20.0), 352 | ratios=(1.0, 1.6, 1.8), 353 | attack_frames=1, 354 | release_frames=6, 355 | device=device, 356 | dtype=dtype) 357 | 358 | # Reconstruct complex STFT 359 | X = mag * torch.exp(1j * phase) 360 | 361 | # Very subtle harmonic excitation on highest band 362 | high_mask_full = (freq_bins > 4000.0).float().unsqueeze(1) 363 | X_high = X * high_mask_full 364 | high_time = torch.istft(X_high, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=n) 365 | excite = _soft_clip(high_time, drive=1.0 + excite_amount) - high_time 366 | E = torch.stft(excite, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True, return_complex=True) 367 | X = X + excite_scale * E 368 | 369 | # ISTFT back to time domain 370 | out = torch.istft(X, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=n) 371 | 372 | # Final gentle de-reverb: smooth residual and subtract tiny fraction 373 | residual = out - x 374 | res = residual.cpu() 375 | kernel_len = max(1, int(0.05 * sr)) # 50 ms smoothing 376 | if kernel_len > 1 and res.numel() > kernel_len: 377 | # prepare kernel on CPU 378 | k = torch.ones(1, 1, kernel_len, dtype=res.dtype) / float(kernel_len) 379 | # conv1d expects (batch, channels, length) 380 | res_padded = F.pad(res.unsqueeze(0).unsqueeze(0), (kernel_len // 2, kernel_len - 1 - kernel_len // 2), mode='replicate') 381 | res_smooth = F.conv1d(res_padded, k, padding=0).squeeze() 382 | # Align length robustly: trim or pad to match n 383 | if res_smooth.numel() > n: 384 | res_smooth = res_smooth[:n] 385 | elif res_smooth.numel() < n: 386 | res_smooth = F.pad(res_smooth, (0, n - res_smooth.numel())) 387 | # subtract a tiny fraction of smoothed residual to reduce perceived reverb 388 | out = out - 0.02 * res_smooth.to(device=device, dtype=dtype) 389 | 390 | # Apply overall gain in dB (before final limiter/normalization) 391 | if abs(overall_gain_db) > 1e-6: 392 | gain_lin = 10 ** (overall_gain_db / 20.0) 393 | out = out * gain_lin 394 | 395 | # Final limiter and RMS normalization 396 | peak = out.abs().max().clamp(min=1e-12).item() 397 | threshold = 10 ** (limiter_threshold_db / 20.0) 398 | if peak > threshold: 399 | out = out * (threshold / peak) 400 | 401 | current_rms = _rms_val(out) 402 | target_rms = 10 ** (target_rms_db / 20.0) 403 | if current_rms > 1e-12: 404 | out = out * (target_rms / current_rms) 405 | 406 | # If stereo output requested: duplicate mono to two channels and normalize (mono duplication + norm) 407 | if output_as_stereo: 408 | # create (2, n) tensor 409 | out_stereo = torch.stack([out, out], dim=0) # (2, n) 410 | # ensure each channel has same RMS as target_rms 411 | ch_rms = torch.sqrt(torch.mean(out_stereo**2, dim=1) + 1e-12) 412 | # avoid division by zero 413 | scale = (target_rms / ch_rms).unsqueeze(1) 414 | out_stereo = out_stereo * scale.to(device=device, dtype=dtype) 415 | final_out = out_stereo 416 | final_shape = (2, n) 417 | else: 418 | final_out = out 419 | final_shape = (n,) 420 | 421 | return _to_output(final_out, audio, return_type), final_shape -------------------------------------------------------------------------------- /neuralpiano/master.py: -------------------------------------------------------------------------------- 1 | # Master Python module 2 | 3 | import torch 4 | import math 5 | from typing import Dict, Tuple 6 | 7 | Tensor = torch.Tensor 8 | 9 | def master_mono_piano(mono: Tensor, 10 | sr: int = 48000, 11 | hp_cut: float = 20.0, 12 | low_shelf_db: float = -1.5, 13 | low_shelf_fc: float = 250.0, 14 | presence_boost_db: float = 1.6, 15 | presence_fc: float = 3500.0, 16 | compressor_thresh_db: float = -6.0, 17 | compressor_ratio: float = 2.5, 18 | compressor_attack_ms: float = 8.0, 19 | compressor_release_ms: float = 80.0, 20 | compressor_makeup_db: float = 0.0, 21 | stereo_spread_ms: float = 6.0, 22 | stereo_spread_level_db: float = -12.0, 23 | reverb_size_sec: float = 0.6, 24 | reverb_mix: float = 0.06, 25 | limiter_ceiling_db: float = -0.3, 26 | dithering_bits: int = 24, 27 | gain_db: float = 20.0, 28 | device: torch.device = None, 29 | dtype: torch.dtype = None, 30 | iir_impulse_len: int = 2048, # length used to approximate IIRs as FIRs (tradeoff speed/accuracy) 31 | ) -> Tuple[Tensor, Dict]: 32 | """ 33 | Mastering pipeline for a mono input piano track producing a stereo master and diagnostics. 34 | 35 | Description and design choices 36 | - Purpose: provide a compact, deterministic, and high-throughput mastering chain implemented 37 | with PyTorch tensors. The pipeline uses short FIR approximations for small IIR filters, 38 | single-shot FFT convolution for medium/long FIRs (cascaded HP + shelf, reverb), a fast 39 | RMS-based compressor, fractional-delay stereo widening, a smooth soft limiter, and 40 | deterministic TPDF dithering. 41 | - Where gain_db is applied: an explicit master gain parameter (gain_db) is applied after 42 | EQ/compression/stereo processing and reverb but before the soft limiter and final safety 43 | scaling. This placement allows the limiter to react to the applied gain, preserving 44 | consistent ceiling behavior while enabling transparent loudness adjustments and avoiding 45 | excessive post-limiter boosting which would bypass the limiter's protection. 46 | - Implementation details: 47 | * EQ: designs 2nd-order biquads and converts them to short FIR impulse responses using 48 | direct IIR stepping (iir_to_fir). HP and low-shelf are cascaded via FFT convolution. 49 | * Presence: implemented as a small symmetric FIR (sinc-windowed) and applied with conv1d. 50 | * Compression: fast RMS detector downsampled to ~4 kHz with an attack/release recurrence 51 | executed on CPU for determinism and efficiency. Soft-knee gain computer produces a 52 | time-varying linear gain applied to the signal with optional makeup gain. 53 | * Stereo widen: fractional sub-sample delays are used to create left/right channels, then 54 | mid/side processing adjusts side level. 55 | * Reverb: simple tapped + exponential-tail IR built and FFT-convolved once. 56 | * Limiter: soft tanh-based limiter that scales output to target ceiling (limiter_ceiling_db). 57 | * Dither: deterministic vectorized LCG generates TPDF dither for specified bit depth. 58 | - Determinism and diagnostics: the function avoids non-deterministic ops and returns a 59 | diagnostics dict with numeric metrics (peaks, RMS, average reduction, applied scales). 60 | - Precision and device handling: prefers float32 for performance; preserves device assignment; 61 | convolution routines use torch.fft (rfft/irfft) for CPU/GPU compatibility. 62 | 63 | Parameters 64 | - mono: Tensor of shape [N] or [1, N], mono PCM in float range roughly [-1, 1]. 65 | - sr: sample rate in Hz. 66 | - hp_cut: high-pass cutoff frequency in Hz. 67 | - low_shelf_db: low-shelf gain (dB). 68 | - low_shelf_fc: low-shelf center freq (Hz). 69 | - presence_boost_db: narrow presence boost amount (dB). 70 | - presence_fc: center freq of presence boost (Hz). 71 | - compressor_...: compressor threshold (dBFS), ratio, attack and release (ms), and makeup (dB). 72 | - stereo_spread_ms: nominal stereo spread in milliseconds used to compute fractional delays. 73 | - stereo_spread_level_db: side gain in dB applied to M/S side channel. 74 | - reverb_size_sec: approximate reverb tail time in seconds (affects IR length). 75 | - reverb_mix: wet fraction of reverb to blend with dry signal. 76 | - limiter_ceiling_db: final soft limiter ceiling in dBFS (should be <= 0). 77 | - dithering_bits: integer bits for TPDF dithering (1-32). Use 0 or outside range to disable. 78 | - gain_db: master gain applied (dB). Positive increases loudness, negative reduces. Applied 79 | before the limiter so the limiter shapes peaks introduced by the gain. 80 | - device, dtype: optional torch.device and dtype to force placement/precision. 81 | - iir_impulse_len: length of FIR used to approximate 2nd-order IIRs (tradeoff accuracy/speed). 82 | 83 | Returns 84 | - stereo: Tensor shaped [2, N] float32 in range (-1, 1) representing left/right master. 85 | - diagnostics: Dict with numeric measurements and settings for reproducibility. 86 | 87 | Notes 88 | - The function keeps intermediary operations vectorized. Non-trivial recurrence for the RMS 89 | detector runs on CPU for stability and determinism but the returned envelopes are moved back 90 | to the target device and dtype. 91 | - For large inputs and GPU usage, FFT sizes are chosen as powers of two for efficiency. 92 | - If you want gain to be applied as a pre-EQ trim (instead of pre-limiter), move the gain 93 | multiplication earlier; current placement intentionally lets the limiter handle the gain. 94 | """ 95 | 96 | # --- Setup, sanitize --- 97 | if mono.ndim == 2 and mono.shape[0] == 1: 98 | x = mono.squeeze(0) 99 | elif mono.ndim == 1: 100 | x = mono 101 | else: 102 | raise ValueError("mono must be shape [1, N] or [N]") 103 | 104 | device = device or x.device 105 | # prefer float32 for fastest throughput unless user explicitly provided float64 106 | if dtype is None: 107 | dtype = x.dtype if x.dtype in (torch.float32, torch.float64) else torch.float32 108 | x = x.to(device=device, dtype=dtype, copy=False) 109 | N = x.shape[-1] 110 | eps = 1e-12 111 | 112 | diagnostics: Dict = {} 113 | def db2lin(d): return 10.0 ** (d / 20.0) 114 | def lin2db(v): return 20.0 * math.log10(max(v, 1e-12)) 115 | 116 | # --- Helper: design biquad coefficients (as before) --- 117 | def design_butter_hp(fc, fs, Q=0.7071): 118 | omega = 2.0 * math.pi * fc / fs 119 | alpha = math.sin(omega) / (2.0 * Q) 120 | cosw = math.cos(omega) 121 | b0 = (1 + cosw) / 2.0 122 | b1 = -(1 + cosw) 123 | b2 = (1 + cosw) / 2.0 124 | a0 = 1 + alpha 125 | a1 = -2 * cosw 126 | a2 = 1 - alpha 127 | return b0, b1, b2, a0, a1, a2 128 | 129 | def design_low_shelf(fc, fs, gain_db, Q=0.7071): 130 | A = 10 ** (gain_db / 40.0) 131 | w0 = 2.0 * math.pi * fc / fs 132 | alpha = math.sin(w0) / (2.0 * Q) 133 | cosw = math.cos(w0) 134 | b0 = A*( (A+1) - (A-1)*cosw + 2*math.sqrt(A)*alpha ) 135 | b1 = 2*A*( (A-1) - (A+1)*cosw ) 136 | b2 = A*( (A+1) - (A-1)*cosw - 2*math.sqrt(A)*alpha ) 137 | a0 = (A+1) + (A-1)*cosw + 2*math.sqrt(A)*alpha 138 | a1 = -2*( (A-1) + (A+1)*cosw ) 139 | a2 = (A+1) + (A-1)*cosw - 2*math.sqrt(A)*alpha 140 | return b0, b1, b2, a0, a1, a2 141 | 142 | # --- Utility: convert a 2nd-order IIR (b,a) to an FIR impulse response of length L 143 | # compute impulse response by stepping the IIR for L samples (this loop runs only for L ~ 1-4k) 144 | def iir_to_fir(b0, b1, b2, a0, a1, a2, L, device, dtype): 145 | # normalize 146 | b0n, b1n, b2n = b0 / a0, b1 / a0, b2 / a0 147 | a1n, a2n = a1 / a0, a2 / a0 148 | # compute impulse response on CPU or device (run on device if small L and device != cpu) 149 | run_device = device if (device.type != 'cuda' or L <= 8192) else device 150 | h = torch.zeros(L, device=run_device, dtype=dtype) 151 | x_prev1 = 0.0 152 | x_prev2 = 0.0 153 | y_prev1 = 0.0 154 | y_prev2 = 0.0 155 | # feed delta(0)=1, others 0 156 | for n in range(L): 157 | xv = 1.0 if n == 0 else 0.0 158 | yv = b0n * xv + b1n * x_prev1 + b2n * x_prev2 - a1n * y_prev1 - a2n * y_prev2 159 | h[n] = yv 160 | x_prev2 = x_prev1 161 | x_prev1 = xv 162 | y_prev2 = y_prev1 163 | y_prev1 = yv 164 | # ensure on main device 165 | if h.device != device: 166 | h = h.to(device=device, dtype=dtype) 167 | return h 168 | 169 | # --- 1) Build IIR -> FIR approximations for HP and shelf (single impulse responses) --- 170 | # keep impulse length small (configurable) to balance cost/accuracy 171 | b0,b1,b2,a0,a1,a2 = design_butter_hp(hp_cut, sr) 172 | hp_ir = iir_to_fir(b0,b1,b2,a0,a1,a2, iir_impulse_len, device, dtype) 173 | 174 | b0s,b1s,b2s,a0s,a1s,a2s = design_low_shelf(low_shelf_fc, sr, low_shelf_db) 175 | shelf_ir = iir_to_fir(b0s,b1s,b2s,a0s,a1s,a2s, iir_impulse_len, device, dtype) 176 | 177 | # cascade IIRs by convolving their IRs (use FFT convolution) 178 | def fft_convolve_full(sig, kernel, device, dtype): 179 | n = sig.shape[-1] 180 | k = kernel.shape[-1] 181 | out_len = n + k - 1 182 | size = 1 << ((out_len - 1).bit_length()) 183 | # cast to complex-friendly dtype (float32/64) 184 | S = torch.fft.rfft(sig, n=size) 185 | K = torch.fft.rfft(kernel, n=size) 186 | Y = S * K 187 | y = torch.fft.irfft(Y, n=size)[:out_len] 188 | # return same length as input (valid-ish) by trimming convolution to center-left aligned (like original) 189 | return y[:n] 190 | 191 | # apply HP then shelf by convolving with cascaded IR (hp_ir * shelf_ir) 192 | casc_ir = fft_convolve_full(hp_ir, shelf_ir, device, dtype)[:max(hp_ir.numel(), shelf_ir.numel())] 193 | # normalize tiny numerical offsets 194 | casc_ir = casc_ir / (casc_ir.abs().sum().clamp(min=eps)) 195 | 196 | # apply cascade IR to input (single FFT conv) 197 | x_hp_shelf = fft_convolve_full(x, casc_ir, device, dtype) 198 | 199 | # --- 2) Presence boost (small FIR) same approach but small kernel conv1d is very fast --- 200 | pres_len = min(256, max(65, int(sr * 0.0045))) # ~4.5 ms 201 | t_idx = torch.arange(pres_len, device=device, dtype=dtype) - (pres_len - 1) / 2.0 202 | h = (torch.sinc(2.0 * presence_fc / sr * t_idx) * torch.hann_window(pres_len, device=device, dtype=dtype)) 203 | h = h / (h.abs().sum() + eps) 204 | gain_lin = db2lin(presence_boost_db) 205 | presence_ir = (gain_lin - 1.0) * h 206 | presence_ir[(pres_len - 1) // 2] += 1.0 207 | # conv small kernel with conv1d (fast) 208 | x_eq = torch.nn.functional.conv1d(x_hp_shelf.view(1,1,-1), presence_ir.view(1,1,-1), padding=(pres_len-1)//2).view(-1) 209 | 210 | diagnostics.update({ 211 | "hp_cut": hp_cut, 212 | "low_shelf_db": low_shelf_db, 213 | "presence_db": presence_boost_db, 214 | "presence_len": pres_len, 215 | "iir_impulse_len": iir_impulse_len, 216 | }) 217 | 218 | # --- 3) Fast RMS compressor (vectorized with downsampled detector) --- 219 | sig = x_eq 220 | attack_tc = math.exp(-1.0 / max(1.0, (compressor_attack_ms * sr / 1000.0))) 221 | release_tc = math.exp(-1.0 / max(1.0, (compressor_release_ms * sr / 1000.0))) 222 | sq = sig * sig 223 | 224 | ds = max(1, int(sr // 4000)) # detector rate ~4kHz 225 | if ds > 1: 226 | # pad to multiple of ds 227 | pad = (-sq.shape[-1]) % ds 228 | if pad: 229 | sq_pad = torch.nn.functional.pad(sq, (0, pad)) 230 | else: 231 | sq_pad = sq 232 | sq_ds = sq_pad.view(-1).reshape(-1, ds).mean(dim=1) 233 | else: 234 | sq_ds = sq 235 | 236 | # recurrence on small downsampled vector executed on CPU (cheap) 237 | sq_ds_cpu = sq_ds.detach().cpu() 238 | env_ds = torch.empty_like(sq_ds_cpu) 239 | s_val = float(sq_ds_cpu[0].item()) 240 | a = attack_tc 241 | r = release_tc 242 | for i in range(sq_ds_cpu.shape[0]): 243 | v = float(sq_ds_cpu[i].item()) 244 | if v > s_val: 245 | s_val = a * s_val + (1.0 - a) * v 246 | else: 247 | s_val = r * s_val + (1.0 - r) * v 248 | env_ds[i] = s_val 249 | env_ds = env_ds.to(device=device, dtype=dtype) 250 | 251 | if ds > 1: 252 | env = env_ds.repeat_interleave(ds)[:N] 253 | else: 254 | env = env_ds 255 | 256 | rms_env = torch.sqrt(torch.clamp(env, min=eps)) 257 | lvl_db = 20.0 * torch.log10(torch.clamp(rms_env, min=1e-12)) 258 | knee = 3.0 259 | over = lvl_db - compressor_thresh_db 260 | # soft knee 261 | zero = torch.zeros_like(over) 262 | gain_reduction_db = torch.where( 263 | over <= -knee, 264 | zero, 265 | torch.where( 266 | over >= knee, 267 | compressor_thresh_db + (over / compressor_ratio) - lvl_db, 268 | - ((1.0 - 1.0/compressor_ratio) * (over + knee)**2) / (4.0 * knee) 269 | ) 270 | ).clamp(max=0.0) 271 | gain_lin = 10.0 ** (gain_reduction_db / 20.0) 272 | if ds > 1: 273 | gain_full = gain_lin.repeat_interleave(ds)[:N] 274 | else: 275 | gain_full = gain_lin 276 | makeup = db2lin(compressor_makeup_db) 277 | comp_out = sig * gain_full * makeup 278 | 279 | diagnostics.update({ 280 | "compressor_thresh_db": compressor_thresh_db, 281 | "compressor_ratio": compressor_ratio, 282 | "compressor_attack_ms": compressor_attack_ms, 283 | "compressor_release_ms": compressor_release_ms, 284 | "compressor_makeup_db": compressor_makeup_db, 285 | "detector_downsample": ds, 286 | "avg_reduction_db": float((20.0 * torch.log10((gain_full.mean().clamp(min=1e-12))).item())), 287 | }) 288 | 289 | # --- 4) Stereo widening with fractional sub-sample delays (vectorized) --- 290 | spread_samples = max(1e-4, stereo_spread_ms * sr / 1000.0) 291 | left_delay = spread_samples * 0.5 292 | right_delay = -spread_samples * 0.3333 293 | 294 | def fractional_delay_vec(sig, delay): 295 | n = sig.shape[-1] 296 | idx = torch.arange(n, device=device, dtype=dtype) 297 | pos = idx - delay 298 | pos_floor = pos.floor().long() 299 | pos_ceil = pos_floor + 1 300 | frac = (pos - pos_floor.to(dtype)) 301 | pos_floor = pos_floor.clamp(0, n-1) 302 | pos_ceil = pos_ceil.clamp(0, n-1) 303 | s_floor = sig[pos_floor] 304 | s_ceil = sig[pos_ceil] 305 | return s_floor * (1.0 - frac) + s_ceil * frac 306 | 307 | left = 0.985 * comp_out + 0.015 * fractional_delay_vec(comp_out, left_delay) 308 | right = 0.985 * comp_out + 0.015 * fractional_delay_vec(comp_out, right_delay) 309 | 310 | mid = 0.5 * (left + right) 311 | side = 0.5 * (left - right) 312 | side = side * db2lin(stereo_spread_level_db) 313 | left = mid + side 314 | right = mid - side 315 | 316 | # --- 5) Reverb: build IR and FFT-convolve (single-shot) --- 317 | reverb_len = int(min(int(sr * reverb_size_sec), 65536)) 318 | reverb_len = max(reverb_len, int(0.02 * sr)) 319 | t = torch.arange(reverb_len, device=device, dtype=dtype) 320 | tail = torch.exp(-t / (reverb_size_sec * sr + 1e-12)) 321 | taps_ms = [12, 23, 37, 53, 79] 322 | ir = torch.zeros(reverb_len, device=device, dtype=dtype) 323 | for i, tm in enumerate(taps_ms): 324 | idx = int(round(sr * tm / 1000.0)) 325 | if idx < reverb_len: 326 | ir[idx] += (0.5 ** (i + 1)) 327 | ir += 0.15 * tail 328 | ir = ir / (ir.abs().sum() + eps) 329 | 330 | left_rev = fft_convolve_full(left, ir, device, dtype) 331 | right_rev = fft_convolve_full(right, ir, device, dtype) 332 | left = (1.0 - reverb_mix) * left + reverb_mix * left_rev 333 | right = (1.0 - reverb_mix) * right + reverb_mix * right_rev 334 | 335 | diagnostics.update({ 336 | "reverb_size_sec": reverb_size_sec, 337 | "reverb_mix": reverb_mix, 338 | "reverb_len": reverb_len, 339 | }) 340 | 341 | # --- MASTER GAIN: apply desired gain in linear domain before limiter --- 342 | if abs(gain_db) > 1e-12: 343 | gain_lin_master = db2lin(gain_db) 344 | left = left * gain_lin_master 345 | right = right * gain_lin_master 346 | diagnostics['applied_gain_db'] = float(gain_db) 347 | diagnostics['applied_gain_lin'] = float(gain_lin_master) 348 | else: 349 | diagnostics['applied_gain_db'] = 0.0 350 | diagnostics['applied_gain_lin'] = 1.0 351 | 352 | # --- 6) Soft limiter --- 353 | def soft_limiter(x_chan, ceiling_db): 354 | ceiling_lin = db2lin(ceiling_db) 355 | peak = x_chan.abs().max().clamp(min=eps) 356 | if peak <= ceiling_lin: 357 | return x_chan 358 | scaled = x_chan * (ceiling_lin / peak) 359 | out = torch.tanh(scaled * 1.25) / 1.25 360 | out = out / out.abs().max().clamp(min=eps) * ceiling_lin 361 | return out 362 | 363 | left = soft_limiter(left, limiter_ceiling_db) 364 | right = soft_limiter(right, limiter_ceiling_db) 365 | 366 | # final safety scaling 367 | peak_val = max(left.abs().max().item(), right.abs().max().item()) 368 | if peak_val > 0.999: 369 | scale = 0.999 / peak_val 370 | left = left * scale 371 | right = right * scale 372 | diagnostics['final_scale'] = float(scale) 373 | else: 374 | diagnostics['final_scale'] = 1.0 375 | 376 | # --- 7) Deterministic TPDF dithering (vectorized LCG) --- 377 | def vectorized_lcg(sz, seed): 378 | a = 1103515245 379 | c = 12345 380 | mod = 2**31 381 | seeds = (torch.arange(sz, device=device, dtype=torch.int64) * 1664525 + int(seed)) & (mod - 1) 382 | vals = (a * seeds + c) & (mod - 1) 383 | floats = (vals.to(dtype) / float(mod)) - 0.5 384 | return floats 385 | 386 | if 1 <= dithering_bits <= 32: 387 | q = 1.0 / (2 ** (dithering_bits - 1)) 388 | seed = (N ^ sr ^ 0x9e3779b1) & 0xffffffff 389 | na = vectorized_lcg(N, seed) 390 | nb = vectorized_lcg(N, seed ^ 0x6a09e667) 391 | tpdf = (na - nb) * q 392 | left = left + 0.5 * tpdf 393 | right = right + 0.5 * tpdf 394 | 395 | # --- Output and diagnostics --- 396 | stereo = torch.stack([left.to(torch.float32), right.to(torch.float32)], dim=0) 397 | stereo = stereo.clamp(-1.0 + 1e-9, 1.0 - 1e-9) 398 | 399 | left_peak = left.abs().max().item() 400 | right_peak = right.abs().max().item() 401 | left_rms = math.sqrt(float(torch.mean(left * left).item())) 402 | right_rms = math.sqrt(float(torch.mean(right * right).item())) 403 | diagnostics.update({ 404 | "left_peak": left_peak, "right_peak": right_peak, 405 | "left_peak_db": lin2db(left_peak), "right_peak_db": lin2db(right_peak), 406 | "left_rms_db": lin2db(left_rms), "right_rms_db": lin2db(right_rms), 407 | "num_samples": N, "sample_rate": sr, 408 | }) 409 | 410 | return stereo, diagnostics -------------------------------------------------------------------------------- /neuralpiano/neuralpiano.py: -------------------------------------------------------------------------------- 1 | #=============================================================================== 2 | # Neural Piano main Python module 3 | #=============================================================================== 4 | 5 | """ 6 | 7 | This module exposes `render_midi`, a high-level convenience function that: 8 | - Renders a MIDI file to a raw waveform using a SoundFont (SF2) via `midirenderer`. 9 | - Loads the rendered waveform and optionally trims silence. 10 | - Encodes the waveform into a latent representation and decodes it using a 11 | learned Encoder/Decoder model to produce a high-quality piano audio render. 12 | - Optionally applies a sequence of post-processing steps: denoising, 13 | bass enhancement, full-spectrum enhancement, and mastering. 14 | - Writes the final audio to disk or returns it in-memory. 15 | 16 | Design goals 17 | ------------ 18 | - Provide a single, well-documented function to convert MIDI -> polished WAV. 19 | - Keep sensible defaults so the function works out-of-the-box with a 20 | `models/` directory containing the required SoundFont and model artifacts. 21 | - Allow advanced users to override model paths, processing parameters, and 22 | device selection. 23 | 24 | Dependencies 25 | ------------ 26 | - Python 3.8+ 27 | - torch 28 | - librosa 29 | - soundfile (pysoundfile) 30 | - midirenderer 31 | - The package's internal modules: 32 | - .music2latent.inference.EncoderDecoder 33 | - .denoise.denoise_audio 34 | - .bass.enhance_audio_bass 35 | - .enhancer.enhance_audio_full 36 | - .master.master_mono_piano 37 | 38 | Typical usage 39 | ------------- 40 | >>> from neuralpiano.main import render_midi 41 | >>> out_path = render_midi("score.mid") # writes ./score.wav using defaults 42 | >>> audio, sr = render_midi("score.mid", return_audio=True) # get numpy array and sr 43 | 44 | Notes and behavior 45 | ------------------ 46 | - By default the function expects a `models/` directory in the current working 47 | directory and a SoundFont file named `SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2`. 48 | Use `sf2_name` or `custom_model_path` to override. 49 | - The EncoderDecoder is instantiated per-call. For repeated renders in a 50 | long-running process, consider reusing a single EncoderDecoder instance 51 | (not provided by this convenience wrapper). 52 | - `sample_rate` controls the resampling rate used when loading the rendered 53 | MIDI waveform and is propagated to downstream processing where relevant. 54 | - `trim_silence` uses `librosa.effects.trim` with configurable `trim_top_db`, 55 | `trim_frame_length`, and `trim_hop_length`. 56 | - Post-processing steps are applied in this order when enabled: 57 | 1. denoise (if `denoise=True`) 58 | 2. bass enhancement (if `enhance_bass=True`) 59 | 3. full enhancement (if `enhance_full=True`) 60 | 4. mastering (if `master=True`) 61 | Each step accepts a kwargs dict (e.g., `denoise_kwargs`) to override defaults. 62 | - `device` accepts any value accepted by `torch.device` (e.g., 'cuda', 'cpu'). 63 | - When `return_audio=True`, the function returns `(final_audio, sample_rate)` as 64 | a NumPy array and sample rate. Otherwise it writes a WAV file and returns the 65 | output file path. 66 | - Verbosity: 67 | - `verbose` prints high-level progress messages. 68 | - `verbose_diag` prints additional diagnostic values useful for debugging. 69 | 70 | Exceptions 71 | ---------- 72 | The function may raise exceptions originating from: 73 | - File I/O (missing MIDI or models, permission errors). 74 | - `midirenderer` if the SoundFont or MIDI bytes are invalid. 75 | - `librosa` when loading or trimming audio. 76 | - Torch/model-related errors when instantiating or running the EncoderDecoder. 77 | - Any of the post-processing modules if they encounter invalid inputs. 78 | 79 | Parameters 80 | ---------- 81 | input_midi_file : str 82 | Path to the input MIDI file to render. 83 | output_audio_file : str or None 84 | Path to write the final WAV file. If None, the output filename is derived 85 | from the MIDI filename and written to the current working directory. 86 | sample_rate : int 87 | Target sample rate for loading and processing audio (default: 48000). 88 | denoising_steps : int 89 | Default number of denoising steps passed to the decoder. 90 | max_batch_size : int or None 91 | Maximum batch size to use when encoding/decoding; passed to EncoderDecoder. 92 | use_v1_piano_model : bool 93 | If True, instructs EncoderDecoder to use the v1 piano model variant. 94 | load_multi_instrumental_model : bool 95 | If True, load a multi-instrument model variant (if available). 96 | custom_model_path : str or None 97 | Path to a custom model checkpoint for inference; passed to EncoderDecoder. 98 | sf2_name : str 99 | Filename of the SoundFont inside the `models/` directory (default provided). 100 | trim_silence : bool 101 | If True, trim leading/trailing silence from the rendered waveform. 102 | trim_top_db : float 103 | `top_db` parameter for `librosa.effects.trim` (default: 60). 104 | trim_frame_length : int 105 | `frame_length` parameter for `librosa.effects.trim` (default: 2048). 106 | trim_hop_length : int 107 | `hop_length` parameter for `librosa.effects.trim` (default: 512). 108 | denoise : bool 109 | If True, run the denoiser post-processing step. 110 | denoise_kwargs : dict or None 111 | Additional keyword arguments for `denoise_audio`. 112 | enhance_bass : bool 113 | If True, run bass enhancement post-processing. 114 | bass_kwargs : dict or None 115 | Additional keyword arguments for `enhance_audio_bass`. `low_gain_db` is set 116 | from the top-level `low_gain_db` unless overridden here. 117 | low_gain_db : float 118 | Default low-frequency gain (dB) used by bass enhancement (default: 8.0). 119 | enhance_full : bool 120 | If True, run full-spectrum enhancement post-processing. 121 | enhance_full_kwargs : dict or None 122 | Additional keyword arguments for `enhance_audio_full`. 123 | master : bool 124 | If True, run final mastering pass. 125 | master_kwargs : dict or None 126 | Additional keyword arguments for `master_mono_piano`. `gain_db` is set from 127 | the top-level `overall_gain_db` unless overridden here. 128 | overall_gain_db : float 129 | Default gain (dB) applied during mastering (default: 10.0). 130 | device : str 131 | Device string for torch (e.g., 'cuda' or 'cpu'). Converted to `torch.device` 132 | when passed to post-processing functions. 133 | return_audio : bool 134 | If True, return `(audio_numpy, sample_rate)` instead of writing a file. 135 | verbose : bool 136 | Print progress messages when True. 137 | verbose_diag : bool 138 | Print diagnostic information when True. 139 | 140 | Example 141 | ------- 142 | Render a MIDI to disk with default processing: 143 | 144 | >>> render_midi("song.mid") 145 | 146 | Render and receive audio in-memory without post-processing: 147 | 148 | >>> audio, sr = render_midi("song.mid", denoise=False, enhance_bass=False, 149 | ... enhance_full=False, master=False, return_audio=True) 150 | 151 | """ 152 | 153 | #=============================================================================== 154 | 155 | import os 156 | import io 157 | from pathlib import Path 158 | 159 | import torch 160 | 161 | import librosa 162 | import soundfile as sf 163 | 164 | import midirenderer 165 | 166 | from .music2latent.inference import EncoderDecoder 167 | 168 | from .denoise import denoise_audio 169 | 170 | from .bass import enhance_audio_bass 171 | 172 | from .enhancer import enhance_audio_full 173 | 174 | from .master import master_mono_piano 175 | 176 | #=============================================================================== 177 | 178 | def render_midi(input_midi_file, 179 | output_audio_file=None, 180 | sample_rate=48000, 181 | denoising_steps=10, 182 | max_batch_size=None, 183 | use_v1_piano_model=False, 184 | load_multi_instrumental_model=False, 185 | custom_model_path=None, 186 | sf2_name='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2', 187 | trim_silence=True, 188 | trim_top_db=60, 189 | trim_frame_length=2048, 190 | trim_hop_length=512, 191 | denoise=True, 192 | denoise_kwargs=None, 193 | enhance_bass=False, 194 | bass_kwargs=None, 195 | low_gain_db=8.0, 196 | enhance_full=True, 197 | enhance_full_kwargs=None, 198 | master=False, 199 | master_kwargs=None, 200 | overall_gain_db=10.0, 201 | device='cuda', 202 | return_audio=False, 203 | verbose=True, 204 | verbose_diag=False 205 | ): 206 | 207 | """ 208 | Render a MIDI file to a polished piano audio waveform or WAV file. 209 | 210 | This function orchestrates the full Neural Piano pipeline: 211 | 1. Render MIDI -> raw waveform using a SoundFont via `midirenderer`. 212 | 2. Load and optionally trim silence from the rendered waveform. 213 | 3. Encode waveform to latent and decode with the model to synthesize audio. 214 | 4. Optionally apply denoising, bass enhancement, full enhancement, and mastering. 215 | 5. Return the final audio as a NumPy array or write it to disk as a WAV file. 216 | 217 | Parameters 218 | ---------- 219 | input_midi_file : str 220 | Path to the input MIDI file. 221 | output_audio_file : str or None 222 | Path to write the final WAV file. If None, derived from MIDI filename. 223 | sample_rate : int 224 | Sample rate for loading and processing audio (default: 48000). 225 | denoising_steps : int 226 | Default denoising steps passed to decoder. 227 | max_batch_size : int or None 228 | Max batch size for EncoderDecoder operations. 229 | use_v1_piano_model : bool 230 | Use v1 piano model variant if True. 231 | load_multi_instrumental_model : bool 232 | Load multi-instrument model variant if True. 233 | custom_model_path : str or None 234 | Path to a custom model checkpoint for inference. 235 | sf2_name : str 236 | SoundFont filename located in the `models/` directory. 237 | trim_silence : bool 238 | Trim leading/trailing silence from the rendered waveform. 239 | trim_top_db : float 240 | `top_db` for `librosa.effects.trim`. 241 | trim_frame_length : int 242 | `frame_length` for `librosa.effects.trim`. 243 | trim_hop_length : int 244 | `hop_length` for `librosa.effects.trim`. 245 | denoise : bool 246 | Run denoiser post-processing when True. 247 | denoise_kwargs : dict or None 248 | Extra kwargs for `denoise_audio`. 249 | enhance_bass : bool 250 | Run bass enhancement when True. 251 | bass_kwargs : dict or None 252 | Extra kwargs for `enhance_audio_bass`. `low_gain_db` is set from the 253 | top-level argument unless overridden here. 254 | low_gain_db : float 255 | Default low-frequency gain (dB) for bass enhancement. 256 | enhance_full : bool 257 | Run full-spectrum enhancement when True. 258 | enhance_full_kwargs : dict or None 259 | Extra kwargs for `enhance_audio_full`. 260 | master : bool 261 | Run final mastering when True. 262 | master_kwargs : dict or None 263 | Extra kwargs for `master_mono_piano`. `gain_db` is set from the top-level 264 | argument unless overridden here. 265 | overall_gain_db : float 266 | Default gain (dB) applied during mastering. 267 | device : str 268 | Torch device string (e.g., 'cuda' or 'cpu'). 269 | return_audio : bool 270 | If True, return `(audio_numpy, sample_rate)` instead of writing a file. 271 | verbose : bool 272 | Print progress messages. 273 | verbose_diag : bool 274 | Print diagnostic information for debugging. 275 | 276 | Returns 277 | ------- 278 | str or (numpy.ndarray, int) 279 | If `return_audio` is False (default), returns the path to the written WAV file. 280 | If `return_audio` is True, returns a tuple `(audio_numpy, sample_rate)` where 281 | `audio_numpy` is a 1-D NumPy array (mono) and `sample_rate` is an int. 282 | 283 | Raises 284 | ------ 285 | FileNotFoundError 286 | If the input MIDI file or required model/SF2 files are missing. 287 | RuntimeError 288 | If model inference or post-processing fails (propagates underlying errors). 289 | 290 | """ 291 | 292 | def _pv(msg): 293 | if verbose: 294 | print(msg) 295 | 296 | _pv('=' * 70) 297 | _pv('Neural Piano') 298 | _pv('=' * 70) 299 | 300 | # Normalize kwargs buckets 301 | denoise_kwargs = {} if denoise_kwargs is None else dict(denoise_kwargs) 302 | bass_kwargs = {} if bass_kwargs is None else dict(bass_kwargs) 303 | enhance_full_kwargs = {} if enhance_full_kwargs is None else dict(enhance_full_kwargs) 304 | master_kwargs = {} if master_kwargs is None else dict(master_kwargs) 305 | 306 | # Provide sensible defaults from top-level args unless overridden in kwargs 307 | if 'low_gain_db' not in bass_kwargs: 308 | bass_kwargs['low_gain_db'] = low_gain_db 309 | 310 | if 'overall_gain_db' not in enhance_full_kwargs: 311 | enhance_full_kwargs['overall_gain_db'] = overall_gain_db 312 | 313 | if 'gain_db' not in master_kwargs: 314 | master_kwargs['gain_db'] = overall_gain_db 315 | 316 | home_root = os.getcwd() 317 | models_dir = os.path.join(home_root, "models") 318 | sf2_path = os.path.join(models_dir, sf2_name) 319 | 320 | if verbose_diag: 321 | _pv(home_root) 322 | _pv(models_dir) 323 | _pv(sf2_path) 324 | _pv('=' * 70) 325 | 326 | _pv('Prepping model...') 327 | encdec = EncoderDecoder(load_multi_instrumental_model=load_multi_instrumental_model, 328 | use_v1_piano_model=use_v1_piano_model, 329 | load_path_inference=custom_model_path 330 | ) 331 | 332 | if verbose_diag: 333 | try: 334 | _pv(encdec.gen) 335 | except Exception: 336 | _pv('encdec.gen: ') 337 | _pv('=' * 70) 338 | 339 | _pv('Reading and rendering MIDI file...') 340 | wav_data = midirenderer.render_wave_from( 341 | Path(sf2_path).read_bytes(), 342 | Path(input_midi_file).read_bytes() 343 | ) 344 | 345 | if verbose_diag: 346 | _pv(len(wav_data)) 347 | _pv('=' * 70) 348 | 349 | _pv('Loading rendered MIDI...') 350 | with io.BytesIO(wav_data) as byte_stream: 351 | wv, sr = librosa.load(byte_stream, sr=sample_rate) 352 | 353 | if verbose_diag: 354 | _pv(sr) 355 | _pv(wv.shape) 356 | _pv('=' * 70) 357 | 358 | if trim_silence: 359 | _pv('Trimming leading and trailing silence from rendered waveform...') 360 | wv_trimmed, trim_interval = librosa.effects.trim( 361 | wv, 362 | top_db=trim_top_db, 363 | frame_length=trim_frame_length, 364 | hop_length=trim_hop_length 365 | ) 366 | start_sample, end_sample = trim_interval 367 | orig_dur = len(wv) / sr 368 | trimmed_dur = len(wv_trimmed) / sr 369 | if verbose: 370 | _pv(f' Trimmed samples: start={start_sample}, end={end_sample}') 371 | _pv(f' Duration before={orig_dur:.3f}s, after={trimmed_dur:.3f}s') 372 | wv = wv_trimmed 373 | else: 374 | _pv('Silence trimming disabled; using full rendered waveform.') 375 | 376 | if verbose_diag: 377 | _pv(wv.shape) 378 | _pv('=' * 70) 379 | 380 | _pv('Encoding...') 381 | latent = encdec.encode(wv, 382 | max_batch_size=max_batch_size, 383 | show_progress=verbose 384 | ) 385 | 386 | if verbose_diag: 387 | try: 388 | _pv(latent.shape) 389 | except Exception: 390 | _pv('latent.shape: ') 391 | _pv('=' * 70) 392 | 393 | _pv('Rendering...') 394 | audio = encdec.decode(latent, 395 | denoising_steps=denoising_steps, 396 | max_batch_size=max_batch_size, 397 | show_progress=verbose 398 | ) 399 | 400 | audio = audio.squeeze() 401 | 402 | if verbose_diag: 403 | try: 404 | _pv(audio.shape) 405 | except Exception: 406 | _pv('audio.shape: ') 407 | _pv('=' * 70) 408 | 409 | # Post-processing: denoise 410 | if denoise: 411 | _pv('Denoising...') 412 | # Always pass sr and device; allow denoise_kwargs to override them if provided 413 | denoise_call_kwargs = dict(sr=sr, device=torch.device(device)) 414 | denoise_call_kwargs.update(denoise_kwargs) 415 | audio, den_diag = denoise_audio(audio, **denoise_call_kwargs) 416 | 417 | if verbose_diag: 418 | _pv(den_diag) 419 | _pv('=' * 70) 420 | 421 | # Post-processing: bass enhancement 422 | if enhance_bass: 423 | _pv('Enhancing bass...') 424 | bass_call_kwargs = dict(sr=sr, device=torch.device(device)) 425 | bass_call_kwargs.update(bass_kwargs) 426 | audio, bass_diag = enhance_audio_bass(audio, **bass_call_kwargs) 427 | 428 | if verbose_diag: 429 | _pv(bass_diag) 430 | _pv('=' * 70) 431 | 432 | # Post-processing: full enhancement (placed before mastering) 433 | if enhance_full: 434 | _pv('Enhancing full audio...') 435 | full_call_kwargs = dict(sr=sr, device=torch.device(device)) 436 | full_call_kwargs.update(enhance_full_kwargs) 437 | 438 | if not master: 439 | output_as_stereo = True 440 | 441 | else: 442 | output_as_stereo = False 443 | 444 | audio, full_diag = enhance_audio_full(audio, 445 | output_as_stereo=output_as_stereo, 446 | **full_call_kwargs 447 | ) 448 | 449 | if verbose_diag: 450 | _pv(full_diag) 451 | _pv('=' * 70) 452 | 453 | # Post-processing: mastering 454 | if master: 455 | _pv('Mastering...') 456 | master_call_kwargs = dict(device=torch.device(device)) 457 | master_call_kwargs.update(master_kwargs) 458 | audio, mas_diag = master_mono_piano(audio, **master_call_kwargs) 459 | 460 | if verbose_diag: 461 | _pv(mas_diag) 462 | _pv('=' * 70) 463 | 464 | if verbose_diag: 465 | try: 466 | _pv(audio.shape) 467 | except Exception: 468 | _pv('audio.shape: ') 469 | _pv('=' * 70) 470 | 471 | _pv('Creating final audio...') 472 | final_audio = audio.cpu().numpy().squeeze().T 473 | 474 | if verbose_diag: 475 | _pv(final_audio.shape) 476 | _pv(sr) 477 | _pv('=' * 70) 478 | 479 | if return_audio: 480 | _pv('Returning final audio...') 481 | _pv('=' * 70) 482 | _pv('Done!') 483 | _pv('=' * 70) 484 | return final_audio, sr 485 | 486 | else: 487 | _pv('Saving final audio...') 488 | if output_audio_file is None: 489 | midi_name = os.path.basename(input_midi_file) 490 | output_name, _ = os.path.splitext(midi_name) 491 | output_audio_file = os.path.join(home_root, output_name + '.wav') 492 | 493 | if verbose_diag: 494 | _pv(output_audio_file) 495 | _pv(sr) 496 | _pv('=' * 70) 497 | 498 | sf.write(output_audio_file, final_audio, samplerate=sr) 499 | 500 | _pv('=' * 70) 501 | _pv('Done!') 502 | _pv('=' * 70) 503 | 504 | return output_audio_file -------------------------------------------------------------------------------- /neuralpiano/music2latent/models.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .audio import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def zero_init(module): 10 | if init_as_zero: 11 | for p in module.parameters(): 12 | p.detach().zero_() 13 | return module 14 | 15 | def upsample_1d(x): 16 | return F.interpolate(x, scale_factor=2, mode="nearest") 17 | 18 | def downsample_1d(x): 19 | return F.avg_pool1d(x, kernel_size=2, stride=2) 20 | 21 | def upsample_2d(x): 22 | return F.interpolate(x, scale_factor=2, mode="nearest") 23 | 24 | def downsample_2d(x): 25 | return F.avg_pool2d(x, kernel_size=2, stride=2) 26 | 27 | 28 | class LayerNorm(nn.Module): 29 | def __init__(self, dim): 30 | super(LayerNorm, self).__init__() 31 | self.ln = torch.nn.LayerNorm(dim) 32 | 33 | def forward(self, input): 34 | x = input.permute(0,2,3,1) 35 | x = self.ln(x) 36 | x = x.permute(0,3,1,2) 37 | return x 38 | 39 | class FreqGain(nn.Module): 40 | def __init__(self, freq_dim): 41 | super(FreqGain, self).__init__() 42 | self.scale = nn.Parameter(torch.ones((1,1,freq_dim,1))) 43 | 44 | def forward(self, input): 45 | return input*self.scale 46 | 47 | 48 | class UpsampleConv(nn.Module): 49 | def __init__(self, in_channels, out_channels=None, use_2d=False, normalize=False): 50 | super(UpsampleConv, self).__init__() 51 | self.normalize = normalize 52 | 53 | self.use_2d = use_2d 54 | 55 | if out_channels is None: 56 | out_channels = in_channels 57 | 58 | if normalize: 59 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels) 60 | 61 | if use_2d: 62 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same') 63 | else: 64 | self.c = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding='same') 65 | 66 | def forward(self, x): 67 | 68 | if self.normalize: 69 | x = self.norm(x) 70 | 71 | if self.use_2d: 72 | x = upsample_2d(x) 73 | else: 74 | x = upsample_1d(x) 75 | x = self.c(x) 76 | 77 | return x 78 | 79 | class DownsampleConv(nn.Module): 80 | def __init__(self, in_channels, out_channels=None, use_2d=False, normalize=False): 81 | super(DownsampleConv, self).__init__() 82 | self.normalize = normalize 83 | 84 | if out_channels is None: 85 | out_channels = in_channels 86 | 87 | if normalize: 88 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels) 89 | 90 | if use_2d: 91 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 92 | else: 93 | self.c = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 94 | 95 | def forward(self, x): 96 | 97 | if self.normalize: 98 | x = self.norm(x) 99 | x = self.c(x) 100 | 101 | return x 102 | 103 | class UpsampleFreqConv(nn.Module): 104 | def __init__(self, in_channels, out_channels=None, normalize=False): 105 | super(UpsampleFreqConv, self).__init__() 106 | self.normalize = normalize 107 | 108 | if out_channels is None: 109 | out_channels = in_channels 110 | 111 | if normalize: 112 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels) 113 | 114 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=1, padding='same') 115 | 116 | def forward(self, x): 117 | if self.normalize: 118 | x = self.norm(x) 119 | x = F.interpolate(x, scale_factor=(4,1), mode="nearest") 120 | x = self.c(x) 121 | return x 122 | 123 | class DownsampleFreqConv(nn.Module): 124 | def __init__(self, in_channels, out_channels=None, normalize=False): 125 | super(DownsampleFreqConv, self).__init__() 126 | self.normalize = normalize 127 | 128 | if out_channels is None: 129 | out_channels = in_channels 130 | 131 | if normalize: 132 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels) 133 | 134 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=(4,1), padding=(2,0)) 135 | 136 | def forward(self, x): 137 | if self.normalize: 138 | x = self.norm(x) 139 | x = self.c(x) 140 | return x 141 | 142 | class MultiheadAttention(nn.MultiheadAttention): 143 | def _reset_parameters(self): 144 | super()._reset_parameters() 145 | self.out_proj = zero_init(self.out_proj) 146 | 147 | class Attention(nn.Module): 148 | def __init__(self, dim, heads=4, normalize=True, use_2d=False): 149 | super(Attention, self).__init__() 150 | 151 | self.normalize = normalize 152 | self.use_2d = use_2d 153 | 154 | self.mha = MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=0.0, add_zero_attn=False, batch_first=True) 155 | if normalize: 156 | self.norm = nn.GroupNorm(min(dim//4, 32), dim) 157 | 158 | def forward(self, x): 159 | 160 | inp = x 161 | 162 | if self.normalize: 163 | x = self.norm(x) 164 | 165 | if self.use_2d: 166 | x = x.permute(0,3,2,1) # shape: [bs,len,freq,channels] 167 | bs,len,freq,channels = x.shape[0],x.shape[1],x.shape[2],x.shape[3] 168 | x = x.reshape(bs*len,freq,channels) # shape: [bs*len,freq,channels] 169 | else: 170 | x = x.permute(0,2,1) # shape: [bs,len,channels] 171 | 172 | x = self.mha(x, x, x, need_weights=False)[0] 173 | 174 | if self.use_2d: 175 | x = x.reshape(bs,len,freq,channels).permute(0,3,2,1) 176 | else: 177 | x = x.permute(0,2,1) 178 | x = x+inp 179 | 180 | return x 181 | 182 | 183 | 184 | class ResBlock(nn.Module): 185 | def __init__(self, in_channels, out_channels, cond_channels=None, kernel_size=3, downsample=False, upsample=False, normalize=True, leaky=False, attention=False, heads=4, use_2d=False, normalize_residual=False): 186 | super(ResBlock, self).__init__() 187 | self.normalize = normalize 188 | self.attention = attention 189 | self.upsample = upsample 190 | self.downsample = downsample 191 | self.leaky = leaky 192 | self.kernel_size = kernel_size 193 | self.normalize_residual = normalize_residual 194 | self.use_2d = use_2d 195 | if use_2d: 196 | Conv = nn.Conv2d 197 | else: 198 | Conv = nn.Conv1d 199 | self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding='same') 200 | self.conv2 = zero_init(Conv(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding='same')) 201 | if in_channels!=out_channels: 202 | self.res_conv = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 203 | else: 204 | self.res_conv = nn.Identity() 205 | if normalize: 206 | self.norm1 = nn.GroupNorm(min(in_channels//4, 32), in_channels) 207 | self.norm2 = nn.GroupNorm(min(out_channels//4, 32), out_channels) 208 | if leaky: 209 | self.activation = nn.LeakyReLU(negative_slope=0.2) 210 | else: 211 | self.activation = nn.SiLU() 212 | if cond_channels is not None: 213 | self.proj_emb = zero_init(nn.Linear(cond_channels, out_channels)) 214 | self.dropout = nn.Dropout(dropout_rate) 215 | if attention: 216 | self.att = Attention(out_channels, heads, use_2d=use_2d) 217 | 218 | 219 | def forward(self, x, time_emb=None): 220 | if not self.normalize_residual: 221 | y = x.clone() 222 | if self.normalize: 223 | x = self.norm1(x) 224 | if self.normalize_residual: 225 | y = x.clone() 226 | x = self.activation(x) 227 | if self.downsample: 228 | if self.use_2d: 229 | x = downsample_2d(x) 230 | y = downsample_2d(y) 231 | else: 232 | x = downsample_1d(x) 233 | y = downsample_1d(y) 234 | if self.upsample: 235 | if self.use_2d: 236 | x = upsample_2d(x) 237 | y = upsample_2d(y) 238 | else: 239 | x = upsample_1d(x) 240 | y = upsample_1d(y) 241 | x = self.conv1(x) 242 | if time_emb is not None: 243 | if self.use_2d: 244 | x = x+self.proj_emb(time_emb)[:,:,None,None] 245 | else: 246 | x = x+self.proj_emb(time_emb)[:,:,None] 247 | if self.normalize: 248 | x = self.norm2(x) 249 | x = self.activation(x) 250 | if x.shape[-1]<=min_res_dropout: 251 | x = self.dropout(x) 252 | x = self.conv2(x) 253 | y = self.res_conv(y) 254 | x = x+y 255 | if self.attention: 256 | x = self.att(x) 257 | return x 258 | 259 | 260 | # adapted from https://github.com/yang-song/score_sde_pytorch/blob/main/models/layerspp.py 261 | class GaussianFourierProjection(torch.nn.Module): 262 | """Gaussian Fourier embeddings for noise levels.""" 263 | 264 | def __init__(self, embedding_size=128, scale=0.02): 265 | super().__init__() 266 | self.W = torch.nn.Parameter(torch.randn(embedding_size//2) * scale, requires_grad=False) 267 | 268 | def forward(self, x): 269 | x_proj = x[:, None] * self.W[None, :] * 2. * np.pi 270 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 271 | 272 | 273 | class PositionalEmbedding(torch.nn.Module): 274 | def __init__(self, embedding_size=128, max_positions=10000): 275 | super().__init__() 276 | self.embedding_size = embedding_size 277 | self.max_positions = max_positions 278 | 279 | def forward(self, x): 280 | freqs = torch.arange(start=0, end=self.embedding_size//2, dtype=torch.float32, device=x.device) 281 | freqs = freqs / (self.embedding_size // 2 - 1) 282 | freqs = (1 / self.max_positions) ** freqs 283 | x = x.ger(freqs.to(x.dtype)) 284 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 285 | return x 286 | 287 | class Encoder(nn.Module): 288 | def __init__(self): 289 | super(Encoder, self).__init__() 290 | 291 | layers_list = layers_list_encoder 292 | attention_list = attention_list_encoder 293 | self.layers_list = layers_list 294 | self.multipliers_list = multipliers_list 295 | input_channels = base_channels*multipliers_list[0] 296 | Conv = nn.Conv2d 297 | self.gain = FreqGain(freq_dim=hop*2) 298 | 299 | channels = data_channels 300 | self.conv_inp = Conv(channels, input_channels, kernel_size=3, stride=1, padding=1) 301 | 302 | self.freq_dim = (hop*2)//(4**freq_downsample_list.count(1)) 303 | self.freq_dim = self.freq_dim//(2**freq_downsample_list.count(0)) 304 | 305 | # DOWNSAMPLING 306 | down_layers = [] 307 | for i, (num_layers,multiplier) in enumerate(zip(layers_list,multipliers_list)): 308 | output_channels = base_channels*multiplier 309 | for num in range(num_layers): 310 | down_layers.append(ResBlock(input_channels, output_channels, normalize=normalization, attention=attention_list[i]==1, heads=heads, use_2d=True)) 311 | input_channels = output_channels 312 | if i!=(len(layers_list)-1): 313 | if freq_downsample_list[i]==1: 314 | down_layers.append(DownsampleFreqConv(input_channels, normalize=pre_normalize_downsampling_encoder)) 315 | else: 316 | down_layers.append(DownsampleConv(input_channels, use_2d=True, normalize=pre_normalize_downsampling_encoder)) 317 | 318 | if pre_normalize_2d_to_1d: 319 | self.prenorm_1d_to_2d = nn.GroupNorm(min(input_channels//4, 32), input_channels) 320 | 321 | bottleneck_layers = [] 322 | output_channels = bottleneck_base_channels 323 | bottleneck_layers.append(nn.Conv1d(input_channels*self.freq_dim, output_channels, kernel_size=1, stride=1, padding='same')) 324 | for i in range(num_bottleneck_layers): 325 | bottleneck_layers.append(ResBlock(output_channels, output_channels, normalize=normalization, use_2d=False)) 326 | self.bottleneck_layers = nn.ModuleList(bottleneck_layers) 327 | 328 | self.norm_out = nn.GroupNorm(min(output_channels//4, 32), output_channels) 329 | self.activation_out = nn.SiLU() 330 | self.conv_out = nn.Conv1d(output_channels, bottleneck_channels, kernel_size=1, stride=1, padding='same') 331 | self.activation_bottleneck = nn.Tanh() 332 | 333 | self.down_layers = nn.ModuleList(down_layers) 334 | 335 | 336 | def forward(self, x, extract_features=False): 337 | 338 | x = self.conv_inp(x) 339 | if frequency_scaling: 340 | x = self.gain(x) 341 | 342 | # DOWNSAMPLING 343 | k = 0 344 | for i,num_layers in enumerate(self.layers_list): 345 | for num in range(num_layers): 346 | x = self.down_layers[k](x) 347 | k = k+1 348 | if i!=(len(self.layers_list)-1): 349 | x = self.down_layers[k](x) 350 | k = k+1 351 | 352 | if pre_normalize_2d_to_1d: 353 | x = self.prenorm_1d_to_2d(x) 354 | 355 | x = x.reshape(x.size(0), x.size(1) * x.size(2), x.size(3)) 356 | if extract_features: 357 | return x 358 | 359 | for layer in self.bottleneck_layers: 360 | x = layer(x) 361 | 362 | x = self.norm_out(x) 363 | x = self.activation_out(x) 364 | x = self.conv_out(x) 365 | x = self.activation_bottleneck(x) 366 | 367 | return x 368 | 369 | 370 | class Decoder(nn.Module): 371 | def __init__(self): 372 | super(Decoder, self).__init__() 373 | 374 | layers_list = layers_list_encoder 375 | attention_list = attention_list_encoder 376 | self.layers_list = layers_list_encoder 377 | self.multipliers_list = multipliers_list 378 | input_channels = base_channels*multipliers_list[-1] 379 | 380 | output_channels = bottleneck_base_channels 381 | self.conv_inp = nn.Conv1d(bottleneck_channels, output_channels, kernel_size=1, stride=1, padding='same') 382 | 383 | self.freq_dim = (hop*2)//(4**freq_downsample_list.count(1)) 384 | self.freq_dim = self.freq_dim//(2**freq_downsample_list.count(0)) 385 | 386 | bottleneck_layers = [] 387 | for i in range(num_bottleneck_layers): 388 | bottleneck_layers.append(ResBlock(output_channels, output_channels, cond_channels, normalize=normalization, use_2d=False)) 389 | 390 | self.conv_out_bottleneck = nn.Conv1d(output_channels, input_channels*self.freq_dim, kernel_size=1, stride=1, padding='same') 391 | self.bottleneck_layers = nn.ModuleList(bottleneck_layers) 392 | 393 | # UPSAMPLING 394 | multipliers_list_upsampling = list(reversed(multipliers_list))[1:]+list(reversed(multipliers_list))[:1] 395 | freq_upsample_list = list(reversed(freq_downsample_list)) 396 | up_layers = [] 397 | for i, (num_layers,multiplier) in enumerate(zip(reversed(layers_list),multipliers_list_upsampling)): 398 | for num in range(num_layers): 399 | up_layers.append(ResBlock(input_channels, input_channels, normalize=normalization, attention=list(reversed(attention_list))[i]==1, heads=heads, use_2d=True)) 400 | if i!=(len(layers_list)-1): 401 | output_channels = base_channels*multiplier 402 | if freq_upsample_list[i]==1: 403 | up_layers.append(UpsampleFreqConv(input_channels, output_channels)) 404 | else: 405 | up_layers.append(UpsampleConv(input_channels, output_channels, use_2d=True)) 406 | input_channels = output_channels 407 | 408 | self.up_layers = nn.ModuleList(up_layers) 409 | 410 | 411 | def forward(self, x): 412 | 413 | x = self.conv_inp(x) 414 | 415 | for layer in self.bottleneck_layers: 416 | x = layer(x) 417 | x = self.conv_out_bottleneck(x) 418 | 419 | x_ls = torch.chunk(x.unsqueeze(-2), self.freq_dim, -3) 420 | x = torch.cat(x_ls, -2) 421 | 422 | # UPSAMPLING 423 | k = 0 424 | pyramid_list = [] 425 | for i,num_layers in enumerate(reversed(self.layers_list)): 426 | for num in range(num_layers): 427 | x = self.up_layers[k](x) 428 | k = k+1 429 | pyramid_list.append(x) 430 | if i!=(len(self.layers_list)-1): 431 | x = self.up_layers[k](x) 432 | k = k+1 433 | 434 | pyramid_list = pyramid_list[::-1] 435 | 436 | return pyramid_list 437 | 438 | 439 | class UNet(nn.Module): 440 | def __init__(self): 441 | super(UNet, self).__init__() 442 | 443 | self.layers_list = layers_list 444 | self.multipliers_list = multipliers_list 445 | input_channels = base_channels*multipliers_list[0] 446 | Conv = nn.Conv2d 447 | 448 | self.encoder = Encoder() 449 | self.decoder = Decoder() 450 | 451 | if use_fourier: 452 | self.emb = GaussianFourierProjection(embedding_size=cond_channels, scale=fourier_scale) 453 | else: 454 | self.emb = PositionalEmbedding(embedding_size=cond_channels) 455 | 456 | self.emb_proj = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU()) 457 | 458 | self.scale_inp = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU(), zero_init(nn.Linear(cond_channels, hop*2))) 459 | self.scale_out = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU(), zero_init(nn.Linear(cond_channels, hop*2))) 460 | 461 | self.conv_inp = Conv(data_channels, input_channels, kernel_size=3, stride=1, padding=1) 462 | 463 | # DOWNSAMPLING 464 | down_layers = [] 465 | for i, (num_layers,multiplier) in enumerate(zip(layers_list,multipliers_list)): 466 | output_channels = base_channels*multiplier 467 | for num in range(num_layers): 468 | down_layers.append(Conv(output_channels, output_channels, kernel_size=1, stride=1, padding=0)) 469 | down_layers.append(ResBlock(output_channels, output_channels, cond_channels, normalize=normalization, attention=attention_list[i]==1, heads=heads, use_2d=True)) 470 | input_channels = output_channels 471 | if i!=(len(layers_list)-1): 472 | output_channels = base_channels*multipliers_list[i+1] 473 | if freq_downsample_list[i]==1: 474 | down_layers.append(DownsampleFreqConv(input_channels, output_channels)) 475 | else: 476 | down_layers.append(DownsampleConv(input_channels, output_channels, use_2d=True)) 477 | 478 | # UPSAMPLING 479 | multipliers_list_upsampling = list(reversed(multipliers_list))[1:]+list(reversed(multipliers_list))[:1] 480 | freq_upsample_list = list(reversed(freq_downsample_list)) 481 | up_layers = [] 482 | for i, (num_layers,multiplier) in enumerate(zip(reversed(layers_list),multipliers_list_upsampling)): 483 | for num in range(num_layers): 484 | up_layers.append(Conv(input_channels, input_channels, kernel_size=1, stride=1, padding=0)) 485 | up_layers.append(ResBlock(input_channels, input_channels, cond_channels, normalize=normalization, attention=list(reversed(attention_list))[i]==1, heads=heads, use_2d=True)) 486 | if i!=(len(layers_list)-1): 487 | output_channels = base_channels*multiplier 488 | if freq_upsample_list[i]==1: 489 | up_layers.append(UpsampleFreqConv(input_channels, output_channels)) 490 | else: 491 | up_layers.append(UpsampleConv(input_channels, output_channels, use_2d=True)) 492 | input_channels = output_channels 493 | 494 | self.conv_decoded = Conv(input_channels, input_channels, kernel_size=1, stride=1, padding=0) 495 | self.norm_out = nn.GroupNorm(min(input_channels//4, 32), input_channels) 496 | self.activation_out = nn.SiLU() 497 | self.conv_out = zero_init(Conv(input_channels, data_channels, kernel_size=3, stride=1, padding=1)) 498 | 499 | self.down_layers = nn.ModuleList(down_layers) 500 | self.up_layers = nn.ModuleList(up_layers) 501 | 502 | 503 | def forward(self, latents, x, sigma=None, pyramid_latents=None): 504 | 505 | if sigma is None: 506 | sigma = sigma_max 507 | 508 | inp = x 509 | 510 | # CONDITIONING 511 | sigma = torch.ones((x.shape[0],), dtype=torch.float32).to(x.device)*sigma 512 | sigma_log = torch.log(sigma)/4. 513 | emb_sigma_log = self.emb(sigma_log) 514 | time_emb = self.emb_proj(emb_sigma_log) 515 | 516 | scale_w_inp = self.scale_inp(emb_sigma_log).reshape(x.shape[0],1,-1,1) 517 | scale_w_out = self.scale_out(emb_sigma_log).reshape(x.shape[0],1,-1,1) 518 | 519 | c_skip, c_out, c_in = get_c(sigma) 520 | 521 | x = c_in*x 522 | 523 | if latents.shape == x.shape: 524 | latents = self.encoder(latents) 525 | 526 | if pyramid_latents is None: 527 | pyramid_latents = self.decoder(latents) 528 | 529 | x = self.conv_inp(x) 530 | if frequency_scaling: 531 | x = (1.+scale_w_inp)*x 532 | 533 | skip_list = [] 534 | 535 | # DOWNSAMPLING 536 | k = 0 537 | r = 0 538 | for i,num_layers in enumerate(self.layers_list): 539 | for num in range(num_layers): 540 | d = self.down_layers[k](pyramid_latents[i]) 541 | k = k+1 542 | x = (x+d)/np.sqrt(2.) 543 | x = self.down_layers[k](x, time_emb) 544 | skip_list.append(x) 545 | k = k+1 546 | if i!=(len(self.layers_list)-1): 547 | x = self.down_layers[k](x) 548 | k = k+1 549 | 550 | # UPSAMPLING 551 | k = 0 552 | for i,num_layers in enumerate(reversed(self.layers_list)): 553 | for num in range(num_layers): 554 | d = self.up_layers[k](pyramid_latents[-i-1]) 555 | k = k+1 556 | x = (x+skip_list.pop()+d)/np.sqrt(3.) 557 | x = self.up_layers[k](x, time_emb) 558 | k = k+1 559 | if i!=(len(self.layers_list)-1): 560 | x = self.up_layers[k](x) 561 | k = k+1 562 | 563 | d = self.conv_decoded(pyramid_latents[0]) 564 | x = (x+d)/np.sqrt(2.) 565 | 566 | x = self.norm_out(x) 567 | x = self.activation_out(x) 568 | if frequency_scaling: 569 | x = (1.+scale_w_out)*x 570 | x = self.conv_out(x) 571 | 572 | out = c_skip*inp + c_out*x 573 | 574 | return out --------------------------------------------------------------------------------