├── neuralpiano
├── artwork
│ ├── __init__.py
│ ├── Neural-Piano-Logo.png
│ ├── Neural-Piano-Sign.png
│ ├── Tegridy-Code-2025.png
│ ├── Neural-Piano-Artwork.png
│ ├── Project-Los-Angeles.png
│ └── README.md
├── output_samples
│ ├── __init__.py
│ └── README.md
├── seed_midis
│ ├── __init__.py
│ ├── README.md
│ ├── neural_piano_seed_midi_1.mid
│ ├── neural_piano_seed_midi_2.mid
│ ├── neural_piano_seed_midi_3.mid
│ ├── neural_piano_seed_midi_4.mid
│ ├── neural_piano_seed_midi_5.mid
│ └── neural_piano_seed_midi_6.mid
├── music2latent
│ ├── __init__.py
│ ├── hparams_inference.py
│ ├── README.md
│ ├── hparams.py
│ ├── utils.py
│ ├── audio.py
│ ├── inference.py
│ └── models.py
├── __init__.py
├── README.md
├── sample_midis.py
├── denoise.py
├── bass.py
├── mixer.py
├── enhancer.py
├── master.py
└── neuralpiano.py
├── requirements.txt
├── MANIFEST.in
├── README.md
├── pyproject.toml
└── LICENSE
/neuralpiano/artwork/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/neuralpiano/output_samples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/neuralpiano/output_samples/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/neuralpiano/music2latent/__init__.py:
--------------------------------------------------------------------------------
1 | from .inference import EncoderDecoder
--------------------------------------------------------------------------------
/neuralpiano/artwork/Neural-Piano-Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Logo.png
--------------------------------------------------------------------------------
/neuralpiano/artwork/Neural-Piano-Sign.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Sign.png
--------------------------------------------------------------------------------
/neuralpiano/artwork/Tegridy-Code-2025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Tegridy-Code-2025.png
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/README.md:
--------------------------------------------------------------------------------
1 | # Neural Piano sample seed MIDIs
2 |
3 | ***
4 |
5 | ### Project Los Angeles
6 | ### Tegridy Code 2025
7 |
--------------------------------------------------------------------------------
/neuralpiano/artwork/Neural-Piano-Artwork.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Neural-Piano-Artwork.png
--------------------------------------------------------------------------------
/neuralpiano/artwork/Project-Los-Angeles.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/artwork/Project-Los-Angeles.png
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_1.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_1.mid
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_2.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_2.mid
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_3.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_3.mid
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_4.mid
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_5.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_5.mid
--------------------------------------------------------------------------------
/neuralpiano/seed_midis/neural_piano_seed_midi_6.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/neuralpiano/main/neuralpiano/seed_midis/neural_piano_seed_midi_6.mid
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | ipywidgets
3 | matplotlib
4 | hf-transfer
5 | huggingface_hub
6 | soundfile
7 | torch
8 | numpy==1.24.4
9 | librosa
10 | midirenderer
--------------------------------------------------------------------------------
/neuralpiano/artwork/README.md:
--------------------------------------------------------------------------------
1 | # Neural Piano Concept Artwork
2 |
3 | ***
4 |
5 | ## Images were created with Stable Diffusion 3.5 Large image AI model
6 |
7 | ***
8 |
9 | ### Project Los Angeles
10 | ### Tegridy Code 2025
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include neuralpiano/README.md
2 | recursive-include neuralpiano/ *
3 | recursive-include neuralpiano/artwork *
4 | recursive-include neuralpiano/output_samples *
5 | recursive-include neuralpiano/seed_midis *
6 | recursive-include neuralpiano/music2latent *
--------------------------------------------------------------------------------
/neuralpiano/__init__.py:
--------------------------------------------------------------------------------
1 | from .sample_midis import get_sample_midi_files
2 |
3 | from .music2latent.inference import EncoderDecoder
4 |
5 | from .denoise import denoise_audio
6 |
7 | from .bass import enhance_audio_bass
8 |
9 | from .enhancer import enhance_audio_full
10 |
11 | from .master import master_mono_piano
12 |
13 | from .mixer import mix_audio
14 |
15 | from .neuralpiano import *
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [WIP] Neural Piano
2 | ## Hi-Fi neural MIDI piano synthesizer and MIDI renderer
3 |
4 |
5 |
6 | ***
7 |
8 | ## Installation
9 |
10 | ### pip and setuptools
11 |
12 | ```sh
13 | # It is recommended that you upgrade pip and setuptools prior to install for max compatibility
14 | !pip install --upgrade pip
15 | !pip install --upgrade setuptools build wheel
16 | ```
17 |
18 | ### pip install
19 |
20 | ```sh
21 | # The following command will install Neural Piano pip package
22 | # Please note that Neural Piano requires Nvidia GPU with at least 40GB VRAM
23 |
24 | !pip install -U neuralpiano
25 | ```
26 |
27 | ***
28 |
29 | ## Quick-start use example
30 |
31 | ```python
32 | # Import main Neural Piano module
33 | import neuralpiano
34 |
35 | # Render MIDI
36 | neuralpiano.render_midi('input.mid', 'output.wav')
37 | ```
38 |
39 | ***
40 |
41 | ### Project Los Angeles
42 | ### Tegridy Code 2025
43 |
--------------------------------------------------------------------------------
/neuralpiano/README.md:
--------------------------------------------------------------------------------
1 | # [WIP] Neural Piano
2 | ## Hi-Fi neural MIDI piano synthesizer and MIDI renderer
3 |
4 |
5 |
6 | ***
7 |
8 | ## Installation
9 |
10 | ### pip and setuptools
11 |
12 | ```sh
13 | # It is recommended that you upgrade pip and setuptools prior to install for max compatibility
14 | !pip install --upgrade pip
15 | !pip install --upgrade setuptools build wheel
16 | ```
17 |
18 | ### pip install
19 |
20 | ```sh
21 | # The following command will install Neural Piano pip package
22 | # Please note that Neural Piano requires Nvidia GPU with at least 40GB VRAM
23 |
24 | !pip install -U neuralpiano
25 | ```
26 |
27 | ***
28 |
29 | ## Quick-start use example
30 |
31 | ```python
32 | # Import main Neural Piano module
33 | import neuralpiano
34 |
35 | # Render MIDI
36 | neuralpiano.render_midi('input.mid', 'output.wav')
37 | ```
38 |
39 | ***
40 |
41 | ### Project Los Angeles
42 | ### Tegridy Code 2025
43 |
--------------------------------------------------------------------------------
/neuralpiano/music2latent/hparams_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | home_root = os.getcwd()
4 |
5 | load_path_inference_multi_instrumental_default = os.path.join(home_root, 'models/music2latent.pt')
6 | load_path_inference_solo_piano_default = os.path.join(home_root, 'models/music2latent_maestro_loss_16.871_iters_45500.pt')
7 | load_path_inference_solo_piano_v1_default = os.path.join(home_root, 'models/music2latent_maestro_loss_27.834_iters_14300.pt')
8 |
9 | max_batch_size_encode = 1 # maximum inference batch size for encoding: tune it depending on the available GPU memory
10 | max_waveform_length_encode = 44100*60 # maximum length of waveforms in the batch for encoding: tune it depending on the available GPU memory
11 | max_batch_size_decode = 1 # maximum inference batch size for decoding: tune it depending on the available GPU memory
12 | max_waveform_length_decode = 44100*60 # maximum length of waveforms in the batch for decoding: tune it depending on the available GPU memory
13 |
14 | sigma_rescale = 0.06 # rescale sigma for inference
--------------------------------------------------------------------------------
/neuralpiano/sample_midis.py:
--------------------------------------------------------------------------------
1 | #===================================================================================================
2 | # Neural Piano sample_midis Python module
3 | #===================================================================================================
4 | # Project Los Angeles
5 | # Tegridy Code 2025
6 | #===================================================================================================
7 | # License: Apache 2.0
8 | #===================================================================================================
9 |
10 | import importlib.resources as pkg_resources
11 | from neuralpiano import seed_midis
12 |
13 | #===================================================================================================
14 |
15 | def get_sample_midi_files():
16 |
17 | midi_files = []
18 |
19 | for resource in pkg_resources.contents(seed_midis):
20 | if resource.endswith('.mid'):
21 | with pkg_resources.path(seed_midis, resource) as p:
22 | midi_files.append((resource, str(p)))
23 |
24 | return sorted(midi_files)
25 |
26 | #===================================================================================================
27 | # This is the end of the sample_midis Python module
28 | #===================================================================================================
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=80.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "neuralpiano"
7 | version = "25.11.37"
8 | description = "Hi-Fi neural MIDI piano synthesizer and MIDI renderer"
9 | license = "Apache-2.0"
10 | license-files = ["LICENSE"]
11 | authors = [
12 | { name = "Alex Lev", email = "alexlev61@proton.me" }
13 | ]
14 | maintainers = [
15 | { name = "Alex Lev", email = "alexlev61@proton.me" }
16 | ]
17 | dependencies = [
18 | "tqdm",
19 | "ipywidgets",
20 | "matplotlib",
21 | "hf-transfer",
22 | "huggingface_hub",
23 | "soundfile",
24 | "torch",
25 | "numpy==1.24.4",
26 | "librosa",
27 | "midirenderer"
28 | ]
29 | keywords = ["MIDI", "music", "music ai", "MIDI piano", "Piano synthesizer", "renderer", "MIDI synthesizer", "MIDI renderer", "neural synthesizer", "artificial intelligence", "AI", "audio synthesis"]
30 | classifiers = [
31 | "Development Status :: 5 - Production/Stable",
32 | "Intended Audience :: Science/Research",
33 | "Topic :: Multimedia :: Sound/Audio :: MIDI",
34 | "Operating System :: OS Independent",
35 | "Programming Language :: Python :: 3.6",
36 | "Programming Language :: Python :: 3.7",
37 | "Programming Language :: Python :: 3.8",
38 | "Programming Language :: Python :: 3.9",
39 | "Programming Language :: Python :: 3.10",
40 | "Programming Language :: Python :: 3.11",
41 | "Programming Language :: Python :: 3.12"
42 | ]
43 | requires-python = ">=3.6"
44 |
45 | [project.readme]
46 | file = "neuralpiano/README.md"
47 | content-type = "text/markdown"
48 |
49 | [project.urls]
50 | Homepage = "https://github.com/asigalov61/neuralpiano"
51 | SoundCloud = "https://soundcloud.com/aleksandr-sigalov-61/"
52 | Samples = "https://github.com/asigalov61/neuralpiano/tree/main/neuralpiano/output_samples"
53 | Documentation = "https://github.com/asigalov61/neuralpiano"
54 | Issues = "https://github.com/asigalov61/neuralpiano/issues"
55 | Discussions = "https://github.com/asigalov61/neuralpiano/discussions"
56 | Dataset = "https://magenta.withgoogle.com/datasets/maestro"
57 | Demo = "https://huggingface.co/datasets/projectlosangeles/Neural-Piano-Synthesizer"
58 |
--------------------------------------------------------------------------------
/neuralpiano/music2latent/README.md:
--------------------------------------------------------------------------------
1 | # Music2Latent
2 | Encode and decode audio samples to/from compressed representations! Useful for efficient generative modelling applications and for other downstream tasks.
3 |
4 | 
5 |
6 | Read the ISMIR 2024 paper [here](https://arxiv.org/abs/2408.06500).
7 | Listen to audio samples [here](https://sonycslparis.github.io/music2latent-companion/).
8 |
9 | Under the hood, __Music2Latent__ uses a __Consistency Autoencoder__ model to efficiently encode and decode audio samples.
10 |
11 | 44.1 kHz audio is encoded into a sequence of __~10 Hz__, and each of the latents has 64 channels.
12 | 48 kHz audio can also be encoded, which results in a sequence of ~12 Hz.
13 | A generative model can then be trained on these embeddings, or they can be used for other downstream tasks.
14 |
15 | Music2Latent was trained on __music__ and on __speech__. Refer to the [paper](https://arxiv.org/abs/2408.06500) for more details.
16 |
17 |
18 | ## Installation
19 |
20 | ```bash
21 | pip install music2latent
22 | ```
23 | The model weights will be downloaded automatically the first time the code is run.
24 |
25 |
26 | ## How to use
27 | To encode and decode audio samples to/from latent embeddings:
28 | ```bash
29 | audio_path = librosa.example('trumpet')
30 | wv, sr = librosa.load(audio_path, sr=44100)
31 |
32 | from music2latent import EncoderDecoder
33 | encdec = EncoderDecoder()
34 |
35 | latent = encdec.encode(wv)
36 | # latent has shape (batch_size/audio_channels, dim (64), sequence_length)
37 |
38 | wv_rec = encdec.decode(latent)
39 | ```
40 | To extract encoder features to use in downstream tasks:
41 | ```bash
42 | features = encoder.encode(wv, extract_features=True)
43 | ```
44 | These features are extracted before the encoder bottleneck, and thus have more channels (contain more information) than the latents used for reconstruction. It will not be possible to directly decode these features back to audio.
45 |
46 | music2latent supports more advanced usage, including GPU memory management controls. Please refer to __tutorial.ipynb__.
47 |
48 |
49 | ## License
50 | This library is released under the CC BY-NC 4.0 license. Please refer to the LICENSE file for more details.
51 |
52 |
53 |
54 | This work was conducted by [Marco Pasini](https://twitter.com/marco_ppasini) during his PhD at Queen Mary University of London, in partnership with Sony Computer Science Laboratories Paris.
55 | This work was supervised by Stefan Lattner and George Fazekas.
--------------------------------------------------------------------------------
/neuralpiano/music2latent/hparams.py:
--------------------------------------------------------------------------------
1 | # GENERAL
2 | mixed_precision = True # use mixed precision
3 | seed = 42 # seed for Pytorch
4 |
5 | # DATA
6 | data_channels = 2 # channels of input data
7 | data_length = 64 # sequence length of input data
8 | data_length_test = 1024//4 # sequence length of data used for testing
9 | sample_rate = 44100 # sampling rate of input/output audio
10 |
11 | hop = 128*4 # hop size of transformation
12 |
13 | alpha_rescale = 0.65
14 | beta_rescale = 0.34
15 | sigma_data = 0.5
16 |
17 |
18 |
19 | # MODEL
20 | base_channels = 64 # base channel number for architecture
21 | layers_list = [2,2,2,2,2] # number of blocks per each resolution level
22 | multipliers_list = [1,2,4,4,4] # base channels multipliers for each resolution level
23 | attention_list = [0,0,1,1,1] # for each resolution, 0 if no attention is performed, 1 if attention is performed
24 | freq_downsample_list = [1,0,0,0] # for each resolution, 0 if frequency 4x downsampling, 1 if standard frequency 2x and time 2x downsampling
25 |
26 | layers_list_encoder = [1,1,1,1,1] # number of blocks per each resolution level
27 | attention_list_encoder = [0,0,1,1,1] # for each resolution, 0 if no attention is performed, 1 if attention is performed
28 | bottleneck_base_channels = 512 # base channels to use for block before/after bottleneck
29 | num_bottleneck_layers = 4 # number of blocks to use before/after bottleneck
30 | frequency_scaling = True
31 |
32 | heads = 4 # number of attention heads
33 | cond_channels = 256 # dimension of time embedding
34 | use_fourier = False # if True, use random Fourier embedding, if False, use Positional
35 | fourier_scale = 0.2 # scale parameter for gaussian fourier layer (original is 0.02, but to me it appears too small)
36 | normalization = True # use group normalization
37 | dropout_rate = 0. # dropout rate
38 | min_res_dropout = 16 # dropout is applied on equal or smaller feature map resolutions
39 | init_as_zero = True # initialize convolution kernels before skip connections with zeros
40 |
41 | bottleneck_channels = 32*2 # channels of encoder bottleneck
42 |
43 | pre_normalize_2d_to_1d = True # pre-normalize 2D to 1D connection in encoder
44 | pre_normalize_downsampling_encoder = True # pre-normalize downsampling layers in encoder
45 |
46 | sigma_min = 0.002 # minimum sigma
47 | sigma_max = 80. # maximum sigma
48 | rho = 7. # rho parameter for sigma schedule
--------------------------------------------------------------------------------
/neuralpiano/music2latent/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path as ospath
3 |
4 | from huggingface_hub import hf_hub_download
5 |
6 | import torch
7 |
8 | from .hparams import *
9 |
10 | # Get scaling coefficients c_skip, c_out, c_in based on noise sigma
11 | # These are used to scale the input and output of the consistency model, while satisfying the boundary condition for consistency models
12 | # Parameters:
13 | # sigma: noise level
14 | # Returns:
15 | # c_skip, c_out, c_in: scaling coefficients
16 | def get_c(sigma):
17 | sigma_correct = sigma_min
18 | c_skip = (sigma_data**2.)/(((sigma-sigma_correct)**2.) + (sigma_data**2.))
19 | c_out = (sigma_data*(sigma-sigma_correct))/(((sigma_data**2.) + (sigma**2.))**0.5)
20 | c_in = 1./(((sigma**2.)+(sigma_data**2.))**0.5)
21 | return c_skip.reshape(-1,1,1,1), c_out.reshape(-1,1,1,1), c_in.reshape(-1,1,1,1)
22 |
23 | # Get noise level sigma_i based on index i and number of discretization steps k
24 | # Parameters:
25 | # i: index
26 | # k: number of discretization steps
27 | # Returns:
28 | # sigma_i: noise level corresponding to index i
29 | def get_sigma(i, k):
30 | return (sigma_min**(1./rho) + ((i-1)/(k-1))*(sigma_max**(1./rho)-sigma_min**(1./rho)))**rho
31 |
32 | # Get noise level sigma for a continuous index i in [0, 1]
33 | # Follows parameterization in https://openreview.net/pdf?id=FmqFfMTNnv
34 | # Parameters:
35 | # i: continuous index in [0, 1]
36 | # Returns:
37 | # sigma: corresponding noise level
38 | def get_sigma_continuous(i):
39 | return (sigma_min**(1./rho) + i*(sigma_max**(1./rho)-sigma_min**(1./rho)))**rho
40 |
41 |
42 | # Add Gaussian noise to input x based on given noise and sigma
43 | # Parameters:
44 | # x: input tensor
45 | # noise: tensor containing Gaussian noise
46 | # sigma: noise level
47 | # Returns:
48 | # x_noisy: x with noise added
49 | def add_noise(x, noise, sigma):
50 | return x + sigma.reshape(-1,1,1,1)*noise
51 |
52 |
53 | # Reverse the probability flow ODE by one step
54 | # Parameters:
55 | # x: input
56 | # noise: Gaussian noise
57 | # sigma: noise level
58 | # Returns:
59 | # x: x after reversing ODE by one step
60 | def reverse_step(x, noise, sigma):
61 | return x + ((sigma**2 - sigma_min**2)**0.5)*noise
62 |
63 |
64 | # Denoise samples at a given noise level
65 | # Parameters:
66 | # model: consistency model
67 | # noisy_samples: input noisy samples
68 | # sigma: noise level
69 | # Returns:
70 | # pred_noises: predicted noise
71 | # pred_samples: denoised samples
72 | def denoise(model, noisy_samples, sigma, latents=None):
73 | # Denoise samples
74 | with torch.no_grad():
75 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
76 | if latents is not None:
77 | pred_samples = model(latents, noisy_samples, sigma)
78 | else:
79 | pred_samples = model(noisy_samples, sigma)
80 | # Sample noise
81 | pred_noises = torch.randn_like(pred_samples)
82 | return pred_noises, pred_samples
83 |
84 | # Reverse the diffusion process to generate samples
85 | # Parameters:
86 | # model: trained consistency model
87 | # initial_noise: initial noise to start from
88 | # diffusion_steps: number of steps to reverse
89 | # Returns:
90 | # final_samples: generated samples
91 | def reverse_diffusion(model, initial_noise, diffusion_steps, latents=None):
92 | next_noisy_samples = initial_noise
93 | # Reverse process step-by-step
94 | for k in range(diffusion_steps):
95 |
96 | # Get sigma values
97 | sigma = get_sigma(diffusion_steps+1-k, diffusion_steps+1)
98 | next_sigma = get_sigma(diffusion_steps-k, diffusion_steps+1)
99 |
100 | # Denoise
101 | noisy_samples = next_noisy_samples
102 | pred_noises, pred_samples = denoise(model, noisy_samples, sigma, latents)
103 |
104 | # Step to next (lower) noise level
105 | next_noisy_samples = reverse_step(pred_samples, pred_noises, next_sigma)
106 |
107 | return pred_samples.detach().cpu()
108 |
109 |
110 | def is_path(variable):
111 | return isinstance(variable, str) and os.path.exists(variable)
112 |
113 |
114 |
115 | def download_models():
116 | home_root = os.getcwd()
117 | models_dir = os.path.join(home_root, "models")
118 | os.makedirs(models_dir, exist_ok=True)
119 |
120 | if not os.path.exists(os.path.join(models_dir, "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2")):
121 |
122 | files = [
123 | ("SonyCSLParis/music2latent", "music2latent.pt", "model"),
124 | ("asigalov61/music2latent-maestro", "music2latent_maestro_loss_16.871_iters_45500.pt", "model"),
125 | ("asigalov61/music2latent-maestro", "music2latent_maestro_loss_27.834_iters_14300.pt", "model"),
126 | ("projectlosangeles/soundfonts4u", "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2", "dataset"),
127 | ]
128 |
129 | for repo_id, filename, repo_type in files:
130 | print(f"Downloading {filename} from {repo_id} ...")
131 | # download to Hugging Face cache and get the real path returned
132 | cached_path = hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=filename, local_dir=models_dir)
133 |
134 | print("Models were downloaded successfully!")
--------------------------------------------------------------------------------
/neuralpiano/music2latent/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from .hparams import *
6 |
7 |
8 | def wv2spec(wv, hop_size=256, fac=4):
9 | X = stft(wv, hop_size=hop_size, fac=fac, device=wv.device)
10 | X = power2db(torch.abs(X)**2)
11 | X = normalize(X)
12 | return X
13 |
14 | def spec2wv(S,P, hop_size=256, fac=4):
15 | S = denormalize(S)
16 | S = torch.sqrt(db2power(S))
17 | P = P * np.pi
18 | SP = torch.complex(S * torch.cos(P), S * torch.sin(P))
19 | return istft(SP, fac=fac, hop_size=hop_size, device=SP.device)
20 |
21 | def denormalize_realimag(x):
22 | x = x/beta_rescale
23 | return torch.sign(x)*(x.abs()**(1./alpha_rescale))
24 |
25 | def normalize_complex(x):
26 | return beta_rescale*(x.abs()**alpha_rescale).to(torch.complex64)*torch.exp(1j*torch.angle(x).to(torch.complex64))
27 |
28 | def denormalize_complex(x):
29 | x = x/beta_rescale
30 | return (x.abs()**(1./alpha_rescale)).to(torch.complex64)*torch.exp(1j*torch.angle(x).to(torch.complex64))
31 |
32 | def wv2complex(wv, hop_size=256, fac=4):
33 | X = stft(wv, hop_size=hop_size, fac=fac, device=wv.device)
34 | return X[:,:hop_size*2,:]
35 |
36 | def wv2realimag(wv, hop_size=256, fac=4):
37 | X = wv2complex(wv, hop_size, fac)
38 | X = normalize_complex(X)
39 | return torch.stack((torch.real(X),torch.imag(X)), -3)
40 |
41 | def realimag2wv(x, hop_size=256, fac=4):
42 | x = torch.nn.functional.pad(x, (0,0,0,1))
43 | real,imag = torch.chunk(x, 2, -3)
44 | X = torch.complex(real.squeeze(-3),imag.squeeze(-3))
45 | X = denormalize_complex(X)
46 | return istft(X, fac=fac, hop_size=hop_size, device=X.device).clamp(-1.,1.)
47 |
48 | def to_representation_encoder(x):
49 | return wv2realimag(x, hop)
50 |
51 | def to_representation(x):
52 | return wv2realimag(x, hop)
53 |
54 | def to_waveform(x):
55 | return realimag2wv(x, hop)
56 |
57 | def overlap_and_add(signal, frame_step):
58 |
59 | outer_dimensions = signal.shape[:-2]
60 | outer_rank = torch.numel(torch.tensor(outer_dimensions))
61 |
62 | def full_shape(inner_shape):
63 | s = torch.cat([torch.tensor(outer_dimensions), torch.tensor(inner_shape)], 0)
64 | s = list(s)
65 | s = [int(el) for el in s]
66 | return s
67 |
68 | frame_length = signal.shape[-1]
69 | frames = signal.shape[-2]
70 |
71 | # Compute output length.
72 | output_length = frame_length + frame_step * (frames - 1)
73 |
74 | # Compute the number of segments, per frame.
75 | segments = -(-frame_length // frame_step) # Divide and round up.
76 |
77 | signal = torch.nn.functional.pad(signal, (0, segments * frame_step - frame_length, 0, segments))
78 |
79 | shape = full_shape([frames + segments, segments, frame_step])
80 | signal = torch.reshape(signal, shape)
81 |
82 | perm = torch.cat([torch.arange(0, outer_rank), torch.tensor([el+outer_rank for el in [1, 0, 2]])], 0)
83 | perm = list(perm)
84 | perm = [int(el) for el in perm]
85 | signal = torch.permute(signal, perm)
86 |
87 | shape = full_shape([(frames + segments) * segments, frame_step])
88 | signal = torch.reshape(signal, shape)
89 |
90 | signal = signal[..., :(frames + segments - 1) * segments, :]
91 |
92 | shape = full_shape([segments, (frames + segments - 1), frame_step])
93 | signal = torch.reshape(signal, shape)
94 |
95 | signal = signal.sum(-3)
96 |
97 | # Flatten the array.
98 | shape = full_shape([(frames + segments - 1) * frame_step])
99 | signal = torch.reshape(signal, shape)
100 |
101 | # Truncate to final length.
102 | signal = signal[..., :output_length]
103 |
104 | return signal
105 |
106 | def inverse_stft_window(frame_length, frame_step, forward_window):
107 | denom = forward_window**2
108 | overlaps = -(-frame_length // frame_step)
109 | denom = F.pad(denom, (0, overlaps * frame_step - frame_length))
110 | denom = torch.reshape(denom, [overlaps, frame_step])
111 | denom = torch.sum(denom, 0, keepdim=True)
112 | denom = torch.tile(denom, [overlaps, 1])
113 | denom = torch.reshape(denom, [overlaps * frame_step])
114 | return forward_window / denom[:frame_length]
115 |
116 | def istft(SP, fac=4, hop_size=256, device='cuda'):
117 | x = torch.fft.irfft(SP, dim=-2)
118 | window = torch.hann_window(fac*hop_size).to(device)
119 | window = inverse_stft_window(fac*hop_size, hop_size, window)
120 | x = x*window.unsqueeze(-1)
121 | return overlap_and_add(x.permute(0,2,1), hop_size)
122 |
123 | def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1):
124 | """
125 | equivalent of tf.signal.frame
126 | """
127 | signal_length = signal.shape[axis]
128 | if pad_end:
129 | frames_overlap = frame_length - frame_step
130 | rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap)
131 | pad_size = int(frame_length - rest_samples)
132 | if pad_size != 0:
133 | pad_axis = [0] * signal.ndim
134 | pad_axis[axis] = pad_size
135 | signal = F.pad(signal, pad_axis, "constant", pad_value)
136 | frames = signal.unfold(axis, frame_length, frame_step)
137 | return frames
138 |
139 | def stft(wv, fac=4, hop_size=256, device='cuda'):
140 | window = torch.hann_window(fac*hop_size).to(device)
141 | framed_signals = frame(wv, fac*hop_size, hop_size)
142 | framed_signals = framed_signals*window
143 | return torch.fft.rfft(framed_signals, n=None, dim=- 1, norm=None).permute(0,2,1)
144 |
145 | def normalize(S, mu_rescale=-25., sigma_rescale=75.):
146 | return (S - mu_rescale) / sigma_rescale
147 |
148 | def denormalize(S, mu_rescale=-25., sigma_rescale=75.):
149 | return (S * sigma_rescale) + mu_rescale
150 |
151 | def db2power(S_db, ref=1.0):
152 | return ref * torch.pow(10.0, 0.1 * S_db)
153 |
154 | def power2db(power, ref_value=1.0, amin=1e-10):
155 | log_spec = 10.0 * torch.log10(torch.maximum(torch.tensor(amin), power))
156 | log_spec -= 10.0 * torch.log10(torch.maximum(torch.tensor(amin), torch.tensor(ref_value)))
157 | return log_spec
--------------------------------------------------------------------------------
/neuralpiano/denoise.py:
--------------------------------------------------------------------------------
1 | # Denoise Python module
2 |
3 | import math
4 | from typing import Tuple, Dict, Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | TensorLike = Union[torch.Tensor, np.ndarray]
11 |
12 | def denoise_audio(mono: TensorLike,
13 | sr: int = 48000,
14 | n_fft: int = 4096,
15 | hop: Optional[int] = None,
16 | noise_seconds: float = 0.8,
17 | max_atten_db: float = 18.0,
18 | noise_floor_db: float = -60.0,
19 | smoothing_time_ms: float = 40.0,
20 | smoothing_freq_bins: int = 3,
21 | noise_sample: Optional[TensorLike] = None,
22 | device: Optional[torch.device] = None,
23 | dtype: torch.dtype = torch.float32
24 | ) -> Tuple[torch.Tensor, Dict]:
25 |
26 | """
27 | Conservative denoiser tuned for solo piano.
28 | Returns denoised 1-D torch.Tensor and diagnostics dict.
29 | """
30 |
31 | # --- prepare tensor ---
32 | if isinstance(mono, np.ndarray):
33 | x = torch.from_numpy(mono.astype(np.float32))
34 | elif isinstance(mono, torch.Tensor):
35 | x = mono.clone()
36 | else:
37 | raise TypeError("mono must be a numpy array or torch.Tensor")
38 |
39 | if x.ndim != 1:
40 | raise ValueError("mono must be 1-D (mono)")
41 |
42 | device = device or (x.device if isinstance(x, torch.Tensor) else torch.device('cpu'))
43 | x = x.to(device=device, dtype=dtype, copy=False)
44 | N = x.shape[-1]
45 | eps = 1e-12
46 |
47 | hop = hop or (n_fft // 4)
48 | win = torch.hann_window(n_fft, device=device, dtype=dtype)
49 |
50 | # STFT / ISTFT helpers
51 | def stft_torch(sig):
52 | return torch.stft(sig, n_fft=n_fft, hop_length=hop, win_length=n_fft,
53 | window=win, center=True, return_complex=True, pad_mode='reflect')
54 |
55 | def istft_torch(X, length):
56 | return torch.istft(X, n_fft=n_fft, hop_length=hop, win_length=n_fft,
57 | window=win, center=True, length=length)
58 |
59 | # compute STFT
60 | X = stft_torch(x) # complex tensor shape [freq_bins, frames]
61 | mag = X.abs() + eps
62 | freq_bins, frames = mag.shape
63 |
64 | # noise magnitude estimate
65 | if noise_sample is not None:
66 | if isinstance(noise_sample, np.ndarray):
67 | n_wav = torch.from_numpy(noise_sample.astype(np.float32)).to(device=device, dtype=dtype)
68 | else:
69 | n_wav = noise_sample.to(device=device, dtype=dtype)
70 | if n_wav.ndim > 1:
71 | n_wav = n_wav.mean(dim=1) if n_wav.shape[0] == 1 else n_wav.view(-1)
72 | Xn = stft_torch(n_wav)
73 | noise_mag = Xn.abs().mean(dim=1)
74 | else:
75 | if noise_seconds <= 0:
76 | frames_for_noise = 1
77 | else:
78 | approx_frames = max(1, int(math.ceil((noise_seconds * sr) / hop)))
79 | frames_for_noise = min(approx_frames, frames)
80 | noise_mag = mag[:, :frames_for_noise].mean(dim=1)
81 |
82 | # floor noise magnitude
83 | noise_floor_lin = 10.0 ** (noise_floor_db / 20.0)
84 | noise_mag = noise_mag.clamp(min=noise_floor_lin)
85 |
86 | # Wiener-like gain: S^2 / (S^2 + N^2)
87 | S2 = mag ** 2
88 | N2 = noise_mag.unsqueeze(1) ** 2
89 | G = S2 / (S2 + N2 + eps) # shape [freq_bins, frames]
90 |
91 | # convert to attenuation dB and clamp
92 | att_db = -20.0 * torch.log10(G.clamp(min=1e-12))
93 | att_db_clamped = att_db.clamp(max=max_atten_db)
94 | G_limited = 10.0 ** (-att_db_clamped / 20.0) # shape [freq_bins, frames]
95 |
96 | # -------------------------
97 | # Time smoothing (per-frequency)
98 | # -------------------------
99 | time_smooth_frames = max(1, int(round((smoothing_time_ms / 1000.0) * sr / hop)))
100 | if time_smooth_frames > 1:
101 | k = torch.hann_window(time_smooth_frames, device=device, dtype=dtype)
102 | k = k / k.sum()
103 | kernel = k.view(1, 1, -1).repeat(freq_bins, 1, 1) # [freq_bins,1,k]
104 | inp = G_limited.unsqueeze(0) # [1, freq_bins, frames]
105 | G_time = F.conv1d(inp, kernel, padding=(time_smooth_frames - 1) // 2, groups=freq_bins).squeeze(0)
106 | else:
107 | G_time = G_limited
108 |
109 | # -------------------------
110 | # Frequency smoothing (per-frame)
111 | # -------------------------
112 | if smoothing_freq_bins > 0:
113 | kf = torch.ones(smoothing_freq_bins, device=device, dtype=dtype)
114 | kf = kf / kf.sum()
115 | G_perm = G_time.permute(1, 0).unsqueeze(1) # [frames, 1, freq_bins]
116 | kernel_f = kf.view(1, 1, -1)
117 | G_freq_perm = F.conv1d(G_perm, kernel_f, padding=(smoothing_freq_bins - 1) // 2).squeeze(1) # [frames, freq_bins]
118 | G_freq = G_freq_perm.permute(1, 0) # back to [freq_bins, frames]
119 | else:
120 | G_freq = G_time
121 |
122 | # Ensure orientation is [freq_bins, frames]
123 | if G_freq.ndim != 2:
124 | raise RuntimeError("Unexpected G_freq dimensionality")
125 | if G_freq.shape[0] != freq_bins and G_freq.shape[1] == freq_bins:
126 | G_freq = G_freq.permute(1, 0)
127 |
128 | # Build frequency-dependent protection vector [freq_bins, 1]
129 | freqs = torch.linspace(0.0, float(sr) / 2.0, steps=freq_bins, device=device, dtype=dtype).unsqueeze(1)
130 | low_protect = (freqs < 200.0).to(dtype)
131 | high_allow = (freqs > 6000.0).to(dtype)
132 |
133 | # apply frequency-dependent scaling (broadcast across frames)
134 | G_final = G_freq * (1.0 - 0.35 * low_protect) * (1.0 + 0.25 * high_allow)
135 | G_final = G_final.clamp(min=0.0, max=1.0)
136 |
137 | # -------------------------
138 | # ALIGNMENT FIX: ensure G_final matches X.shape exactly
139 | # -------------------------
140 | # If smoothing changed the frame count by ±1 (or more), trim or pad G_final to match X.
141 | # Trim if longer, pad by repeating last column if shorter.
142 | if G_final.shape[0] != freq_bins:
143 | # if frequency axis mismatched, try safe transpose or resize
144 | if G_final.shape[1] == freq_bins and G_final.shape[0] == frames:
145 | G_final = G_final.permute(1, 0)
146 | else:
147 | # fallback: resize frequency axis by trimming or repeating last row
148 | if G_final.shape[0] > freq_bins:
149 | G_final = G_final[:freq_bins, :]
150 | else:
151 | pad_rows = freq_bins - G_final.shape[0]
152 | last_row = G_final[-1:, :].repeat(pad_rows, 1)
153 | G_final = torch.cat([G_final, last_row], dim=0)
154 |
155 | if G_final.shape[1] != frames:
156 | if G_final.shape[1] > frames:
157 | G_final = G_final[:, :frames]
158 | else:
159 | # pad by repeating last column
160 | pad_cols = frames - G_final.shape[1]
161 | last_col = G_final[:, -1:].repeat(1, pad_cols)
162 | G_final = torch.cat([G_final, last_col], dim=1)
163 |
164 | # final safety check
165 | if G_final.shape != (freq_bins, frames):
166 | raise RuntimeError(f"Unable to align mask to STFT shape: mask {G_final.shape}, STFT {(freq_bins, frames)}")
167 |
168 | # apply mask (preserve phase)
169 | X_denoised = X * G_final
170 |
171 | # reconstruct
172 | y = istft_torch(X_denoised, length=N)
173 |
174 | # tiny residual low-frequency subtraction (gentle)
175 | lf_cut = 60.0
176 | lf_bin = int(round(lf_cut / (sr / 2.0) * (freq_bins - 1)))
177 | if lf_bin >= 1:
178 | lf_rms = mag[:lf_bin, :].mean().item()
179 | subtract = 0.02 * lf_rms
180 | y = y - subtract * torch.mean(y)
181 |
182 | # final safety normalize (very gentle)
183 | peak_in = float(x.abs().max().item())
184 | peak_out = float(y.abs().max().item())
185 | if peak_out > 0.999:
186 | y = y * (0.999 / peak_out)
187 | final_scale = 0.999 / peak_out
188 | else:
189 | final_scale = 1.0
190 |
191 | diagnostics = {
192 | "sr": sr,
193 | "n_fft": n_fft,
194 | "hop": hop,
195 | "noise_seconds": noise_seconds,
196 | "max_atten_db": max_atten_db,
197 | "smoothing_time_ms": smoothing_time_ms,
198 | "smoothing_freq_bins": smoothing_freq_bins,
199 | "input_peak": peak_in,
200 | "output_peak": float(y.abs().max().item()),
201 | "final_scale": final_scale,
202 | }
203 |
204 | return y.to(dtype=dtype, device=device), diagnostics
--------------------------------------------------------------------------------
/neuralpiano/bass.py:
--------------------------------------------------------------------------------
1 | # Bass Python module
2 |
3 | import math
4 | from typing import Tuple, Dict, Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | TensorLike = Union[torch.Tensor, np.ndarray]
11 |
12 | def enhance_audio_bass(mono: TensorLike,
13 | sr: int = 48000,
14 | low_cutoff: float = 200.0,
15 | ir_len: int = 1025,
16 | low_gain_db: float = 8.0,
17 | sub_mix: float = 0.75,
18 | compressor_thresh_db: float = -24.0,
19 | compressor_ratio: float = 3.0,
20 | compressor_attack_ms: float = 8.0,
21 | compressor_release_ms: float = 120.0,
22 | makeup_db: float = 0.0,
23 | drive: float = 1.15,
24 | wet_mix: float = 0.9,
25 | downsample_target: int = 4000,
26 | device: Optional[torch.device] = None,
27 | dtype: torch.dtype = torch.float32
28 | ) -> Tuple[torch.Tensor, Dict]:
29 |
30 | """
31 | Fast bass enhancement optimized for GPU with robust shape alignment.
32 |
33 | Returns:
34 | enhanced 1-D torch.Tensor (same length as input), diagnostics dict.
35 | """
36 |
37 | # --- prepare input tensor ---
38 | if isinstance(mono, np.ndarray):
39 | x = torch.from_numpy(mono.astype(np.float32))
40 | elif isinstance(mono, torch.Tensor):
41 | x = mono.clone()
42 | else:
43 | raise TypeError("mono must be numpy or torch.Tensor")
44 | if x.ndim != 1:
45 | raise ValueError("mono must be 1-D (mono)")
46 |
47 | device = device or (x.device if isinstance(x, torch.Tensor) else torch.device('cpu'))
48 | x = x.to(device=device, dtype=dtype, copy=False)
49 | N = x.shape[-1]
50 | eps = 1e-12
51 |
52 | # --- helpers ---
53 | def db2lin(db): return 10.0 ** (db / 20.0)
54 | def next_pow2(n): return 1 << ((n - 1).bit_length())
55 |
56 | # linear-phase lowpass IR builder (small, efficient)
57 | def make_lowpass_ir(cutoff_hz, sr_local, length):
58 | if length % 2 == 0:
59 | length += 1
60 | t = torch.arange(length, device=device, dtype=dtype) - (length - 1) / 2.0
61 | sinc_arg = 2.0 * cutoff_hz / sr_local * t
62 | h = torch.sinc(sinc_arg)
63 | beta = 6.0
64 | # kaiser_window signature: (window_length, periodic=False, beta, *, dtype=None, layout=None, device=None)
65 | win = torch.kaiser_window(length, False, beta, dtype=dtype, device=device)
66 | h = h * win
67 | h = h / (h.sum() + eps)
68 | return h
69 |
70 | # FFT convolution using optional precomputed kernel FFT
71 | def fft_conv_signal_kernel(sig, kernel, kernel_fft=None):
72 | n = sig.shape[-1]
73 | k = kernel.shape[-1]
74 | out_len = n + k - 1
75 | size = next_pow2(out_len)
76 | S = torch.fft.rfft(sig, n=size)
77 | if kernel_fft is None:
78 | K = torch.fft.rfft(kernel, n=size)
79 | else:
80 | K = kernel_fft
81 | Y = S * K
82 | y = torch.fft.irfft(Y, n=size)[:out_len]
83 | return y[:n], K
84 |
85 | # --- 1) extract low band via FFT conv (single pass) ---
86 | lp = make_lowpass_ir(low_cutoff, sr, ir_len)
87 | low, lp_fft = fft_conv_signal_kernel(x, lp) # low is same device/dtype
88 |
89 | # --- 2) downsample low band for fast processing ---
90 | ds = max(1, int(round(sr / float(downsample_target))))
91 | ds_sr = sr // ds
92 | if ds > 1:
93 | pad = (-low.shape[-1]) % ds
94 | if pad:
95 | low_p = F.pad(low.unsqueeze(0).unsqueeze(0), (0, pad)).squeeze(0).squeeze(0)
96 | else:
97 | low_p = low
98 | # decimate by averaging each block (cheap, preserves low-band energy)
99 | low_ds = low_p.view(-1, ds).mean(dim=1)
100 | else:
101 | low_ds = low
102 |
103 | # --- 3) compressor on downsampled low band (vectorized) ---
104 | win_len = max(1, int(round(0.01 * ds_sr)))
105 | sq = low_ds * low_ds
106 | pad = (win_len - 1) // 2
107 | sq_p = F.pad(sq.unsqueeze(0).unsqueeze(0), (pad, win_len - 1 - pad)).squeeze(0).squeeze(0)
108 | kernel = torch.ones(win_len, device=device, dtype=dtype) / float(win_len)
109 | rms_sq = F.conv1d(sq_p.unsqueeze(0).unsqueeze(0), kernel.view(1,1,-1)).squeeze(0).squeeze(0)
110 | rms = torch.sqrt(rms_sq + 1e-12)
111 |
112 | # soft-knee gain computer vectorized
113 | lvl_db = 20.0 * torch.log10(rms.clamp(min=1e-12))
114 | knee = 3.0
115 | over = lvl_db - compressor_thresh_db
116 | zero = torch.zeros_like(over)
117 | gr_db = torch.where(
118 | over <= -knee,
119 | zero,
120 | torch.where(
121 | over >= knee,
122 | compressor_thresh_db + (over / compressor_ratio) - lvl_db,
123 | - ((1.0 - 1.0/compressor_ratio) * (over + knee)**2) / (4.0 * knee)
124 | )
125 | ).clamp(max=0.0)
126 | gain_lin = 10.0 ** (gr_db / 20.0)
127 |
128 | # smooth gain with FIR approximations for attack/release
129 | def exp_kernel(tc_ms, sr_local, length=64):
130 | if length < 3:
131 | length = 3
132 | tau = max(1e-6, tc_ms / 1000.0)
133 | t = torch.arange(length, device=device, dtype=dtype)
134 | k = torch.exp(-t / (tau * sr_local))
135 | k = k / k.sum()
136 | return k
137 |
138 | atk_k = exp_kernel(compressor_attack_ms, ds_sr, length=64)
139 | rel_k = exp_kernel(compressor_release_ms, ds_sr, length=128)
140 | # convolve (padding may change length slightly)
141 | g_atk = F.conv1d(gain_lin.unsqueeze(0).unsqueeze(0), atk_k.view(1,1,-1), padding=(atk_k.numel()-1)//2).squeeze(0).squeeze(0)
142 | g_smooth = F.conv1d(g_atk.unsqueeze(0).unsqueeze(0), rel_k.view(1,1,-1), padding=(rel_k.numel()-1)//2).squeeze(0).squeeze(0)
143 |
144 | # --- ALIGN: ensure g_smooth matches low_ds length ---
145 | if g_smooth.shape[0] != low_ds.shape[0]:
146 | if g_smooth.shape[0] > low_ds.shape[0]:
147 | g_smooth = g_smooth[:low_ds.shape[0]]
148 | else:
149 | pad_len = low_ds.shape[0] - g_smooth.shape[0]
150 | if g_smooth.numel() == 0:
151 | # fallback to ones if something went wrong
152 | g_smooth = torch.ones(low_ds.shape[0], device=device, dtype=dtype)
153 | else:
154 | last = g_smooth[-1:].repeat(pad_len)
155 | g_smooth = torch.cat([g_smooth, last], dim=0)
156 |
157 | # apply makeup
158 | makeup_lin = db2lin(makeup_db)
159 | low_ds_comp = low_ds * g_smooth * makeup_lin
160 |
161 | # --- 4) subharmonic generation on downsampled signal ---
162 | rect = torch.clamp(low_ds_comp, min=0.0)
163 | lp_sub_len = 513 if ds_sr >= 4000 else 257
164 | lp_sub = make_lowpass_ir(200.0, ds_sr, lp_sub_len)
165 | rect_lp, _ = fft_conv_signal_kernel(rect, lp_sub)
166 | sub_gain = db2lin(low_gain_db)
167 | sub_ds = rect_lp * sub_gain
168 |
169 | # soft saturation (vectorized)
170 | one = torch.tensor(1.0, device=device, dtype=dtype)
171 | sat_low_ds = torch.tanh(low_ds_comp * drive) / torch.tanh(one)
172 | sat_sub_ds = torch.tanh(sub_ds * (drive * 0.8)) / torch.tanh(one)
173 | enhanced_low_ds = (1.0 - sub_mix) * sat_low_ds + sub_mix * sat_sub_ds
174 |
175 | # --- 5) upsample enhanced low back to original rate ---
176 | if ds > 1:
177 | # ensure length before upsampling is consistent
178 | L_needed = (low.shape[-1] + ds - 1) // ds
179 | if enhanced_low_ds.shape[0] < L_needed:
180 | pad_len = L_needed - enhanced_low_ds.shape[0]
181 | enhanced_low_ds = torch.cat([enhanced_low_ds, enhanced_low_ds[-1:].repeat(pad_len)], dim=0)
182 | enhanced_low = F.interpolate(enhanced_low_ds.view(1,1,-1), scale_factor=ds, mode='linear', align_corners=False).view(-1)[:low.shape[-1]]
183 | else:
184 | enhanced_low = enhanced_low_ds
185 | # ensure exact length
186 | if enhanced_low.shape[0] != low.shape[-1]:
187 | if enhanced_low.shape[0] > low.shape[-1]:
188 | enhanced_low = enhanced_low[:low.shape[-1]]
189 | else:
190 | enhanced_low = torch.cat([enhanced_low, enhanced_low[-1:].repeat(low.shape[-1] - enhanced_low.shape[0])], dim=0)
191 |
192 | # --- 6) band-limit enhanced low using original lp FFT (cheap because we have lp_fft) ---
193 | enhanced_low_band, _ = fft_conv_signal_kernel(enhanced_low, lp, kernel_fft=lp_fft)
194 |
195 | # --- 7) scale wet and mix back ---
196 | wet = enhanced_low_band * db2lin(low_gain_db)
197 | out = (1.0 - wet_mix) * x + wet_mix * (x + wet)
198 |
199 | # final gentle limiter
200 | peak = float(out.abs().max().item())
201 | if peak > 0.999:
202 | out = out * (0.999 / peak)
203 |
204 | diagnostics = {
205 | "sr": sr,
206 | "low_cutoff": low_cutoff,
207 | "ir_len": ir_len,
208 | "low_gain_db": low_gain_db,
209 | "sub_mix": sub_mix,
210 | "downsample_factor": ds,
211 | "downsample_rate": ds_sr,
212 | "compressor_avg_gain_db": float(20.0 * math.log10(max(g_smooth.mean().item(), 1e-12))),
213 | "input_peak": float(x.abs().max().item()),
214 | "output_peak": float(out.abs().max().item()),
215 | }
216 |
217 | return out.to(dtype=dtype, device=device), diagnostics
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 Tegridy Code
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/neuralpiano/music2latent/inference.py:
--------------------------------------------------------------------------------
1 | import soundfile as sf
2 | import torch
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | from .hparams import *
7 | from .hparams_inference import *
8 | from .utils import *
9 | from .models import *
10 | from .audio import *
11 |
12 |
13 | def _should_show_progress():
14 | """
15 | Return True when progress bars should be shown (main process).
16 | Hides bars in distributed workers other than rank 0.
17 | """
18 | try:
19 | if torch.distributed.is_available() and torch.distributed.is_initialized():
20 | return torch.distributed.get_rank() == 0
21 | except Exception:
22 | # If distributed is not configured or any error occurs, show progress
23 | return True
24 | return True
25 |
26 |
27 | class EncoderDecoder:
28 | def __init__(self, load_multi_instrumental_model=False, use_v1_piano_model=False, load_path_inference=None, device=None):
29 | download_models()
30 | if device is None:
31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32 | else:
33 | self.device = device
34 | self.load_path_inference = load_path_inference
35 | if load_path_inference is None:
36 | if load_multi_instrumental_model:
37 | self.load_path_inference = load_path_inference_multi_instrumental_default
38 | else:
39 | if use_v1_piano_model:
40 | self.load_path_inference = load_path_inference_solo_piano_v1_default
41 | else:
42 | self.load_path_inference = load_path_inference_solo_piano_default
43 | self.get_models()
44 |
45 | def get_models(self):
46 | gen = UNet().to(self.device)
47 | checkpoint = torch.load(self.load_path_inference, map_location=self.device, weights_only=False)
48 | gen.load_state_dict(checkpoint['gen_state_dict'], strict=False)
49 | self.gen = gen
50 |
51 | def encode(self, path_or_audio, max_waveform_length=None, max_batch_size=None, extract_features=False, show_progress=True):
52 | '''
53 | path_or_audio: path of audio sample to encode or numpy array of waveform to encode
54 | max_waveform_length: maximum length of waveforms in the batch for encoding: tune it depending on the available GPU memory
55 | max_batch_size: maximum inference batch size for encoding: tune it depending on the available GPU memory
56 | extract_features: if True, return raw features (no sigma rescale)
57 | show_progress: whether to show tqdm progress bars
58 |
59 | WARNING! if input is numpy array of stereo waveform, it must have shape [waveform_samples, audio_channels]
60 |
61 | Returns latents with shape [audio_channels, dim, length]
62 | '''
63 | if max_waveform_length is None:
64 | max_waveform_length = max_waveform_length_encode
65 | if max_batch_size is None:
66 | max_batch_size = max_batch_size_encode
67 | return encode_audio_inference(path_or_audio, self, max_waveform_length, max_batch_size, device=self.device, extract_features=extract_features, show_progress=show_progress)
68 |
69 | def decode(self, latent, denoising_steps=1, max_waveform_length=None, max_batch_size=None, show_progress=True):
70 | '''
71 | latent: numpy array of latents to decode with shape [audio_channels, dim, length]
72 | denoising_steps: number of denoising steps to use for decoding
73 | max_waveform_length: maximum length of waveforms in the batch for decoding: tune it depending on the available GPU memory
74 | max_batch_size: maximum inference batch size for decoding: tune it depending on the available GPU memory
75 | show_progress: whether to show tqdm progress bars
76 |
77 | Returns numpy array of decoded waveform with shape [waveform_samples, audio_channels]
78 | '''
79 | if max_waveform_length is None:
80 | max_waveform_length = max_waveform_length_decode
81 | if max_batch_size is None:
82 | max_batch_size = max_batch_size_decode
83 | return decode_latent_inference(latent, self, max_waveform_length, max_batch_size, diffusion_steps=denoising_steps, device=self.device, show_progress=show_progress)
84 |
85 |
86 |
87 |
88 |
89 |
90 | # decode samples with consistency model to real/imag STFT spectrograms
91 | # Parameters:
92 | # model: trained consistency model
93 | # latents: latent representation with shape [audio_channels/batch_size, dim, length]
94 | # diffusion_steps: number of steps
95 | # Returns:
96 | # decoded_spectrograms with shape [audio_channels/batch_size, data_channels, hop*2, length*downscaling_factor]
97 | def decode_to_representation(model, latents, diffusion_steps=1, device='cuda'):
98 | num_samples = latents.shape[0]
99 | downscaling_factor = 2**freq_downsample_list.count(0)
100 | sample_length = int(latents.shape[-1]*downscaling_factor)
101 | initial_noise = torch.randn((num_samples, data_channels, hop*2, sample_length)).to(device)*sigma_max
102 | decoded_spectrograms = reverse_diffusion(model, initial_noise, diffusion_steps, latents=latents)
103 | return decoded_spectrograms
104 |
105 |
106 |
107 |
108 | # Encode audio sample for inference
109 | # Parameters:
110 | # audio_path: path of audio sample
111 | # model: trained consistency model
112 | # device: device to run the model on
113 | # Returns:
114 | # latent: compressed latent representation with shape [audio_channels, dim, latent_length]
115 | @torch.no_grad()
116 | def encode_audio_inference(audio_path, trainer, max_waveform_length_encode, max_batch_size_encode, device='cuda', extract_features=False, show_progress=True):
117 | trainer.gen = trainer.gen.to(device)
118 | trainer.gen.eval()
119 | downscaling_factor = 2**freq_downsample_list.count(0)
120 | if is_path(audio_path):
121 | audio, sr = sf.read(audio_path, dtype='float32', always_2d=True)
122 | audio = np.transpose(audio, [1,0])
123 | else:
124 | audio = audio_path
125 | sr = None
126 | if len(audio.shape)==1:
127 | # check if audio is numpy array, then use np.expand_dims, if it is a pytorch tensor, then use torch.unsqueeze
128 | if isinstance(audio, np.ndarray):
129 | audio = np.expand_dims(audio, 0)
130 | else:
131 | audio = torch.unsqueeze(audio, 0)
132 | audio_channels = audio.shape[0]
133 | if isinstance(audio, np.ndarray):
134 | audio = torch.from_numpy(audio).to(device)
135 | else:
136 | # check if audio tensor is on cpu. if it is, move it to the device
137 | if audio.device.type=='cpu':
138 | audio = audio.to(device)
139 |
140 | # EXPERIMENTAL: crop audio to be divisible by downscaling_factor
141 | cropped_length = ((((audio.shape[-1]-3*hop)//hop)//downscaling_factor)*hop*downscaling_factor)+3*hop
142 | audio = audio[:,:cropped_length]
143 |
144 | repr_encoder = to_representation_encoder(audio)
145 | sample_length = repr_encoder.shape[-1]
146 | max_sample_length = (int(max_waveform_length_encode/hop)//downscaling_factor)*downscaling_factor
147 |
148 | # if repr_encoder is longer than max_sample_length, chunk it into max_sample_length chunks, the last chunk will be zero-padded, then concatenate the chunks into the batch dimension (before encoding them)
149 | pad_size = 0
150 | if sample_length > max_sample_length:
151 | # pad repr_encoder with copies of the sample to be divisible by max_sample_length
152 | pad_size = max_sample_length - (sample_length % max_sample_length)
153 | # repeat repr_encoder such that repr_encoder.shape[-1] is higher than pad_size, then crop it such that repr_encoder.shape[-1]=pad_size
154 | repr_encoder_pad = torch.cat([repr_encoder for _ in range(1+(pad_size//repr_encoder.shape[-1]))], dim=-1)[:,:,:,:pad_size]
155 | repr_encoder = torch.cat([repr_encoder, repr_encoder_pad], dim=-1)
156 | repr_encoder = torch.split(repr_encoder, max_sample_length, dim=-1)
157 | repr_encoder = torch.cat(repr_encoder, dim=0)
158 | # encode repr_encoder using a maximum batch size (dimension 0) of max_batch_size_inference, if repr_encoder is longer than max_batch_size_inference, chunk it into max_batch_size_inference chunks, the last chunk will maybe have less samples in the batch, then encode the chunks and concatenate the results into the batch dimension
159 | max_batch_size = max_batch_size_encode
160 | if repr_encoder.shape[0] > max_batch_size:
161 | repr_encoder_ls = torch.split(repr_encoder, max_batch_size, dim=0)
162 | latent_ls = []
163 | show = show_progress and _should_show_progress()
164 | for chunk in tqdm(repr_encoder_ls, desc="Encoding chunks", unit="chunk", leave=False, disable=not show):
165 | # Use autocast as before; mixed_precision controls dtype
166 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
167 | latent_chunk = trainer.gen.encoder(chunk, extract_features=extract_features)
168 | latent_ls.append(latent_chunk)
169 | latent = torch.cat(latent_ls, dim=0)
170 | else:
171 | # single batch encode
172 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
173 | latent = trainer.gen.encoder(repr_encoder, extract_features=extract_features)
174 | # split samples
175 | if latent.shape[0]>1:
176 | latent_ls = torch.split(latent, audio_channels, 0)
177 | latent = torch.cat(latent_ls, -1)
178 | latent = latent[:,:,:latent.shape[-1]-(pad_size//downscaling_factor)]
179 | if extract_features:
180 | return latent
181 | else:
182 | return latent/sigma_rescale
183 |
184 |
185 |
186 | # Decode latent representation for inference, use the same framework as in encode_audio_inference, but in reverse order for decoding
187 | # Parameters:
188 | # latent: compressed latent representation with shape [audio_channels, dim, length]
189 | # model: trained consistency model
190 | # diffusion_steps: number of diffusion steps to use for decoding
191 | # device: device to run the model on
192 | # Returns:
193 | # audio: numpy array of decoded waveform with shape [waveform_samples, audio_channels]
194 | @torch.no_grad()
195 | def decode_latent_inference(latent, trainer, max_waveform_length_decode, max_batch_size_decode, diffusion_steps=1, device='cuda', show_progress=True):
196 | trainer.gen = trainer.gen.to(device)
197 | trainer.gen.eval()
198 | downscaling_factor = 2**freq_downsample_list.count(0)
199 | latent = latent*sigma_rescale
200 | # check if latent is numpy array, then convert to tensor
201 | if isinstance(latent, np.ndarray):
202 | latent = torch.from_numpy(latent)
203 | # check if latent tensor is on cpu. if it is, move it to the device
204 | if latent.device.type=='cpu':
205 | latent = latent.to(device)
206 | # if latent has only 2 dimensions, add a third dimension as axis 0
207 | if len(latent.shape)==2:
208 | latent = torch.unsqueeze(latent, 0)
209 | audio_channels = latent.shape[0]
210 | latent_length = latent.shape[-1]
211 | max_latent_length = int(max_waveform_length_decode/hop)//downscaling_factor
212 |
213 | # if latent is longer than max_latent_length, chunk it into max_latent_length chunks, the last chunk will be zero-padded, then concatenate the chunks into the batch dimension (before decoding them)
214 | pad_size = 0
215 | if latent_length > max_latent_length:
216 | # pad latent with copies of itself to be divisible by max_latent_length
217 | pad_size = max_latent_length - (latent_length % max_latent_length)
218 | # repeat latent such that latent.shape[-1] is higher than pad_size, then crop it such that latent.shape[-1]=pad_size
219 | latent_pad = torch.cat([latent for _ in range(1+(pad_size//latent.shape[-1]))], dim=-1)[:,:,:pad_size]
220 | latent = torch.cat([latent, latent_pad], dim=-1)
221 | latent = torch.split(latent, max_latent_length, dim=-1)
222 | latent = torch.cat(latent, dim=0)
223 | # decode latent using a maximum batch size (dimension 0) of max_batch_size_inference, if latent is longer than max_batch_size_inference, chunk it into max_batch_size_inference chunks, the last chunk will maybe have less samples in the batch, then decode the chunks and concatenate the results into the batch dimension
224 | max_batch_size = max_batch_size_decode
225 | show = show_progress and _should_show_progress()
226 | if latent.shape[0] > max_batch_size:
227 | latent_ls = torch.split(latent, max_batch_size, dim=0)
228 | repr_ls = []
229 | for chunk in tqdm(latent_ls, desc="Decoding chunks", unit="chunk", leave=False, disable=not show):
230 | repr_chunk = decode_to_representation(trainer.gen, chunk, diffusion_steps=diffusion_steps, device=device)
231 | repr_ls.append(repr_chunk)
232 | repr = torch.cat(repr_ls, dim=0)
233 | else:
234 | # single batch decode
235 | repr = decode_to_representation(trainer.gen, latent, diffusion_steps=diffusion_steps, device=device)
236 | # split samples
237 | if repr.shape[0]>1:
238 | repr_ls = torch.split(repr, audio_channels, 0)
239 | repr = torch.cat(repr_ls, -1)
240 | repr = repr[:,:,:,:repr.shape[-1]-(pad_size*downscaling_factor)]
241 | return to_waveform(repr)
--------------------------------------------------------------------------------
/neuralpiano/mixer.py:
--------------------------------------------------------------------------------
1 | # Mixer Python module
2 |
3 | import os
4 | import math
5 | from typing import List, Optional, Union, Sequence, Tuple
6 | import numpy as np
7 | import librosa
8 | import soundfile as sf
9 | from concurrent.futures import ThreadPoolExecutor, as_completed
10 |
11 | # Optional progress bar
12 | try:
13 | from tqdm import tqdm
14 | _HAS_TQDM = True
15 | except Exception:
16 | _HAS_TQDM = False
17 |
18 | def _ensure_list_defaults(values, n, default=0.0):
19 | if values is None:
20 | return [default] * n
21 | if isinstance(values, (int, float)):
22 | return [float(values)] * n
23 | vals = list(values)
24 | if len(vals) >= n:
25 | return vals[:n]
26 | return vals + [default] * (n - len(vals))
27 |
28 | def _resample_channel(channel, orig_sr, target_sr, res_type='kaiser_fast'):
29 | if orig_sr == target_sr:
30 | return channel
31 | return librosa.resample(channel, orig_sr=orig_sr, target_sr=target_sr, res_type=res_type)
32 |
33 | def _apply_fades(track: np.ndarray, sr: int, fade_in_s: float, fade_out_s: float) -> np.ndarray:
34 | n_samples = track.shape[1]
35 | if n_samples == 0:
36 | return track
37 | fi = max(0, int(round(fade_in_s * sr)))
38 | fo = max(0, int(round(fade_out_s * sr)))
39 | if fi + fo > n_samples and (fi + fo) > 0:
40 | scale = n_samples / (fi + fo)
41 | fi = int(round(fi * scale))
42 | fo = int(round(fo * scale))
43 | env = np.ones(n_samples, dtype=np.float32)
44 | if fi > 0:
45 | t = np.linspace(0.0, 1.0, fi, endpoint=True, dtype=np.float32)
46 | env[:fi] = 0.5 * (1.0 - np.cos(np.pi * t))
47 | if fo > 0:
48 | t = np.linspace(0.0, 1.0, fo, endpoint=True, dtype=np.float32)
49 | env[-fo:] = 0.5 * (1.0 + np.cos(np.pi * t))
50 | return track * env[np.newaxis, :]
51 |
52 | def _normalize_input_array(arr: np.ndarray) -> np.ndarray:
53 | arr = np.asarray(arr)
54 | if arr.ndim == 1:
55 | return np.expand_dims(arr.astype(np.float32), 0)
56 | if arr.ndim == 2:
57 | # Heuristic: if shape looks like (samples, channels) transpose to (channels, samples)
58 | if arr.shape[1] <= 2 and arr.shape[0] > arr.shape[1]:
59 | return arr.T.astype(np.float32)
60 | return arr.astype(np.float32)
61 | raise ValueError("Input numpy array must be 1D or 2D (samples,channels) or (channels,samples)")
62 |
63 | def mix_audio(input_items: Sequence[Union[str, Tuple[np.ndarray, int]]],
64 | output_path: str = 'mixed.wav',
65 | gain_db: Optional[Union[float, List[float]]] = 0.0,
66 | pan: Optional[Union[float, List[float]]] = 0.0,
67 | delays: Optional[Union[float, List[float]]] = 0.0,
68 | fade_in: Optional[Union[float, List[float]]] = 0.0,
69 | fade_out: Optional[Union[float, List[float]]] = 0.0,
70 | target_sr: Optional[int] = None,
71 | normalize: bool = True,
72 | trim_trailing_silence: bool = False,
73 | trim_threshold_db: float = -60.0,
74 | trim_padding_seconds: float = 0.01,
75 | output_subtype: Optional[str] = None,
76 | workers: int = 4,
77 | show_progress: bool = False,
78 | verbose: bool = False,
79 | return_mix: bool = False
80 | ) -> Optional[Tuple[np.ndarray, int]]:
81 |
82 | """
83 | Mix inputs (file paths and/or (array, sr) tuples) into one audio file.
84 | If return_mix is True the function returns (mix_array, sample_rate) where mix_array
85 | is shaped (samples, channels) and dtype float32.
86 |
87 | All other parameters behave as in previous versions:
88 | - gain_db, pan, delays, fade_in, fade_out accept single values or shorter lists.
89 | - target_sr defaults to the highest input sample rate.
90 | - normalize scales final peak to 0.999 when True.
91 | - trim_trailing_silence removes trailing silence below trim_threshold_db.
92 | - workers controls parallel processing.
93 | - show_progress uses tqdm if available.
94 | - verbose prints progress messages.
95 | """
96 |
97 | n = len(input_items)
98 | if n == 0:
99 | raise ValueError("input_items must contain at least one element")
100 |
101 | # Prepare per-track parameter lists (allow shorter lists)
102 | gains_db = _ensure_list_defaults(gain_db, n, default=0.0)
103 | pans = _ensure_list_defaults(pan, n, default=0.0)
104 | delays_sec = _ensure_list_defaults(delays, n, default=0.0)
105 | fades_in = _ensure_list_defaults(fade_in, n, default=0.0)
106 | fades_out = _ensure_list_defaults(fade_out, n, default=0.0)
107 | pans = [max(-1.0, min(1.0, float(p))) for p in pans]
108 |
109 | if verbose:
110 | print(f"[mix_audio] Preparing {n} inputs...")
111 |
112 | tracks = []
113 | srs = []
114 | channel_counts = []
115 |
116 | load_iter = list(enumerate(input_items))
117 | if show_progress and _HAS_TQDM:
118 | load_iter = list(tqdm(load_iter, desc="Loading inputs", unit="item"))
119 |
120 | for idx, item in load_iter:
121 | if isinstance(item, str):
122 | if verbose:
123 | print(f" loading file [{idx+1}/{n}]: {item}")
124 | y, sr = librosa.load(item, sr=None, mono=False)
125 | if y.ndim == 1:
126 | y = np.expand_dims(y, 0)
127 | tracks.append(y.astype(np.float32))
128 | srs.append(sr)
129 | channel_counts.append(y.shape[0])
130 | elif isinstance(item, (tuple, list)) and len(item) == 2:
131 | arr, sr = item
132 | if verbose:
133 | print(f" using array input [{idx+1}/{n}] with sr={sr}")
134 | y = _normalize_input_array(arr)
135 | tracks.append(y.astype(np.float32))
136 | srs.append(int(sr))
137 | channel_counts.append(y.shape[0])
138 | else:
139 | raise ValueError("Each input item must be a file path or a tuple (array, sr)")
140 |
141 | # Decide target sample rate
142 | if target_sr is None:
143 | target_sr = max(srs)
144 | if verbose:
145 | print(f"[mix_audio] Target sample rate: {target_sr} Hz")
146 |
147 | # Determine output channel count: stereo if any input has >=2 channels, else mono
148 | out_ch = 2 if any(c >= 2 for c in channel_counts) else 1
149 | if verbose:
150 | print(f"[mix_audio] Output channels: {out_ch}")
151 |
152 | # Process each track: resample, channel handling, panning, fades, apply gain
153 | processed = [None] * n
154 |
155 | def process_track(i):
156 | t = tracks[i]
157 | sr = srs[i]
158 | g_db = gains_db[i]
159 | p_val = pans[i]
160 | delay = float(delays_sec[i])
161 | fi = float(fades_in[i])
162 | fo = float(fades_out[i])
163 |
164 | # Downmix multichannel (>2) to mono first
165 | if t.shape[0] > 2:
166 | t = np.expand_dims(np.mean(t, axis=0), 0)
167 |
168 | # Resample channels
169 | if sr != target_sr:
170 | if workers > 1 and t.shape[0] > 1:
171 | with ThreadPoolExecutor(max_workers=min(workers, t.shape[0])) as ex:
172 | futures = [ex.submit(_resample_channel, t[ch], sr, target_sr, 'kaiser_fast') for ch in range(t.shape[0])]
173 | resampled_ch = [f.result() for f in futures]
174 | t = np.vstack(resampled_ch)
175 | else:
176 | t = np.vstack([_resample_channel(t[ch], sr, target_sr, 'kaiser_fast') for ch in range(t.shape[0])])
177 |
178 | # Channel handling and panning
179 | if t.shape[0] == 1 and out_ch == 2:
180 | mono = t[0]
181 | angle = (p_val + 1.0) * (math.pi / 4.0)
182 | left_gain = math.cos(angle)
183 | right_gain = math.sin(angle)
184 | left = mono * left_gain
185 | right = mono * right_gain
186 | t = np.vstack([left, right])
187 | elif t.shape[0] == 2 and out_ch == 2:
188 | left, right = t[0], t[1]
189 | mono = 0.5 * (left + right)
190 | angle = (p_val + 1.0) * (math.pi / 4.0)
191 | left_gain = math.cos(angle)
192 | right_gain = math.sin(angle)
193 | t = np.vstack([mono * left_gain, mono * right_gain])
194 | elif t.shape[0] == 2 and out_ch == 1:
195 | mono = 0.5 * (t[0] + t[1])
196 | t = np.expand_dims(mono, 0)
197 |
198 | # Apply fades (in seconds)
199 | if (fi > 0.0) or (fo > 0.0):
200 | t = _apply_fades(t, target_sr, fi, fo)
201 |
202 | # Apply gain (dB -> linear)
203 | lin = 10.0 ** (g_db / 20.0)
204 | t = t * lin
205 |
206 | # Convert delay seconds to sample offset (can be negative)
207 | offset_samples = int(round(delay * target_sr))
208 | return t, offset_samples
209 |
210 | if verbose:
211 | print("[mix_audio] Resampling and processing tracks...")
212 |
213 | # Parallel processing of tracks
214 | if workers > 1:
215 | with ThreadPoolExecutor(max_workers=workers) as ex:
216 | futures = {ex.submit(process_track, i): i for i in range(n)}
217 | if show_progress and _HAS_TQDM:
218 | pbar = tqdm(total=n, desc="Processing", unit="track")
219 | for fut in as_completed(futures):
220 | i = futures[fut]
221 | processed[i] = fut.result()
222 | if show_progress and _HAS_TQDM:
223 | pbar.update(1)
224 | if show_progress and _HAS_TQDM:
225 | pbar.close()
226 | else:
227 | proc_iter = range(n)
228 | if show_progress and _HAS_TQDM:
229 | proc_iter = tqdm(proc_iter, desc="Processing", unit="track")
230 | for i in proc_iter:
231 | processed[i] = process_track(i)
232 |
233 | # Determine final mix length considering offsets
234 | end_positions = []
235 | for t, offset in processed:
236 | start = max(0, offset)
237 | end = start + t.shape[1]
238 | end_positions.append(end)
239 | max_len = max(end_positions) if end_positions else 0
240 | min_start = min(offset for _, offset in processed)
241 |
242 | # If there are negative offsets, shift everything forward so earliest sample >= 0
243 | shift_forward = 0
244 | if min_start < 0:
245 | shift_forward = -min_start
246 | max_len += shift_forward
247 | if verbose:
248 | print(f"[mix_audio] Negative offsets detected. Shifting all tracks forward by {shift_forward} samples")
249 |
250 | # Create mix buffer and add tracks at offsets
251 | mix = np.zeros((out_ch, max_len), dtype=np.float32)
252 | if show_progress and _HAS_TQDM:
253 | mix_iter = tqdm(processed, desc="Mixing", unit="track")
254 | else:
255 | mix_iter = processed
256 |
257 | for (t, offset) in mix_iter:
258 | start = offset + shift_forward
259 | if start < 0:
260 | clip = -start
261 | if clip >= t.shape[1]:
262 | continue
263 | t = t[:, clip:]
264 | start = 0
265 | end = start + t.shape[1]
266 | mix[:, start:end] += t
267 |
268 | # Normalize to avoid clipping
269 | if normalize:
270 | peak = np.max(np.abs(mix))
271 | if peak > 0:
272 | mix *= (0.999 / peak)
273 | if verbose:
274 | print(f"[mix_audio] Normalized by factor {(0.999/peak):.6f}")
275 |
276 | # Optional trailing silence removal
277 | if trim_trailing_silence:
278 | if verbose:
279 | print(f"[mix_audio] Trimming trailing silence below {trim_threshold_db} dBFS")
280 | threshold_lin = 10.0 ** (trim_threshold_db / 20.0)
281 | abs_max = np.max(np.abs(mix), axis=0)
282 | non_silent = np.where(abs_max > threshold_lin)[0]
283 | if non_silent.size > 0:
284 | last_idx = int(non_silent[-1])
285 | pad_samples = int(round(trim_padding_seconds * target_sr))
286 | new_len = min(mix.shape[1], last_idx + 1 + pad_samples)
287 | mix = mix[:, :new_len]
288 | if verbose:
289 | print(f"[mix_audio] Trimmed to {new_len} samples ({new_len/target_sr:.3f} s)")
290 | else:
291 | keep = int(round(trim_padding_seconds * target_sr))
292 | mix = mix[:, :keep]
293 | if verbose:
294 | print(f"[mix_audio] All silent; keeping {keep} samples ({keep/target_sr:.3f} s)")
295 |
296 | out = mix.T # (samples, channels)
297 |
298 | # Infer format from extension and validate
299 | ext = os.path.splitext(output_path)[1].lower().lstrip('.')
300 | fmt = ext.upper()
301 | available = sf.available_formats()
302 | if fmt not in available:
303 | raise ValueError(f"Output format {fmt} not supported by soundfile. Available: {list(available.keys())}")
304 |
305 | if return_mix:
306 | print(f"[mix_audio] Returning output")
307 | # Return a copy to avoid accidental modification of internal buffer
308 | return out.copy(), int(target_sr)
309 |
310 | if verbose:
311 | print(f"[mix_audio] Writing output to {output_path} (format={fmt}, subtype={output_subtype})")
312 |
313 | sf.write(output_path, out, samplerate=target_sr, format=fmt, subtype=output_subtype)
314 |
315 | if verbose:
316 | print("[mix_audio] Done.")
317 |
318 | return output_path
--------------------------------------------------------------------------------
/neuralpiano/enhancer.py:
--------------------------------------------------------------------------------
1 | # Enhancer Python module
2 |
3 | """
4 |
5 | A lightweight PyTorch-based audio enhancement module focused on
6 | reducing reverb and improving clarity for solo piano or other
7 | monophonic acoustic recordings.
8 |
9 | Features
10 | --------
11 | - STFT-based spectral denoising with Wiener-style blending.
12 | - 2-D smoothing of spectral gain (time + frequency).
13 | - Band-specific shaping (low / mid / high) and transient preservation.
14 | - Mild multiband compression to control dynamics across frequency bands.
15 | - Subtle harmonic excitation on the highest band to add perceived presence.
16 | - Gentle residual smoothing subtraction to reduce perceived reverb.
17 | - Final limiter and RMS normalization.
18 | - Optional overall gain in dB applied before final limiter/normalization.
19 | - Optional stereo output (mono duplication + per-channel normalization).
20 |
21 | Design notes
22 | ------------
23 | This module is intended for offline processing of single-channel
24 | recordings. It uses reasonably large FFT sizes and smoothing kernels
25 | to produce stable, musical results. Defaults are tuned for piano
26 | recordings sampled at 48 kHz but are configurable.
27 |
28 | Example
29 | -------
30 | >>> import soundfile as sf
31 | >>> from enhancer import enhance_audio_full
32 | >>> audio, sr = sf.read("piano_mono.wav")
33 | >>> enhanced, shape = enhance_audio_full(audio, sr=sr, overall_gain_db=-1.0, output_as_stereo=True)
34 | >>> # enhanced is a numpy array (if input was numpy) or torch.Tensor (if input was torch)
35 | >>> # shape describes the returned array shape: (2, samples) for stereo or (samples,) for mono
36 | """
37 |
38 | from typing import Union, Optional, Tuple
39 | import numpy as np
40 | import torch
41 | import torch.nn.functional as F
42 |
43 | TensorOrArray = Union[torch.Tensor, np.ndarray]
44 |
45 | def _to_torch(x: TensorOrArray, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
46 | if isinstance(x, np.ndarray):
47 | t = torch.from_numpy(x)
48 | else:
49 | t = x.clone()
50 | if not torch.is_floating_point(t):
51 | t = t.float()
52 | return t.to(device=device, dtype=dtype).flatten()
53 |
54 |
55 | def _to_output(t: torch.Tensor, orig: TensorOrArray, return_type: Optional[str]) -> TensorOrArray:
56 | out_type = return_type or ('numpy' if isinstance(orig, np.ndarray) else 'torch')
57 | if out_type == 'numpy':
58 | return t.cpu().numpy()
59 | return t
60 |
61 |
62 | def _rms_val(x: torch.Tensor, eps: float = 1e-12) -> float:
63 | return float(torch.sqrt(torch.mean(x**2) + eps).item())
64 |
65 |
66 | def _soft_clip(x: torch.Tensor, drive: float = 1.0) -> torch.Tensor:
67 | return torch.tanh(x * drive) / (torch.tanh(torch.tensor(1.0, device=x.device)) + 1e-12)
68 |
69 |
70 | def _multiband_compress(mag: torch.Tensor,
71 | freq_bins: torch.Tensor,
72 | sr: int,
73 | bands: tuple = ((20, 200), (200, 2000), (2000, 8000)),
74 | thresholds_db: tuple = (-18.0, -18.0, -18.0),
75 | ratios: tuple = (1.0, 1.8, 2.2),
76 | attack_frames: int = 1,
77 | release_frames: int = 8,
78 | device: torch.device = torch.device('cpu'),
79 | dtype: torch.dtype = torch.float32) -> torch.Tensor:
80 | """
81 |
82 | Simple multiband compressor applied to magnitude spectrogram.
83 |
84 | Parameters
85 | ----------
86 | mag : torch.Tensor
87 | Magnitude spectrogram (bins, frames).
88 | freq_bins : torch.Tensor
89 | Frequency values for each bin (bins,).
90 | sr : int
91 | Sample rate in Hz.
92 | bands : tuple
93 | Sequence of (low_hz, high_hz) band boundaries.
94 | thresholds_db : tuple
95 | Thresholds per band in dB (linear conversion applied internally).
96 | ratios : tuple
97 | Compression ratios per band.
98 | attack_frames : int
99 | Attack smoothing window in frames (approximate).
100 | release_frames : int
101 | Release smoothing window in frames (approximate).
102 | device, dtype : torch types
103 | Device and dtype for intermediate tensors.
104 |
105 | Returns
106 | -------
107 | torch.Tensor
108 | Compressed magnitude spectrogram (bins, frames).
109 | """
110 |
111 | bins, frames = mag.shape
112 | out = mag.clone()
113 | for i, band in enumerate(bands):
114 | lo, hi = band
115 | mask = ((freq_bins >= lo) & (freq_bins < hi)).float().unsqueeze(1)
116 | if mask.sum() < 1:
117 | continue
118 | band_mag = (mag * mask).sum(dim=0) / (mask.sum() + 1e-12)
119 | if attack_frames > 0:
120 | env = torch.sqrt(F.conv1d(band_mag.unsqueeze(0).unsqueeze(0)**2,
121 | torch.ones(1, 1, attack_frames, device=device, dtype=dtype) / attack_frames,
122 | padding=0).squeeze() + 1e-12)
123 | else:
124 | env = band_mag.abs()
125 | threshold = 10 ** (thresholds_db[i] / 20.0)
126 | ratio = ratios[i]
127 | gain = torch.ones_like(env)
128 | over = env > threshold
129 | if over.any():
130 | gain[over] = (threshold + (env[over] - threshold) / ratio) / (env[over] + 1e-12)
131 | kernel = torch.ones(release_frames, device=device, dtype=dtype) / float(release_frames)
132 | g = F.pad(gain.unsqueeze(0).unsqueeze(0),
133 | (release_frames // 2, release_frames - 1 - release_frames // 2),
134 | mode='replicate')
135 | g = F.conv1d(g, kernel.view(1, 1, release_frames)).squeeze()
136 | out = out * (1.0 - mask) + (out * mask) * g.unsqueeze(0)
137 | return out
138 |
139 |
140 | def enhance_audio_full(audio: TensorOrArray,
141 | sr: int = 48000,
142 | device: Union[str, torch.device] = 'cuda',
143 | dtype: torch.dtype = torch.float32,
144 | n_fft: int = 8192,
145 | hop_length: int = 2048,
146 | win_length: Optional[int] = None,
147 | hp_cut_hz: float = 30.0,
148 | denoise_strength: float = 0.55,
149 | min_gain: float = 0.25,
150 | time_smooth_k: int = 9,
151 | freq_smooth_k: int = 15,
152 | low_gain_db: float = -1.8,
153 | mid_gain_db: float = 1.6,
154 | high_gain_db: float = 1.8,
155 | transient_boost: float = 1.12,
156 | excite_amount: float = 0.01,
157 | excite_scale: float = 0.02,
158 | limiter_threshold_db: float = -0.5,
159 | target_rms_db: float = -18.0,
160 | overall_gain_db: float = -1.0,
161 | output_as_stereo: bool = False,
162 | return_type: Optional[str] = None,
163 | verbose: bool = False
164 | ) -> Tuple[TensorOrArray, Tuple[int, ...]]:
165 |
166 | """
167 | Enhance a full audio buffer using STFT-domain processing.
168 |
169 | This function performs a sequence of spectral processing steps designed
170 | to reduce reverberation, suppress noise, preserve transients, and
171 | increase perceived clarity and presence for solo piano or similar
172 | acoustic material.
173 |
174 | Parameters
175 | ----------
176 | audio : numpy.ndarray or torch.Tensor
177 | Input audio. Can be a 1-D array/tensor (samples,) or a 2-D array/tensor
178 | with channels. If multi-channel is provided, channels are averaged to
179 | mono before processing.
180 | sr : int, optional
181 | Sample rate in Hz (default 48000).
182 | device : str or torch.device, optional
183 | Device to run processing on (default 'cuda'). If CUDA is not available,
184 | set to 'cpu'.
185 | dtype : torch.dtype, optional
186 | Floating dtype for processing (default torch.float32).
187 | n_fft : int, optional
188 | FFT size for STFT (default 8192).
189 | hop_length : int, optional
190 | Hop length in samples between STFT frames (default 2048).
191 | win_length : int or None, optional
192 | Window length for STFT. If None, defaults to n_fft.
193 | hp_cut_hz : float, optional
194 | High-pass cutoff frequency in Hz applied to spectral gain (default 30.0).
195 | denoise_strength : float, optional
196 | Controls amount of spectral subtraction and Wiener blending (0..1).
197 | min_gain : float, optional
198 | Minimum magnitude floor applied after processing (linear).
199 | time_smooth_k : int, optional
200 | Kernel size (frames) for temporal smoothing of spectral gain.
201 | freq_smooth_k : int, optional
202 | Kernel size (bins) for frequency smoothing of spectral gain.
203 | low_gain_db, mid_gain_db, high_gain_db : float, optional
204 | Per-band gain adjustments in dB applied after denoising.
205 | transient_boost : float, optional
206 | Multiplier for transient preservation (values >= 1.0).
207 | excite_amount : float, optional
208 | Drive for harmonic excitation signal generation (small positive).
209 | excite_scale : float, optional
210 | Scaling factor for adding harmonic excitation back into STFT.
211 | limiter_threshold_db : float, optional
212 | Peak limiter threshold in dB (default -0.5 dB).
213 | target_rms_db : float, optional
214 | Target RMS level in dBFS for final normalization (default -18 dB).
215 | overall_gain_db : float, optional
216 | Final overall gain in dB applied before limiter/normalization.
217 | A recommended default is -1.0 dB to avoid clipping while preserving
218 | perceived loudness improvements.
219 | output_as_stereo : bool, optional
220 | If True, duplicate the processed mono signal into two channels and
221 | normalize each channel to the target RMS (mono duplication + norm).
222 | return_type : str or None, optional
223 | If 'numpy', returns numpy arrays; if 'torch', returns torch tensors;
224 | if None, the return type matches the input type.
225 | verbose : bool, optional
226 | If True, prints processing debug information.
227 |
228 | Returns
229 | -------
230 | (enhanced, shape) : tuple
231 | - enhanced : numpy.ndarray or torch.Tensor
232 | The processed audio. If `output_as_stereo` is False, this is a 1-D
233 | array/tensor with shape (samples,). If `output_as_stereo` is True,
234 | this is a 2-D array/tensor with shape (2, samples) representing
235 | stereo channels.
236 | - shape : tuple
237 | A small descriptor of the returned shape:
238 | - (n,) for mono output where n is number of samples
239 | - (2, n) for stereo output
240 |
241 | Notes
242 | -----
243 | - The function converts multi-channel inputs to mono by averaging channels.
244 | - The `overall_gain_db` is applied before the final limiter and RMS
245 | normalization so that the limiter can prevent clipping if the gain
246 | increases peaks above the threshold.
247 | - When `output_as_stereo` is True the mono signal is duplicated and each
248 | channel is scaled to match the `target_rms_db` level independently.
249 | - For best results with piano recordings, use a sample rate of 44.1 kHz
250 | or 48 kHz and keep the input as a clean mono take when possible.
251 |
252 | Example
253 | -------
254 | >>> enhanced, shape = enhance_audio_full(audio, sr=48000, overall_gain_db=-1.0, output_as_stereo=True)
255 | """
256 |
257 | device = torch.device(device)
258 | x = _to_torch(audio, device=device, dtype=dtype)
259 | if x.dim() != 1:
260 | if x.dim() == 2:
261 | # If input is (channels, samples) or (samples, channels), try to reduce to mono by averaging channels
262 | if x.shape[0] <= 2 and x.shape[0] < x.shape[1]:
263 | x = x.mean(dim=0)
264 | else:
265 | x = x.mean(dim=1)
266 | else:
267 | x = x.view(-1)
268 | n = x.numel()
269 | if win_length is None:
270 | win_length = n_fft
271 |
272 | if verbose:
273 | print(f"[enhance_v2_fixed] device={device}, dtype={dtype}, n={n}, n_fft={n_fft}, hop={hop_length}, overall_gain_db={overall_gain_db}, output_as_stereo={output_as_stereo}")
274 |
275 | window = torch.hann_window(win_length, device=device, dtype=dtype)
276 |
277 | # Full STFT
278 | X = torch.stft(x,
279 | n_fft=n_fft,
280 | hop_length=hop_length,
281 | win_length=win_length,
282 | window=window,
283 | center=True,
284 | return_complex=True)
285 |
286 | mag = torch.abs(X)
287 | phase = torch.angle(X)
288 | bins, frames = mag.shape
289 |
290 | freq_bins = torch.fft.rfftfreq(n_fft, 1.0 / sr).to(device=device, dtype=dtype)
291 | hp_mask = (freq_bins >= hp_cut_hz).float().unsqueeze(1)
292 | low_mask = (freq_bins <= 200.0).float()
293 | mid_mask = ((freq_bins > 200.0) & (freq_bins <= 2000.0)).float()
294 | high_mask = (freq_bins > 2000.0).float()
295 |
296 | low_gain = 10 ** (low_gain_db / 20.0)
297 | mid_gain = 10 ** (mid_gain_db / 20.0)
298 | high_gain = 10 ** (high_gain_db / 20.0)
299 | band_gain = (low_gain * low_mask + mid_gain * mid_mask + high_gain * high_mask).unsqueeze(1)
300 |
301 | est_samples = min(int(0.5 * sr), n)
302 | est_frames = max(1, int(est_samples / hop_length))
303 | noise_floor = mag[:, :est_frames].median(dim=1).values.unsqueeze(1).clamp(min=1e-9)
304 |
305 | # Spectral subtraction + Wiener blend
306 | S2 = mag**2
307 | N2 = noise_floor**2
308 | over_sub = 1.0 + (denoise_strength * 0.6)
309 | sub = S2 - over_sub * N2
310 | sub = torch.clamp(sub, min=0.0)
311 | gain = sub / (S2 + 1e-12)
312 | gain = torch.clamp(gain, 0.0, 1.0)
313 | gain = 1.0 - (1.0 - gain) * denoise_strength
314 |
315 | # 2-D smoothing: time then frequency
316 | time_k = max(3, time_smooth_k if time_smooth_k % 2 == 1 else time_smooth_k + 1)
317 | freq_k = max(3, freq_smooth_k if freq_smooth_k % 2 == 1 else freq_smooth_k + 1)
318 |
319 | # Time smoothing (grouped conv across frames)
320 | gain_t = gain.unsqueeze(0) # (1, bins, frames)
321 | bins_count = gain_t.shape[1]
322 | time_kernel = torch.ones(time_k, device=device, dtype=dtype) / float(time_k)
323 | k_time = time_kernel.view(1, 1, time_k).repeat(bins_count, 1, 1)
324 | pad_t = (time_k // 2, time_k - 1 - time_k // 2)
325 | gain_t = F.pad(gain_t, pad_t, mode='replicate')
326 | gain_t = F.conv1d(gain_t, k_time, groups=bins_count).squeeze(0)
327 |
328 | # Frequency smoothing (frames as batch)
329 | gain_f = gain_t.transpose(0, 1).unsqueeze(1) # (frames,1,bins)
330 | freq_kernel = torch.ones(freq_k, device=device, dtype=dtype).view(1, 1, freq_k) / float(freq_k)
331 | pad_f = (freq_k // 2, freq_k - 1 - freq_k // 2)
332 | gain_f = F.pad(gain_f, pad_f, mode='replicate')
333 | gain_f = F.conv1d(gain_f, freq_kernel, groups=1) # (frames,1,bins)
334 | gain = gain_f.squeeze(1).transpose(0, 1) # (bins, frames)
335 |
336 | # Apply highpass mask and band shaping, enforce min_gain floor
337 | gain = gain * hp_mask
338 | mag = mag * gain * band_gain
339 | mag = torch.clamp(mag, min=min_gain * 1e-6)
340 |
341 | # Transient preservation (high-frequency energy rise)
342 | hf = (mag * high_mask.unsqueeze(1)).sum(dim=0)
343 | prev = F.pad(hf, (1, 0))[:-1]
344 | rise = torch.clamp((hf - prev) / (prev + 1e-9), min=0.0)
345 | transient_gain = 1.0 + (transient_boost - 1.0) * torch.clamp(rise * 2.0, 0.0, 1.0)
346 | mag = mag * transient_gain.unsqueeze(0)
347 |
348 | # Mild multiband compression
349 | mag = _multiband_compress(mag, freq_bins, sr,
350 | bands=((20, 200), (200, 2000), (2000, 8000)),
351 | thresholds_db=(-22.0, -20.0, -20.0),
352 | ratios=(1.0, 1.6, 1.8),
353 | attack_frames=1,
354 | release_frames=6,
355 | device=device,
356 | dtype=dtype)
357 |
358 | # Reconstruct complex STFT
359 | X = mag * torch.exp(1j * phase)
360 |
361 | # Very subtle harmonic excitation on highest band
362 | high_mask_full = (freq_bins > 4000.0).float().unsqueeze(1)
363 | X_high = X * high_mask_full
364 | high_time = torch.istft(X_high, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=n)
365 | excite = _soft_clip(high_time, drive=1.0 + excite_amount) - high_time
366 | E = torch.stft(excite, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=True, return_complex=True)
367 | X = X + excite_scale * E
368 |
369 | # ISTFT back to time domain
370 | out = torch.istft(X, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=n)
371 |
372 | # Final gentle de-reverb: smooth residual and subtract tiny fraction
373 | residual = out - x
374 | res = residual.cpu()
375 | kernel_len = max(1, int(0.05 * sr)) # 50 ms smoothing
376 | if kernel_len > 1 and res.numel() > kernel_len:
377 | # prepare kernel on CPU
378 | k = torch.ones(1, 1, kernel_len, dtype=res.dtype) / float(kernel_len)
379 | # conv1d expects (batch, channels, length)
380 | res_padded = F.pad(res.unsqueeze(0).unsqueeze(0), (kernel_len // 2, kernel_len - 1 - kernel_len // 2), mode='replicate')
381 | res_smooth = F.conv1d(res_padded, k, padding=0).squeeze()
382 | # Align length robustly: trim or pad to match n
383 | if res_smooth.numel() > n:
384 | res_smooth = res_smooth[:n]
385 | elif res_smooth.numel() < n:
386 | res_smooth = F.pad(res_smooth, (0, n - res_smooth.numel()))
387 | # subtract a tiny fraction of smoothed residual to reduce perceived reverb
388 | out = out - 0.02 * res_smooth.to(device=device, dtype=dtype)
389 |
390 | # Apply overall gain in dB (before final limiter/normalization)
391 | if abs(overall_gain_db) > 1e-6:
392 | gain_lin = 10 ** (overall_gain_db / 20.0)
393 | out = out * gain_lin
394 |
395 | # Final limiter and RMS normalization
396 | peak = out.abs().max().clamp(min=1e-12).item()
397 | threshold = 10 ** (limiter_threshold_db / 20.0)
398 | if peak > threshold:
399 | out = out * (threshold / peak)
400 |
401 | current_rms = _rms_val(out)
402 | target_rms = 10 ** (target_rms_db / 20.0)
403 | if current_rms > 1e-12:
404 | out = out * (target_rms / current_rms)
405 |
406 | # If stereo output requested: duplicate mono to two channels and normalize (mono duplication + norm)
407 | if output_as_stereo:
408 | # create (2, n) tensor
409 | out_stereo = torch.stack([out, out], dim=0) # (2, n)
410 | # ensure each channel has same RMS as target_rms
411 | ch_rms = torch.sqrt(torch.mean(out_stereo**2, dim=1) + 1e-12)
412 | # avoid division by zero
413 | scale = (target_rms / ch_rms).unsqueeze(1)
414 | out_stereo = out_stereo * scale.to(device=device, dtype=dtype)
415 | final_out = out_stereo
416 | final_shape = (2, n)
417 | else:
418 | final_out = out
419 | final_shape = (n,)
420 |
421 | return _to_output(final_out, audio, return_type), final_shape
--------------------------------------------------------------------------------
/neuralpiano/master.py:
--------------------------------------------------------------------------------
1 | # Master Python module
2 |
3 | import torch
4 | import math
5 | from typing import Dict, Tuple
6 |
7 | Tensor = torch.Tensor
8 |
9 | def master_mono_piano(mono: Tensor,
10 | sr: int = 48000,
11 | hp_cut: float = 20.0,
12 | low_shelf_db: float = -1.5,
13 | low_shelf_fc: float = 250.0,
14 | presence_boost_db: float = 1.6,
15 | presence_fc: float = 3500.0,
16 | compressor_thresh_db: float = -6.0,
17 | compressor_ratio: float = 2.5,
18 | compressor_attack_ms: float = 8.0,
19 | compressor_release_ms: float = 80.0,
20 | compressor_makeup_db: float = 0.0,
21 | stereo_spread_ms: float = 6.0,
22 | stereo_spread_level_db: float = -12.0,
23 | reverb_size_sec: float = 0.6,
24 | reverb_mix: float = 0.06,
25 | limiter_ceiling_db: float = -0.3,
26 | dithering_bits: int = 24,
27 | gain_db: float = 20.0,
28 | device: torch.device = None,
29 | dtype: torch.dtype = None,
30 | iir_impulse_len: int = 2048, # length used to approximate IIRs as FIRs (tradeoff speed/accuracy)
31 | ) -> Tuple[Tensor, Dict]:
32 | """
33 | Mastering pipeline for a mono input piano track producing a stereo master and diagnostics.
34 |
35 | Description and design choices
36 | - Purpose: provide a compact, deterministic, and high-throughput mastering chain implemented
37 | with PyTorch tensors. The pipeline uses short FIR approximations for small IIR filters,
38 | single-shot FFT convolution for medium/long FIRs (cascaded HP + shelf, reverb), a fast
39 | RMS-based compressor, fractional-delay stereo widening, a smooth soft limiter, and
40 | deterministic TPDF dithering.
41 | - Where gain_db is applied: an explicit master gain parameter (gain_db) is applied after
42 | EQ/compression/stereo processing and reverb but before the soft limiter and final safety
43 | scaling. This placement allows the limiter to react to the applied gain, preserving
44 | consistent ceiling behavior while enabling transparent loudness adjustments and avoiding
45 | excessive post-limiter boosting which would bypass the limiter's protection.
46 | - Implementation details:
47 | * EQ: designs 2nd-order biquads and converts them to short FIR impulse responses using
48 | direct IIR stepping (iir_to_fir). HP and low-shelf are cascaded via FFT convolution.
49 | * Presence: implemented as a small symmetric FIR (sinc-windowed) and applied with conv1d.
50 | * Compression: fast RMS detector downsampled to ~4 kHz with an attack/release recurrence
51 | executed on CPU for determinism and efficiency. Soft-knee gain computer produces a
52 | time-varying linear gain applied to the signal with optional makeup gain.
53 | * Stereo widen: fractional sub-sample delays are used to create left/right channels, then
54 | mid/side processing adjusts side level.
55 | * Reverb: simple tapped + exponential-tail IR built and FFT-convolved once.
56 | * Limiter: soft tanh-based limiter that scales output to target ceiling (limiter_ceiling_db).
57 | * Dither: deterministic vectorized LCG generates TPDF dither for specified bit depth.
58 | - Determinism and diagnostics: the function avoids non-deterministic ops and returns a
59 | diagnostics dict with numeric metrics (peaks, RMS, average reduction, applied scales).
60 | - Precision and device handling: prefers float32 for performance; preserves device assignment;
61 | convolution routines use torch.fft (rfft/irfft) for CPU/GPU compatibility.
62 |
63 | Parameters
64 | - mono: Tensor of shape [N] or [1, N], mono PCM in float range roughly [-1, 1].
65 | - sr: sample rate in Hz.
66 | - hp_cut: high-pass cutoff frequency in Hz.
67 | - low_shelf_db: low-shelf gain (dB).
68 | - low_shelf_fc: low-shelf center freq (Hz).
69 | - presence_boost_db: narrow presence boost amount (dB).
70 | - presence_fc: center freq of presence boost (Hz).
71 | - compressor_...: compressor threshold (dBFS), ratio, attack and release (ms), and makeup (dB).
72 | - stereo_spread_ms: nominal stereo spread in milliseconds used to compute fractional delays.
73 | - stereo_spread_level_db: side gain in dB applied to M/S side channel.
74 | - reverb_size_sec: approximate reverb tail time in seconds (affects IR length).
75 | - reverb_mix: wet fraction of reverb to blend with dry signal.
76 | - limiter_ceiling_db: final soft limiter ceiling in dBFS (should be <= 0).
77 | - dithering_bits: integer bits for TPDF dithering (1-32). Use 0 or outside range to disable.
78 | - gain_db: master gain applied (dB). Positive increases loudness, negative reduces. Applied
79 | before the limiter so the limiter shapes peaks introduced by the gain.
80 | - device, dtype: optional torch.device and dtype to force placement/precision.
81 | - iir_impulse_len: length of FIR used to approximate 2nd-order IIRs (tradeoff accuracy/speed).
82 |
83 | Returns
84 | - stereo: Tensor shaped [2, N] float32 in range (-1, 1) representing left/right master.
85 | - diagnostics: Dict with numeric measurements and settings for reproducibility.
86 |
87 | Notes
88 | - The function keeps intermediary operations vectorized. Non-trivial recurrence for the RMS
89 | detector runs on CPU for stability and determinism but the returned envelopes are moved back
90 | to the target device and dtype.
91 | - For large inputs and GPU usage, FFT sizes are chosen as powers of two for efficiency.
92 | - If you want gain to be applied as a pre-EQ trim (instead of pre-limiter), move the gain
93 | multiplication earlier; current placement intentionally lets the limiter handle the gain.
94 | """
95 |
96 | # --- Setup, sanitize ---
97 | if mono.ndim == 2 and mono.shape[0] == 1:
98 | x = mono.squeeze(0)
99 | elif mono.ndim == 1:
100 | x = mono
101 | else:
102 | raise ValueError("mono must be shape [1, N] or [N]")
103 |
104 | device = device or x.device
105 | # prefer float32 for fastest throughput unless user explicitly provided float64
106 | if dtype is None:
107 | dtype = x.dtype if x.dtype in (torch.float32, torch.float64) else torch.float32
108 | x = x.to(device=device, dtype=dtype, copy=False)
109 | N = x.shape[-1]
110 | eps = 1e-12
111 |
112 | diagnostics: Dict = {}
113 | def db2lin(d): return 10.0 ** (d / 20.0)
114 | def lin2db(v): return 20.0 * math.log10(max(v, 1e-12))
115 |
116 | # --- Helper: design biquad coefficients (as before) ---
117 | def design_butter_hp(fc, fs, Q=0.7071):
118 | omega = 2.0 * math.pi * fc / fs
119 | alpha = math.sin(omega) / (2.0 * Q)
120 | cosw = math.cos(omega)
121 | b0 = (1 + cosw) / 2.0
122 | b1 = -(1 + cosw)
123 | b2 = (1 + cosw) / 2.0
124 | a0 = 1 + alpha
125 | a1 = -2 * cosw
126 | a2 = 1 - alpha
127 | return b0, b1, b2, a0, a1, a2
128 |
129 | def design_low_shelf(fc, fs, gain_db, Q=0.7071):
130 | A = 10 ** (gain_db / 40.0)
131 | w0 = 2.0 * math.pi * fc / fs
132 | alpha = math.sin(w0) / (2.0 * Q)
133 | cosw = math.cos(w0)
134 | b0 = A*( (A+1) - (A-1)*cosw + 2*math.sqrt(A)*alpha )
135 | b1 = 2*A*( (A-1) - (A+1)*cosw )
136 | b2 = A*( (A+1) - (A-1)*cosw - 2*math.sqrt(A)*alpha )
137 | a0 = (A+1) + (A-1)*cosw + 2*math.sqrt(A)*alpha
138 | a1 = -2*( (A-1) + (A+1)*cosw )
139 | a2 = (A+1) + (A-1)*cosw - 2*math.sqrt(A)*alpha
140 | return b0, b1, b2, a0, a1, a2
141 |
142 | # --- Utility: convert a 2nd-order IIR (b,a) to an FIR impulse response of length L
143 | # compute impulse response by stepping the IIR for L samples (this loop runs only for L ~ 1-4k)
144 | def iir_to_fir(b0, b1, b2, a0, a1, a2, L, device, dtype):
145 | # normalize
146 | b0n, b1n, b2n = b0 / a0, b1 / a0, b2 / a0
147 | a1n, a2n = a1 / a0, a2 / a0
148 | # compute impulse response on CPU or device (run on device if small L and device != cpu)
149 | run_device = device if (device.type != 'cuda' or L <= 8192) else device
150 | h = torch.zeros(L, device=run_device, dtype=dtype)
151 | x_prev1 = 0.0
152 | x_prev2 = 0.0
153 | y_prev1 = 0.0
154 | y_prev2 = 0.0
155 | # feed delta(0)=1, others 0
156 | for n in range(L):
157 | xv = 1.0 if n == 0 else 0.0
158 | yv = b0n * xv + b1n * x_prev1 + b2n * x_prev2 - a1n * y_prev1 - a2n * y_prev2
159 | h[n] = yv
160 | x_prev2 = x_prev1
161 | x_prev1 = xv
162 | y_prev2 = y_prev1
163 | y_prev1 = yv
164 | # ensure on main device
165 | if h.device != device:
166 | h = h.to(device=device, dtype=dtype)
167 | return h
168 |
169 | # --- 1) Build IIR -> FIR approximations for HP and shelf (single impulse responses) ---
170 | # keep impulse length small (configurable) to balance cost/accuracy
171 | b0,b1,b2,a0,a1,a2 = design_butter_hp(hp_cut, sr)
172 | hp_ir = iir_to_fir(b0,b1,b2,a0,a1,a2, iir_impulse_len, device, dtype)
173 |
174 | b0s,b1s,b2s,a0s,a1s,a2s = design_low_shelf(low_shelf_fc, sr, low_shelf_db)
175 | shelf_ir = iir_to_fir(b0s,b1s,b2s,a0s,a1s,a2s, iir_impulse_len, device, dtype)
176 |
177 | # cascade IIRs by convolving their IRs (use FFT convolution)
178 | def fft_convolve_full(sig, kernel, device, dtype):
179 | n = sig.shape[-1]
180 | k = kernel.shape[-1]
181 | out_len = n + k - 1
182 | size = 1 << ((out_len - 1).bit_length())
183 | # cast to complex-friendly dtype (float32/64)
184 | S = torch.fft.rfft(sig, n=size)
185 | K = torch.fft.rfft(kernel, n=size)
186 | Y = S * K
187 | y = torch.fft.irfft(Y, n=size)[:out_len]
188 | # return same length as input (valid-ish) by trimming convolution to center-left aligned (like original)
189 | return y[:n]
190 |
191 | # apply HP then shelf by convolving with cascaded IR (hp_ir * shelf_ir)
192 | casc_ir = fft_convolve_full(hp_ir, shelf_ir, device, dtype)[:max(hp_ir.numel(), shelf_ir.numel())]
193 | # normalize tiny numerical offsets
194 | casc_ir = casc_ir / (casc_ir.abs().sum().clamp(min=eps))
195 |
196 | # apply cascade IR to input (single FFT conv)
197 | x_hp_shelf = fft_convolve_full(x, casc_ir, device, dtype)
198 |
199 | # --- 2) Presence boost (small FIR) same approach but small kernel conv1d is very fast ---
200 | pres_len = min(256, max(65, int(sr * 0.0045))) # ~4.5 ms
201 | t_idx = torch.arange(pres_len, device=device, dtype=dtype) - (pres_len - 1) / 2.0
202 | h = (torch.sinc(2.0 * presence_fc / sr * t_idx) * torch.hann_window(pres_len, device=device, dtype=dtype))
203 | h = h / (h.abs().sum() + eps)
204 | gain_lin = db2lin(presence_boost_db)
205 | presence_ir = (gain_lin - 1.0) * h
206 | presence_ir[(pres_len - 1) // 2] += 1.0
207 | # conv small kernel with conv1d (fast)
208 | x_eq = torch.nn.functional.conv1d(x_hp_shelf.view(1,1,-1), presence_ir.view(1,1,-1), padding=(pres_len-1)//2).view(-1)
209 |
210 | diagnostics.update({
211 | "hp_cut": hp_cut,
212 | "low_shelf_db": low_shelf_db,
213 | "presence_db": presence_boost_db,
214 | "presence_len": pres_len,
215 | "iir_impulse_len": iir_impulse_len,
216 | })
217 |
218 | # --- 3) Fast RMS compressor (vectorized with downsampled detector) ---
219 | sig = x_eq
220 | attack_tc = math.exp(-1.0 / max(1.0, (compressor_attack_ms * sr / 1000.0)))
221 | release_tc = math.exp(-1.0 / max(1.0, (compressor_release_ms * sr / 1000.0)))
222 | sq = sig * sig
223 |
224 | ds = max(1, int(sr // 4000)) # detector rate ~4kHz
225 | if ds > 1:
226 | # pad to multiple of ds
227 | pad = (-sq.shape[-1]) % ds
228 | if pad:
229 | sq_pad = torch.nn.functional.pad(sq, (0, pad))
230 | else:
231 | sq_pad = sq
232 | sq_ds = sq_pad.view(-1).reshape(-1, ds).mean(dim=1)
233 | else:
234 | sq_ds = sq
235 |
236 | # recurrence on small downsampled vector executed on CPU (cheap)
237 | sq_ds_cpu = sq_ds.detach().cpu()
238 | env_ds = torch.empty_like(sq_ds_cpu)
239 | s_val = float(sq_ds_cpu[0].item())
240 | a = attack_tc
241 | r = release_tc
242 | for i in range(sq_ds_cpu.shape[0]):
243 | v = float(sq_ds_cpu[i].item())
244 | if v > s_val:
245 | s_val = a * s_val + (1.0 - a) * v
246 | else:
247 | s_val = r * s_val + (1.0 - r) * v
248 | env_ds[i] = s_val
249 | env_ds = env_ds.to(device=device, dtype=dtype)
250 |
251 | if ds > 1:
252 | env = env_ds.repeat_interleave(ds)[:N]
253 | else:
254 | env = env_ds
255 |
256 | rms_env = torch.sqrt(torch.clamp(env, min=eps))
257 | lvl_db = 20.0 * torch.log10(torch.clamp(rms_env, min=1e-12))
258 | knee = 3.0
259 | over = lvl_db - compressor_thresh_db
260 | # soft knee
261 | zero = torch.zeros_like(over)
262 | gain_reduction_db = torch.where(
263 | over <= -knee,
264 | zero,
265 | torch.where(
266 | over >= knee,
267 | compressor_thresh_db + (over / compressor_ratio) - lvl_db,
268 | - ((1.0 - 1.0/compressor_ratio) * (over + knee)**2) / (4.0 * knee)
269 | )
270 | ).clamp(max=0.0)
271 | gain_lin = 10.0 ** (gain_reduction_db / 20.0)
272 | if ds > 1:
273 | gain_full = gain_lin.repeat_interleave(ds)[:N]
274 | else:
275 | gain_full = gain_lin
276 | makeup = db2lin(compressor_makeup_db)
277 | comp_out = sig * gain_full * makeup
278 |
279 | diagnostics.update({
280 | "compressor_thresh_db": compressor_thresh_db,
281 | "compressor_ratio": compressor_ratio,
282 | "compressor_attack_ms": compressor_attack_ms,
283 | "compressor_release_ms": compressor_release_ms,
284 | "compressor_makeup_db": compressor_makeup_db,
285 | "detector_downsample": ds,
286 | "avg_reduction_db": float((20.0 * torch.log10((gain_full.mean().clamp(min=1e-12))).item())),
287 | })
288 |
289 | # --- 4) Stereo widening with fractional sub-sample delays (vectorized) ---
290 | spread_samples = max(1e-4, stereo_spread_ms * sr / 1000.0)
291 | left_delay = spread_samples * 0.5
292 | right_delay = -spread_samples * 0.3333
293 |
294 | def fractional_delay_vec(sig, delay):
295 | n = sig.shape[-1]
296 | idx = torch.arange(n, device=device, dtype=dtype)
297 | pos = idx - delay
298 | pos_floor = pos.floor().long()
299 | pos_ceil = pos_floor + 1
300 | frac = (pos - pos_floor.to(dtype))
301 | pos_floor = pos_floor.clamp(0, n-1)
302 | pos_ceil = pos_ceil.clamp(0, n-1)
303 | s_floor = sig[pos_floor]
304 | s_ceil = sig[pos_ceil]
305 | return s_floor * (1.0 - frac) + s_ceil * frac
306 |
307 | left = 0.985 * comp_out + 0.015 * fractional_delay_vec(comp_out, left_delay)
308 | right = 0.985 * comp_out + 0.015 * fractional_delay_vec(comp_out, right_delay)
309 |
310 | mid = 0.5 * (left + right)
311 | side = 0.5 * (left - right)
312 | side = side * db2lin(stereo_spread_level_db)
313 | left = mid + side
314 | right = mid - side
315 |
316 | # --- 5) Reverb: build IR and FFT-convolve (single-shot) ---
317 | reverb_len = int(min(int(sr * reverb_size_sec), 65536))
318 | reverb_len = max(reverb_len, int(0.02 * sr))
319 | t = torch.arange(reverb_len, device=device, dtype=dtype)
320 | tail = torch.exp(-t / (reverb_size_sec * sr + 1e-12))
321 | taps_ms = [12, 23, 37, 53, 79]
322 | ir = torch.zeros(reverb_len, device=device, dtype=dtype)
323 | for i, tm in enumerate(taps_ms):
324 | idx = int(round(sr * tm / 1000.0))
325 | if idx < reverb_len:
326 | ir[idx] += (0.5 ** (i + 1))
327 | ir += 0.15 * tail
328 | ir = ir / (ir.abs().sum() + eps)
329 |
330 | left_rev = fft_convolve_full(left, ir, device, dtype)
331 | right_rev = fft_convolve_full(right, ir, device, dtype)
332 | left = (1.0 - reverb_mix) * left + reverb_mix * left_rev
333 | right = (1.0 - reverb_mix) * right + reverb_mix * right_rev
334 |
335 | diagnostics.update({
336 | "reverb_size_sec": reverb_size_sec,
337 | "reverb_mix": reverb_mix,
338 | "reverb_len": reverb_len,
339 | })
340 |
341 | # --- MASTER GAIN: apply desired gain in linear domain before limiter ---
342 | if abs(gain_db) > 1e-12:
343 | gain_lin_master = db2lin(gain_db)
344 | left = left * gain_lin_master
345 | right = right * gain_lin_master
346 | diagnostics['applied_gain_db'] = float(gain_db)
347 | diagnostics['applied_gain_lin'] = float(gain_lin_master)
348 | else:
349 | diagnostics['applied_gain_db'] = 0.0
350 | diagnostics['applied_gain_lin'] = 1.0
351 |
352 | # --- 6) Soft limiter ---
353 | def soft_limiter(x_chan, ceiling_db):
354 | ceiling_lin = db2lin(ceiling_db)
355 | peak = x_chan.abs().max().clamp(min=eps)
356 | if peak <= ceiling_lin:
357 | return x_chan
358 | scaled = x_chan * (ceiling_lin / peak)
359 | out = torch.tanh(scaled * 1.25) / 1.25
360 | out = out / out.abs().max().clamp(min=eps) * ceiling_lin
361 | return out
362 |
363 | left = soft_limiter(left, limiter_ceiling_db)
364 | right = soft_limiter(right, limiter_ceiling_db)
365 |
366 | # final safety scaling
367 | peak_val = max(left.abs().max().item(), right.abs().max().item())
368 | if peak_val > 0.999:
369 | scale = 0.999 / peak_val
370 | left = left * scale
371 | right = right * scale
372 | diagnostics['final_scale'] = float(scale)
373 | else:
374 | diagnostics['final_scale'] = 1.0
375 |
376 | # --- 7) Deterministic TPDF dithering (vectorized LCG) ---
377 | def vectorized_lcg(sz, seed):
378 | a = 1103515245
379 | c = 12345
380 | mod = 2**31
381 | seeds = (torch.arange(sz, device=device, dtype=torch.int64) * 1664525 + int(seed)) & (mod - 1)
382 | vals = (a * seeds + c) & (mod - 1)
383 | floats = (vals.to(dtype) / float(mod)) - 0.5
384 | return floats
385 |
386 | if 1 <= dithering_bits <= 32:
387 | q = 1.0 / (2 ** (dithering_bits - 1))
388 | seed = (N ^ sr ^ 0x9e3779b1) & 0xffffffff
389 | na = vectorized_lcg(N, seed)
390 | nb = vectorized_lcg(N, seed ^ 0x6a09e667)
391 | tpdf = (na - nb) * q
392 | left = left + 0.5 * tpdf
393 | right = right + 0.5 * tpdf
394 |
395 | # --- Output and diagnostics ---
396 | stereo = torch.stack([left.to(torch.float32), right.to(torch.float32)], dim=0)
397 | stereo = stereo.clamp(-1.0 + 1e-9, 1.0 - 1e-9)
398 |
399 | left_peak = left.abs().max().item()
400 | right_peak = right.abs().max().item()
401 | left_rms = math.sqrt(float(torch.mean(left * left).item()))
402 | right_rms = math.sqrt(float(torch.mean(right * right).item()))
403 | diagnostics.update({
404 | "left_peak": left_peak, "right_peak": right_peak,
405 | "left_peak_db": lin2db(left_peak), "right_peak_db": lin2db(right_peak),
406 | "left_rms_db": lin2db(left_rms), "right_rms_db": lin2db(right_rms),
407 | "num_samples": N, "sample_rate": sr,
408 | })
409 |
410 | return stereo, diagnostics
--------------------------------------------------------------------------------
/neuralpiano/neuralpiano.py:
--------------------------------------------------------------------------------
1 | #===============================================================================
2 | # Neural Piano main Python module
3 | #===============================================================================
4 |
5 | """
6 |
7 | This module exposes `render_midi`, a high-level convenience function that:
8 | - Renders a MIDI file to a raw waveform using a SoundFont (SF2) via `midirenderer`.
9 | - Loads the rendered waveform and optionally trims silence.
10 | - Encodes the waveform into a latent representation and decodes it using a
11 | learned Encoder/Decoder model to produce a high-quality piano audio render.
12 | - Optionally applies a sequence of post-processing steps: denoising,
13 | bass enhancement, full-spectrum enhancement, and mastering.
14 | - Writes the final audio to disk or returns it in-memory.
15 |
16 | Design goals
17 | ------------
18 | - Provide a single, well-documented function to convert MIDI -> polished WAV.
19 | - Keep sensible defaults so the function works out-of-the-box with a
20 | `models/` directory containing the required SoundFont and model artifacts.
21 | - Allow advanced users to override model paths, processing parameters, and
22 | device selection.
23 |
24 | Dependencies
25 | ------------
26 | - Python 3.8+
27 | - torch
28 | - librosa
29 | - soundfile (pysoundfile)
30 | - midirenderer
31 | - The package's internal modules:
32 | - .music2latent.inference.EncoderDecoder
33 | - .denoise.denoise_audio
34 | - .bass.enhance_audio_bass
35 | - .enhancer.enhance_audio_full
36 | - .master.master_mono_piano
37 |
38 | Typical usage
39 | -------------
40 | >>> from neuralpiano.main import render_midi
41 | >>> out_path = render_midi("score.mid") # writes ./score.wav using defaults
42 | >>> audio, sr = render_midi("score.mid", return_audio=True) # get numpy array and sr
43 |
44 | Notes and behavior
45 | ------------------
46 | - By default the function expects a `models/` directory in the current working
47 | directory and a SoundFont file named `SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2`.
48 | Use `sf2_name` or `custom_model_path` to override.
49 | - The EncoderDecoder is instantiated per-call. For repeated renders in a
50 | long-running process, consider reusing a single EncoderDecoder instance
51 | (not provided by this convenience wrapper).
52 | - `sample_rate` controls the resampling rate used when loading the rendered
53 | MIDI waveform and is propagated to downstream processing where relevant.
54 | - `trim_silence` uses `librosa.effects.trim` with configurable `trim_top_db`,
55 | `trim_frame_length`, and `trim_hop_length`.
56 | - Post-processing steps are applied in this order when enabled:
57 | 1. denoise (if `denoise=True`)
58 | 2. bass enhancement (if `enhance_bass=True`)
59 | 3. full enhancement (if `enhance_full=True`)
60 | 4. mastering (if `master=True`)
61 | Each step accepts a kwargs dict (e.g., `denoise_kwargs`) to override defaults.
62 | - `device` accepts any value accepted by `torch.device` (e.g., 'cuda', 'cpu').
63 | - When `return_audio=True`, the function returns `(final_audio, sample_rate)` as
64 | a NumPy array and sample rate. Otherwise it writes a WAV file and returns the
65 | output file path.
66 | - Verbosity:
67 | - `verbose` prints high-level progress messages.
68 | - `verbose_diag` prints additional diagnostic values useful for debugging.
69 |
70 | Exceptions
71 | ----------
72 | The function may raise exceptions originating from:
73 | - File I/O (missing MIDI or models, permission errors).
74 | - `midirenderer` if the SoundFont or MIDI bytes are invalid.
75 | - `librosa` when loading or trimming audio.
76 | - Torch/model-related errors when instantiating or running the EncoderDecoder.
77 | - Any of the post-processing modules if they encounter invalid inputs.
78 |
79 | Parameters
80 | ----------
81 | input_midi_file : str
82 | Path to the input MIDI file to render.
83 | output_audio_file : str or None
84 | Path to write the final WAV file. If None, the output filename is derived
85 | from the MIDI filename and written to the current working directory.
86 | sample_rate : int
87 | Target sample rate for loading and processing audio (default: 48000).
88 | denoising_steps : int
89 | Default number of denoising steps passed to the decoder.
90 | max_batch_size : int or None
91 | Maximum batch size to use when encoding/decoding; passed to EncoderDecoder.
92 | use_v1_piano_model : bool
93 | If True, instructs EncoderDecoder to use the v1 piano model variant.
94 | load_multi_instrumental_model : bool
95 | If True, load a multi-instrument model variant (if available).
96 | custom_model_path : str or None
97 | Path to a custom model checkpoint for inference; passed to EncoderDecoder.
98 | sf2_name : str
99 | Filename of the SoundFont inside the `models/` directory (default provided).
100 | trim_silence : bool
101 | If True, trim leading/trailing silence from the rendered waveform.
102 | trim_top_db : float
103 | `top_db` parameter for `librosa.effects.trim` (default: 60).
104 | trim_frame_length : int
105 | `frame_length` parameter for `librosa.effects.trim` (default: 2048).
106 | trim_hop_length : int
107 | `hop_length` parameter for `librosa.effects.trim` (default: 512).
108 | denoise : bool
109 | If True, run the denoiser post-processing step.
110 | denoise_kwargs : dict or None
111 | Additional keyword arguments for `denoise_audio`.
112 | enhance_bass : bool
113 | If True, run bass enhancement post-processing.
114 | bass_kwargs : dict or None
115 | Additional keyword arguments for `enhance_audio_bass`. `low_gain_db` is set
116 | from the top-level `low_gain_db` unless overridden here.
117 | low_gain_db : float
118 | Default low-frequency gain (dB) used by bass enhancement (default: 8.0).
119 | enhance_full : bool
120 | If True, run full-spectrum enhancement post-processing.
121 | enhance_full_kwargs : dict or None
122 | Additional keyword arguments for `enhance_audio_full`.
123 | master : bool
124 | If True, run final mastering pass.
125 | master_kwargs : dict or None
126 | Additional keyword arguments for `master_mono_piano`. `gain_db` is set from
127 | the top-level `overall_gain_db` unless overridden here.
128 | overall_gain_db : float
129 | Default gain (dB) applied during mastering (default: 10.0).
130 | device : str
131 | Device string for torch (e.g., 'cuda' or 'cpu'). Converted to `torch.device`
132 | when passed to post-processing functions.
133 | return_audio : bool
134 | If True, return `(audio_numpy, sample_rate)` instead of writing a file.
135 | verbose : bool
136 | Print progress messages when True.
137 | verbose_diag : bool
138 | Print diagnostic information when True.
139 |
140 | Example
141 | -------
142 | Render a MIDI to disk with default processing:
143 |
144 | >>> render_midi("song.mid")
145 |
146 | Render and receive audio in-memory without post-processing:
147 |
148 | >>> audio, sr = render_midi("song.mid", denoise=False, enhance_bass=False,
149 | ... enhance_full=False, master=False, return_audio=True)
150 |
151 | """
152 |
153 | #===============================================================================
154 |
155 | import os
156 | import io
157 | from pathlib import Path
158 |
159 | import torch
160 |
161 | import librosa
162 | import soundfile as sf
163 |
164 | import midirenderer
165 |
166 | from .music2latent.inference import EncoderDecoder
167 |
168 | from .denoise import denoise_audio
169 |
170 | from .bass import enhance_audio_bass
171 |
172 | from .enhancer import enhance_audio_full
173 |
174 | from .master import master_mono_piano
175 |
176 | #===============================================================================
177 |
178 | def render_midi(input_midi_file,
179 | output_audio_file=None,
180 | sample_rate=48000,
181 | denoising_steps=10,
182 | max_batch_size=None,
183 | use_v1_piano_model=False,
184 | load_multi_instrumental_model=False,
185 | custom_model_path=None,
186 | sf2_name='SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2',
187 | trim_silence=True,
188 | trim_top_db=60,
189 | trim_frame_length=2048,
190 | trim_hop_length=512,
191 | denoise=True,
192 | denoise_kwargs=None,
193 | enhance_bass=False,
194 | bass_kwargs=None,
195 | low_gain_db=8.0,
196 | enhance_full=True,
197 | enhance_full_kwargs=None,
198 | master=False,
199 | master_kwargs=None,
200 | overall_gain_db=10.0,
201 | device='cuda',
202 | return_audio=False,
203 | verbose=True,
204 | verbose_diag=False
205 | ):
206 |
207 | """
208 | Render a MIDI file to a polished piano audio waveform or WAV file.
209 |
210 | This function orchestrates the full Neural Piano pipeline:
211 | 1. Render MIDI -> raw waveform using a SoundFont via `midirenderer`.
212 | 2. Load and optionally trim silence from the rendered waveform.
213 | 3. Encode waveform to latent and decode with the model to synthesize audio.
214 | 4. Optionally apply denoising, bass enhancement, full enhancement, and mastering.
215 | 5. Return the final audio as a NumPy array or write it to disk as a WAV file.
216 |
217 | Parameters
218 | ----------
219 | input_midi_file : str
220 | Path to the input MIDI file.
221 | output_audio_file : str or None
222 | Path to write the final WAV file. If None, derived from MIDI filename.
223 | sample_rate : int
224 | Sample rate for loading and processing audio (default: 48000).
225 | denoising_steps : int
226 | Default denoising steps passed to decoder.
227 | max_batch_size : int or None
228 | Max batch size for EncoderDecoder operations.
229 | use_v1_piano_model : bool
230 | Use v1 piano model variant if True.
231 | load_multi_instrumental_model : bool
232 | Load multi-instrument model variant if True.
233 | custom_model_path : str or None
234 | Path to a custom model checkpoint for inference.
235 | sf2_name : str
236 | SoundFont filename located in the `models/` directory.
237 | trim_silence : bool
238 | Trim leading/trailing silence from the rendered waveform.
239 | trim_top_db : float
240 | `top_db` for `librosa.effects.trim`.
241 | trim_frame_length : int
242 | `frame_length` for `librosa.effects.trim`.
243 | trim_hop_length : int
244 | `hop_length` for `librosa.effects.trim`.
245 | denoise : bool
246 | Run denoiser post-processing when True.
247 | denoise_kwargs : dict or None
248 | Extra kwargs for `denoise_audio`.
249 | enhance_bass : bool
250 | Run bass enhancement when True.
251 | bass_kwargs : dict or None
252 | Extra kwargs for `enhance_audio_bass`. `low_gain_db` is set from the
253 | top-level argument unless overridden here.
254 | low_gain_db : float
255 | Default low-frequency gain (dB) for bass enhancement.
256 | enhance_full : bool
257 | Run full-spectrum enhancement when True.
258 | enhance_full_kwargs : dict or None
259 | Extra kwargs for `enhance_audio_full`.
260 | master : bool
261 | Run final mastering when True.
262 | master_kwargs : dict or None
263 | Extra kwargs for `master_mono_piano`. `gain_db` is set from the top-level
264 | argument unless overridden here.
265 | overall_gain_db : float
266 | Default gain (dB) applied during mastering.
267 | device : str
268 | Torch device string (e.g., 'cuda' or 'cpu').
269 | return_audio : bool
270 | If True, return `(audio_numpy, sample_rate)` instead of writing a file.
271 | verbose : bool
272 | Print progress messages.
273 | verbose_diag : bool
274 | Print diagnostic information for debugging.
275 |
276 | Returns
277 | -------
278 | str or (numpy.ndarray, int)
279 | If `return_audio` is False (default), returns the path to the written WAV file.
280 | If `return_audio` is True, returns a tuple `(audio_numpy, sample_rate)` where
281 | `audio_numpy` is a 1-D NumPy array (mono) and `sample_rate` is an int.
282 |
283 | Raises
284 | ------
285 | FileNotFoundError
286 | If the input MIDI file or required model/SF2 files are missing.
287 | RuntimeError
288 | If model inference or post-processing fails (propagates underlying errors).
289 |
290 | """
291 |
292 | def _pv(msg):
293 | if verbose:
294 | print(msg)
295 |
296 | _pv('=' * 70)
297 | _pv('Neural Piano')
298 | _pv('=' * 70)
299 |
300 | # Normalize kwargs buckets
301 | denoise_kwargs = {} if denoise_kwargs is None else dict(denoise_kwargs)
302 | bass_kwargs = {} if bass_kwargs is None else dict(bass_kwargs)
303 | enhance_full_kwargs = {} if enhance_full_kwargs is None else dict(enhance_full_kwargs)
304 | master_kwargs = {} if master_kwargs is None else dict(master_kwargs)
305 |
306 | # Provide sensible defaults from top-level args unless overridden in kwargs
307 | if 'low_gain_db' not in bass_kwargs:
308 | bass_kwargs['low_gain_db'] = low_gain_db
309 |
310 | if 'overall_gain_db' not in enhance_full_kwargs:
311 | enhance_full_kwargs['overall_gain_db'] = overall_gain_db
312 |
313 | if 'gain_db' not in master_kwargs:
314 | master_kwargs['gain_db'] = overall_gain_db
315 |
316 | home_root = os.getcwd()
317 | models_dir = os.path.join(home_root, "models")
318 | sf2_path = os.path.join(models_dir, sf2_name)
319 |
320 | if verbose_diag:
321 | _pv(home_root)
322 | _pv(models_dir)
323 | _pv(sf2_path)
324 | _pv('=' * 70)
325 |
326 | _pv('Prepping model...')
327 | encdec = EncoderDecoder(load_multi_instrumental_model=load_multi_instrumental_model,
328 | use_v1_piano_model=use_v1_piano_model,
329 | load_path_inference=custom_model_path
330 | )
331 |
332 | if verbose_diag:
333 | try:
334 | _pv(encdec.gen)
335 | except Exception:
336 | _pv('encdec.gen: ')
337 | _pv('=' * 70)
338 |
339 | _pv('Reading and rendering MIDI file...')
340 | wav_data = midirenderer.render_wave_from(
341 | Path(sf2_path).read_bytes(),
342 | Path(input_midi_file).read_bytes()
343 | )
344 |
345 | if verbose_diag:
346 | _pv(len(wav_data))
347 | _pv('=' * 70)
348 |
349 | _pv('Loading rendered MIDI...')
350 | with io.BytesIO(wav_data) as byte_stream:
351 | wv, sr = librosa.load(byte_stream, sr=sample_rate)
352 |
353 | if verbose_diag:
354 | _pv(sr)
355 | _pv(wv.shape)
356 | _pv('=' * 70)
357 |
358 | if trim_silence:
359 | _pv('Trimming leading and trailing silence from rendered waveform...')
360 | wv_trimmed, trim_interval = librosa.effects.trim(
361 | wv,
362 | top_db=trim_top_db,
363 | frame_length=trim_frame_length,
364 | hop_length=trim_hop_length
365 | )
366 | start_sample, end_sample = trim_interval
367 | orig_dur = len(wv) / sr
368 | trimmed_dur = len(wv_trimmed) / sr
369 | if verbose:
370 | _pv(f' Trimmed samples: start={start_sample}, end={end_sample}')
371 | _pv(f' Duration before={orig_dur:.3f}s, after={trimmed_dur:.3f}s')
372 | wv = wv_trimmed
373 | else:
374 | _pv('Silence trimming disabled; using full rendered waveform.')
375 |
376 | if verbose_diag:
377 | _pv(wv.shape)
378 | _pv('=' * 70)
379 |
380 | _pv('Encoding...')
381 | latent = encdec.encode(wv,
382 | max_batch_size=max_batch_size,
383 | show_progress=verbose
384 | )
385 |
386 | if verbose_diag:
387 | try:
388 | _pv(latent.shape)
389 | except Exception:
390 | _pv('latent.shape: ')
391 | _pv('=' * 70)
392 |
393 | _pv('Rendering...')
394 | audio = encdec.decode(latent,
395 | denoising_steps=denoising_steps,
396 | max_batch_size=max_batch_size,
397 | show_progress=verbose
398 | )
399 |
400 | audio = audio.squeeze()
401 |
402 | if verbose_diag:
403 | try:
404 | _pv(audio.shape)
405 | except Exception:
406 | _pv('audio.shape: ')
407 | _pv('=' * 70)
408 |
409 | # Post-processing: denoise
410 | if denoise:
411 | _pv('Denoising...')
412 | # Always pass sr and device; allow denoise_kwargs to override them if provided
413 | denoise_call_kwargs = dict(sr=sr, device=torch.device(device))
414 | denoise_call_kwargs.update(denoise_kwargs)
415 | audio, den_diag = denoise_audio(audio, **denoise_call_kwargs)
416 |
417 | if verbose_diag:
418 | _pv(den_diag)
419 | _pv('=' * 70)
420 |
421 | # Post-processing: bass enhancement
422 | if enhance_bass:
423 | _pv('Enhancing bass...')
424 | bass_call_kwargs = dict(sr=sr, device=torch.device(device))
425 | bass_call_kwargs.update(bass_kwargs)
426 | audio, bass_diag = enhance_audio_bass(audio, **bass_call_kwargs)
427 |
428 | if verbose_diag:
429 | _pv(bass_diag)
430 | _pv('=' * 70)
431 |
432 | # Post-processing: full enhancement (placed before mastering)
433 | if enhance_full:
434 | _pv('Enhancing full audio...')
435 | full_call_kwargs = dict(sr=sr, device=torch.device(device))
436 | full_call_kwargs.update(enhance_full_kwargs)
437 |
438 | if not master:
439 | output_as_stereo = True
440 |
441 | else:
442 | output_as_stereo = False
443 |
444 | audio, full_diag = enhance_audio_full(audio,
445 | output_as_stereo=output_as_stereo,
446 | **full_call_kwargs
447 | )
448 |
449 | if verbose_diag:
450 | _pv(full_diag)
451 | _pv('=' * 70)
452 |
453 | # Post-processing: mastering
454 | if master:
455 | _pv('Mastering...')
456 | master_call_kwargs = dict(device=torch.device(device))
457 | master_call_kwargs.update(master_kwargs)
458 | audio, mas_diag = master_mono_piano(audio, **master_call_kwargs)
459 |
460 | if verbose_diag:
461 | _pv(mas_diag)
462 | _pv('=' * 70)
463 |
464 | if verbose_diag:
465 | try:
466 | _pv(audio.shape)
467 | except Exception:
468 | _pv('audio.shape: ')
469 | _pv('=' * 70)
470 |
471 | _pv('Creating final audio...')
472 | final_audio = audio.cpu().numpy().squeeze().T
473 |
474 | if verbose_diag:
475 | _pv(final_audio.shape)
476 | _pv(sr)
477 | _pv('=' * 70)
478 |
479 | if return_audio:
480 | _pv('Returning final audio...')
481 | _pv('=' * 70)
482 | _pv('Done!')
483 | _pv('=' * 70)
484 | return final_audio, sr
485 |
486 | else:
487 | _pv('Saving final audio...')
488 | if output_audio_file is None:
489 | midi_name = os.path.basename(input_midi_file)
490 | output_name, _ = os.path.splitext(midi_name)
491 | output_audio_file = os.path.join(home_root, output_name + '.wav')
492 |
493 | if verbose_diag:
494 | _pv(output_audio_file)
495 | _pv(sr)
496 | _pv('=' * 70)
497 |
498 | sf.write(output_audio_file, final_audio, samplerate=sr)
499 |
500 | _pv('=' * 70)
501 | _pv('Done!')
502 | _pv('=' * 70)
503 |
504 | return output_audio_file
--------------------------------------------------------------------------------
/neuralpiano/music2latent/models.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .audio import *
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def zero_init(module):
10 | if init_as_zero:
11 | for p in module.parameters():
12 | p.detach().zero_()
13 | return module
14 |
15 | def upsample_1d(x):
16 | return F.interpolate(x, scale_factor=2, mode="nearest")
17 |
18 | def downsample_1d(x):
19 | return F.avg_pool1d(x, kernel_size=2, stride=2)
20 |
21 | def upsample_2d(x):
22 | return F.interpolate(x, scale_factor=2, mode="nearest")
23 |
24 | def downsample_2d(x):
25 | return F.avg_pool2d(x, kernel_size=2, stride=2)
26 |
27 |
28 | class LayerNorm(nn.Module):
29 | def __init__(self, dim):
30 | super(LayerNorm, self).__init__()
31 | self.ln = torch.nn.LayerNorm(dim)
32 |
33 | def forward(self, input):
34 | x = input.permute(0,2,3,1)
35 | x = self.ln(x)
36 | x = x.permute(0,3,1,2)
37 | return x
38 |
39 | class FreqGain(nn.Module):
40 | def __init__(self, freq_dim):
41 | super(FreqGain, self).__init__()
42 | self.scale = nn.Parameter(torch.ones((1,1,freq_dim,1)))
43 |
44 | def forward(self, input):
45 | return input*self.scale
46 |
47 |
48 | class UpsampleConv(nn.Module):
49 | def __init__(self, in_channels, out_channels=None, use_2d=False, normalize=False):
50 | super(UpsampleConv, self).__init__()
51 | self.normalize = normalize
52 |
53 | self.use_2d = use_2d
54 |
55 | if out_channels is None:
56 | out_channels = in_channels
57 |
58 | if normalize:
59 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels)
60 |
61 | if use_2d:
62 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same')
63 | else:
64 | self.c = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding='same')
65 |
66 | def forward(self, x):
67 |
68 | if self.normalize:
69 | x = self.norm(x)
70 |
71 | if self.use_2d:
72 | x = upsample_2d(x)
73 | else:
74 | x = upsample_1d(x)
75 | x = self.c(x)
76 |
77 | return x
78 |
79 | class DownsampleConv(nn.Module):
80 | def __init__(self, in_channels, out_channels=None, use_2d=False, normalize=False):
81 | super(DownsampleConv, self).__init__()
82 | self.normalize = normalize
83 |
84 | if out_channels is None:
85 | out_channels = in_channels
86 |
87 | if normalize:
88 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels)
89 |
90 | if use_2d:
91 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
92 | else:
93 | self.c = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
94 |
95 | def forward(self, x):
96 |
97 | if self.normalize:
98 | x = self.norm(x)
99 | x = self.c(x)
100 |
101 | return x
102 |
103 | class UpsampleFreqConv(nn.Module):
104 | def __init__(self, in_channels, out_channels=None, normalize=False):
105 | super(UpsampleFreqConv, self).__init__()
106 | self.normalize = normalize
107 |
108 | if out_channels is None:
109 | out_channels = in_channels
110 |
111 | if normalize:
112 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels)
113 |
114 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=1, padding='same')
115 |
116 | def forward(self, x):
117 | if self.normalize:
118 | x = self.norm(x)
119 | x = F.interpolate(x, scale_factor=(4,1), mode="nearest")
120 | x = self.c(x)
121 | return x
122 |
123 | class DownsampleFreqConv(nn.Module):
124 | def __init__(self, in_channels, out_channels=None, normalize=False):
125 | super(DownsampleFreqConv, self).__init__()
126 | self.normalize = normalize
127 |
128 | if out_channels is None:
129 | out_channels = in_channels
130 |
131 | if normalize:
132 | self.norm = nn.GroupNorm(min(in_channels//4, 32), in_channels)
133 |
134 | self.c = nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=(4,1), padding=(2,0))
135 |
136 | def forward(self, x):
137 | if self.normalize:
138 | x = self.norm(x)
139 | x = self.c(x)
140 | return x
141 |
142 | class MultiheadAttention(nn.MultiheadAttention):
143 | def _reset_parameters(self):
144 | super()._reset_parameters()
145 | self.out_proj = zero_init(self.out_proj)
146 |
147 | class Attention(nn.Module):
148 | def __init__(self, dim, heads=4, normalize=True, use_2d=False):
149 | super(Attention, self).__init__()
150 |
151 | self.normalize = normalize
152 | self.use_2d = use_2d
153 |
154 | self.mha = MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=0.0, add_zero_attn=False, batch_first=True)
155 | if normalize:
156 | self.norm = nn.GroupNorm(min(dim//4, 32), dim)
157 |
158 | def forward(self, x):
159 |
160 | inp = x
161 |
162 | if self.normalize:
163 | x = self.norm(x)
164 |
165 | if self.use_2d:
166 | x = x.permute(0,3,2,1) # shape: [bs,len,freq,channels]
167 | bs,len,freq,channels = x.shape[0],x.shape[1],x.shape[2],x.shape[3]
168 | x = x.reshape(bs*len,freq,channels) # shape: [bs*len,freq,channels]
169 | else:
170 | x = x.permute(0,2,1) # shape: [bs,len,channels]
171 |
172 | x = self.mha(x, x, x, need_weights=False)[0]
173 |
174 | if self.use_2d:
175 | x = x.reshape(bs,len,freq,channels).permute(0,3,2,1)
176 | else:
177 | x = x.permute(0,2,1)
178 | x = x+inp
179 |
180 | return x
181 |
182 |
183 |
184 | class ResBlock(nn.Module):
185 | def __init__(self, in_channels, out_channels, cond_channels=None, kernel_size=3, downsample=False, upsample=False, normalize=True, leaky=False, attention=False, heads=4, use_2d=False, normalize_residual=False):
186 | super(ResBlock, self).__init__()
187 | self.normalize = normalize
188 | self.attention = attention
189 | self.upsample = upsample
190 | self.downsample = downsample
191 | self.leaky = leaky
192 | self.kernel_size = kernel_size
193 | self.normalize_residual = normalize_residual
194 | self.use_2d = use_2d
195 | if use_2d:
196 | Conv = nn.Conv2d
197 | else:
198 | Conv = nn.Conv1d
199 | self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding='same')
200 | self.conv2 = zero_init(Conv(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding='same'))
201 | if in_channels!=out_channels:
202 | self.res_conv = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
203 | else:
204 | self.res_conv = nn.Identity()
205 | if normalize:
206 | self.norm1 = nn.GroupNorm(min(in_channels//4, 32), in_channels)
207 | self.norm2 = nn.GroupNorm(min(out_channels//4, 32), out_channels)
208 | if leaky:
209 | self.activation = nn.LeakyReLU(negative_slope=0.2)
210 | else:
211 | self.activation = nn.SiLU()
212 | if cond_channels is not None:
213 | self.proj_emb = zero_init(nn.Linear(cond_channels, out_channels))
214 | self.dropout = nn.Dropout(dropout_rate)
215 | if attention:
216 | self.att = Attention(out_channels, heads, use_2d=use_2d)
217 |
218 |
219 | def forward(self, x, time_emb=None):
220 | if not self.normalize_residual:
221 | y = x.clone()
222 | if self.normalize:
223 | x = self.norm1(x)
224 | if self.normalize_residual:
225 | y = x.clone()
226 | x = self.activation(x)
227 | if self.downsample:
228 | if self.use_2d:
229 | x = downsample_2d(x)
230 | y = downsample_2d(y)
231 | else:
232 | x = downsample_1d(x)
233 | y = downsample_1d(y)
234 | if self.upsample:
235 | if self.use_2d:
236 | x = upsample_2d(x)
237 | y = upsample_2d(y)
238 | else:
239 | x = upsample_1d(x)
240 | y = upsample_1d(y)
241 | x = self.conv1(x)
242 | if time_emb is not None:
243 | if self.use_2d:
244 | x = x+self.proj_emb(time_emb)[:,:,None,None]
245 | else:
246 | x = x+self.proj_emb(time_emb)[:,:,None]
247 | if self.normalize:
248 | x = self.norm2(x)
249 | x = self.activation(x)
250 | if x.shape[-1]<=min_res_dropout:
251 | x = self.dropout(x)
252 | x = self.conv2(x)
253 | y = self.res_conv(y)
254 | x = x+y
255 | if self.attention:
256 | x = self.att(x)
257 | return x
258 |
259 |
260 | # adapted from https://github.com/yang-song/score_sde_pytorch/blob/main/models/layerspp.py
261 | class GaussianFourierProjection(torch.nn.Module):
262 | """Gaussian Fourier embeddings for noise levels."""
263 |
264 | def __init__(self, embedding_size=128, scale=0.02):
265 | super().__init__()
266 | self.W = torch.nn.Parameter(torch.randn(embedding_size//2) * scale, requires_grad=False)
267 |
268 | def forward(self, x):
269 | x_proj = x[:, None] * self.W[None, :] * 2. * np.pi
270 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
271 |
272 |
273 | class PositionalEmbedding(torch.nn.Module):
274 | def __init__(self, embedding_size=128, max_positions=10000):
275 | super().__init__()
276 | self.embedding_size = embedding_size
277 | self.max_positions = max_positions
278 |
279 | def forward(self, x):
280 | freqs = torch.arange(start=0, end=self.embedding_size//2, dtype=torch.float32, device=x.device)
281 | freqs = freqs / (self.embedding_size // 2 - 1)
282 | freqs = (1 / self.max_positions) ** freqs
283 | x = x.ger(freqs.to(x.dtype))
284 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
285 | return x
286 |
287 | class Encoder(nn.Module):
288 | def __init__(self):
289 | super(Encoder, self).__init__()
290 |
291 | layers_list = layers_list_encoder
292 | attention_list = attention_list_encoder
293 | self.layers_list = layers_list
294 | self.multipliers_list = multipliers_list
295 | input_channels = base_channels*multipliers_list[0]
296 | Conv = nn.Conv2d
297 | self.gain = FreqGain(freq_dim=hop*2)
298 |
299 | channels = data_channels
300 | self.conv_inp = Conv(channels, input_channels, kernel_size=3, stride=1, padding=1)
301 |
302 | self.freq_dim = (hop*2)//(4**freq_downsample_list.count(1))
303 | self.freq_dim = self.freq_dim//(2**freq_downsample_list.count(0))
304 |
305 | # DOWNSAMPLING
306 | down_layers = []
307 | for i, (num_layers,multiplier) in enumerate(zip(layers_list,multipliers_list)):
308 | output_channels = base_channels*multiplier
309 | for num in range(num_layers):
310 | down_layers.append(ResBlock(input_channels, output_channels, normalize=normalization, attention=attention_list[i]==1, heads=heads, use_2d=True))
311 | input_channels = output_channels
312 | if i!=(len(layers_list)-1):
313 | if freq_downsample_list[i]==1:
314 | down_layers.append(DownsampleFreqConv(input_channels, normalize=pre_normalize_downsampling_encoder))
315 | else:
316 | down_layers.append(DownsampleConv(input_channels, use_2d=True, normalize=pre_normalize_downsampling_encoder))
317 |
318 | if pre_normalize_2d_to_1d:
319 | self.prenorm_1d_to_2d = nn.GroupNorm(min(input_channels//4, 32), input_channels)
320 |
321 | bottleneck_layers = []
322 | output_channels = bottleneck_base_channels
323 | bottleneck_layers.append(nn.Conv1d(input_channels*self.freq_dim, output_channels, kernel_size=1, stride=1, padding='same'))
324 | for i in range(num_bottleneck_layers):
325 | bottleneck_layers.append(ResBlock(output_channels, output_channels, normalize=normalization, use_2d=False))
326 | self.bottleneck_layers = nn.ModuleList(bottleneck_layers)
327 |
328 | self.norm_out = nn.GroupNorm(min(output_channels//4, 32), output_channels)
329 | self.activation_out = nn.SiLU()
330 | self.conv_out = nn.Conv1d(output_channels, bottleneck_channels, kernel_size=1, stride=1, padding='same')
331 | self.activation_bottleneck = nn.Tanh()
332 |
333 | self.down_layers = nn.ModuleList(down_layers)
334 |
335 |
336 | def forward(self, x, extract_features=False):
337 |
338 | x = self.conv_inp(x)
339 | if frequency_scaling:
340 | x = self.gain(x)
341 |
342 | # DOWNSAMPLING
343 | k = 0
344 | for i,num_layers in enumerate(self.layers_list):
345 | for num in range(num_layers):
346 | x = self.down_layers[k](x)
347 | k = k+1
348 | if i!=(len(self.layers_list)-1):
349 | x = self.down_layers[k](x)
350 | k = k+1
351 |
352 | if pre_normalize_2d_to_1d:
353 | x = self.prenorm_1d_to_2d(x)
354 |
355 | x = x.reshape(x.size(0), x.size(1) * x.size(2), x.size(3))
356 | if extract_features:
357 | return x
358 |
359 | for layer in self.bottleneck_layers:
360 | x = layer(x)
361 |
362 | x = self.norm_out(x)
363 | x = self.activation_out(x)
364 | x = self.conv_out(x)
365 | x = self.activation_bottleneck(x)
366 |
367 | return x
368 |
369 |
370 | class Decoder(nn.Module):
371 | def __init__(self):
372 | super(Decoder, self).__init__()
373 |
374 | layers_list = layers_list_encoder
375 | attention_list = attention_list_encoder
376 | self.layers_list = layers_list_encoder
377 | self.multipliers_list = multipliers_list
378 | input_channels = base_channels*multipliers_list[-1]
379 |
380 | output_channels = bottleneck_base_channels
381 | self.conv_inp = nn.Conv1d(bottleneck_channels, output_channels, kernel_size=1, stride=1, padding='same')
382 |
383 | self.freq_dim = (hop*2)//(4**freq_downsample_list.count(1))
384 | self.freq_dim = self.freq_dim//(2**freq_downsample_list.count(0))
385 |
386 | bottleneck_layers = []
387 | for i in range(num_bottleneck_layers):
388 | bottleneck_layers.append(ResBlock(output_channels, output_channels, cond_channels, normalize=normalization, use_2d=False))
389 |
390 | self.conv_out_bottleneck = nn.Conv1d(output_channels, input_channels*self.freq_dim, kernel_size=1, stride=1, padding='same')
391 | self.bottleneck_layers = nn.ModuleList(bottleneck_layers)
392 |
393 | # UPSAMPLING
394 | multipliers_list_upsampling = list(reversed(multipliers_list))[1:]+list(reversed(multipliers_list))[:1]
395 | freq_upsample_list = list(reversed(freq_downsample_list))
396 | up_layers = []
397 | for i, (num_layers,multiplier) in enumerate(zip(reversed(layers_list),multipliers_list_upsampling)):
398 | for num in range(num_layers):
399 | up_layers.append(ResBlock(input_channels, input_channels, normalize=normalization, attention=list(reversed(attention_list))[i]==1, heads=heads, use_2d=True))
400 | if i!=(len(layers_list)-1):
401 | output_channels = base_channels*multiplier
402 | if freq_upsample_list[i]==1:
403 | up_layers.append(UpsampleFreqConv(input_channels, output_channels))
404 | else:
405 | up_layers.append(UpsampleConv(input_channels, output_channels, use_2d=True))
406 | input_channels = output_channels
407 |
408 | self.up_layers = nn.ModuleList(up_layers)
409 |
410 |
411 | def forward(self, x):
412 |
413 | x = self.conv_inp(x)
414 |
415 | for layer in self.bottleneck_layers:
416 | x = layer(x)
417 | x = self.conv_out_bottleneck(x)
418 |
419 | x_ls = torch.chunk(x.unsqueeze(-2), self.freq_dim, -3)
420 | x = torch.cat(x_ls, -2)
421 |
422 | # UPSAMPLING
423 | k = 0
424 | pyramid_list = []
425 | for i,num_layers in enumerate(reversed(self.layers_list)):
426 | for num in range(num_layers):
427 | x = self.up_layers[k](x)
428 | k = k+1
429 | pyramid_list.append(x)
430 | if i!=(len(self.layers_list)-1):
431 | x = self.up_layers[k](x)
432 | k = k+1
433 |
434 | pyramid_list = pyramid_list[::-1]
435 |
436 | return pyramid_list
437 |
438 |
439 | class UNet(nn.Module):
440 | def __init__(self):
441 | super(UNet, self).__init__()
442 |
443 | self.layers_list = layers_list
444 | self.multipliers_list = multipliers_list
445 | input_channels = base_channels*multipliers_list[0]
446 | Conv = nn.Conv2d
447 |
448 | self.encoder = Encoder()
449 | self.decoder = Decoder()
450 |
451 | if use_fourier:
452 | self.emb = GaussianFourierProjection(embedding_size=cond_channels, scale=fourier_scale)
453 | else:
454 | self.emb = PositionalEmbedding(embedding_size=cond_channels)
455 |
456 | self.emb_proj = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU())
457 |
458 | self.scale_inp = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU(), zero_init(nn.Linear(cond_channels, hop*2)))
459 | self.scale_out = nn.Sequential(nn.Linear(cond_channels, cond_channels), nn.SiLU(), nn.Linear(cond_channels, cond_channels), nn.SiLU(), zero_init(nn.Linear(cond_channels, hop*2)))
460 |
461 | self.conv_inp = Conv(data_channels, input_channels, kernel_size=3, stride=1, padding=1)
462 |
463 | # DOWNSAMPLING
464 | down_layers = []
465 | for i, (num_layers,multiplier) in enumerate(zip(layers_list,multipliers_list)):
466 | output_channels = base_channels*multiplier
467 | for num in range(num_layers):
468 | down_layers.append(Conv(output_channels, output_channels, kernel_size=1, stride=1, padding=0))
469 | down_layers.append(ResBlock(output_channels, output_channels, cond_channels, normalize=normalization, attention=attention_list[i]==1, heads=heads, use_2d=True))
470 | input_channels = output_channels
471 | if i!=(len(layers_list)-1):
472 | output_channels = base_channels*multipliers_list[i+1]
473 | if freq_downsample_list[i]==1:
474 | down_layers.append(DownsampleFreqConv(input_channels, output_channels))
475 | else:
476 | down_layers.append(DownsampleConv(input_channels, output_channels, use_2d=True))
477 |
478 | # UPSAMPLING
479 | multipliers_list_upsampling = list(reversed(multipliers_list))[1:]+list(reversed(multipliers_list))[:1]
480 | freq_upsample_list = list(reversed(freq_downsample_list))
481 | up_layers = []
482 | for i, (num_layers,multiplier) in enumerate(zip(reversed(layers_list),multipliers_list_upsampling)):
483 | for num in range(num_layers):
484 | up_layers.append(Conv(input_channels, input_channels, kernel_size=1, stride=1, padding=0))
485 | up_layers.append(ResBlock(input_channels, input_channels, cond_channels, normalize=normalization, attention=list(reversed(attention_list))[i]==1, heads=heads, use_2d=True))
486 | if i!=(len(layers_list)-1):
487 | output_channels = base_channels*multiplier
488 | if freq_upsample_list[i]==1:
489 | up_layers.append(UpsampleFreqConv(input_channels, output_channels))
490 | else:
491 | up_layers.append(UpsampleConv(input_channels, output_channels, use_2d=True))
492 | input_channels = output_channels
493 |
494 | self.conv_decoded = Conv(input_channels, input_channels, kernel_size=1, stride=1, padding=0)
495 | self.norm_out = nn.GroupNorm(min(input_channels//4, 32), input_channels)
496 | self.activation_out = nn.SiLU()
497 | self.conv_out = zero_init(Conv(input_channels, data_channels, kernel_size=3, stride=1, padding=1))
498 |
499 | self.down_layers = nn.ModuleList(down_layers)
500 | self.up_layers = nn.ModuleList(up_layers)
501 |
502 |
503 | def forward(self, latents, x, sigma=None, pyramid_latents=None):
504 |
505 | if sigma is None:
506 | sigma = sigma_max
507 |
508 | inp = x
509 |
510 | # CONDITIONING
511 | sigma = torch.ones((x.shape[0],), dtype=torch.float32).to(x.device)*sigma
512 | sigma_log = torch.log(sigma)/4.
513 | emb_sigma_log = self.emb(sigma_log)
514 | time_emb = self.emb_proj(emb_sigma_log)
515 |
516 | scale_w_inp = self.scale_inp(emb_sigma_log).reshape(x.shape[0],1,-1,1)
517 | scale_w_out = self.scale_out(emb_sigma_log).reshape(x.shape[0],1,-1,1)
518 |
519 | c_skip, c_out, c_in = get_c(sigma)
520 |
521 | x = c_in*x
522 |
523 | if latents.shape == x.shape:
524 | latents = self.encoder(latents)
525 |
526 | if pyramid_latents is None:
527 | pyramid_latents = self.decoder(latents)
528 |
529 | x = self.conv_inp(x)
530 | if frequency_scaling:
531 | x = (1.+scale_w_inp)*x
532 |
533 | skip_list = []
534 |
535 | # DOWNSAMPLING
536 | k = 0
537 | r = 0
538 | for i,num_layers in enumerate(self.layers_list):
539 | for num in range(num_layers):
540 | d = self.down_layers[k](pyramid_latents[i])
541 | k = k+1
542 | x = (x+d)/np.sqrt(2.)
543 | x = self.down_layers[k](x, time_emb)
544 | skip_list.append(x)
545 | k = k+1
546 | if i!=(len(self.layers_list)-1):
547 | x = self.down_layers[k](x)
548 | k = k+1
549 |
550 | # UPSAMPLING
551 | k = 0
552 | for i,num_layers in enumerate(reversed(self.layers_list)):
553 | for num in range(num_layers):
554 | d = self.up_layers[k](pyramid_latents[-i-1])
555 | k = k+1
556 | x = (x+skip_list.pop()+d)/np.sqrt(3.)
557 | x = self.up_layers[k](x, time_emb)
558 | k = k+1
559 | if i!=(len(self.layers_list)-1):
560 | x = self.up_layers[k](x)
561 | k = k+1
562 |
563 | d = self.conv_decoded(pyramid_latents[0])
564 | x = (x+d)/np.sqrt(2.)
565 |
566 | x = self.norm_out(x)
567 | x = self.activation_out(x)
568 | if frequency_scaling:
569 | x = (1.+scale_w_out)*x
570 | x = self.conv_out(x)
571 |
572 | out = c_skip*inp + c_out*x
573 |
574 | return out
--------------------------------------------------------------------------------