├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ ├── code_formatter.yml
│ ├── nextjs.yml
│ └── unittest.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── docs
├── components
│ ├── counters.module.css
│ └── counters.tsx
├── favicon.ico
├── next-env.d.ts
├── next.config.js
├── package.json
├── pages
│ ├── Usage
│ │ ├── _meta.json
│ │ ├── rvc.mdx
│ │ └── uvr.mdx
│ ├── _meta.json
│ ├── index.mdx
│ └── installation.mdx
├── pnpm-lock.yaml
├── theme.config.tsx
└── tsconfig.json
├── install.bat
├── install.sh
├── logs
├── mute
│ ├── extracted
│ │ └── mute.npy
│ ├── f0
│ │ └── mute.wav.npy
│ ├── f0_voiced
│ │ └── mute.wav.npy
│ ├── sliced_audios
│ │ ├── mute32000.wav
│ │ ├── mute40000.wav
│ │ ├── mute44100.wav
│ │ └── mute48000.wav
│ └── sliced_audios_16k
│ │ └── mute.wav
└── reference
│ ├── ref32000.wav
│ ├── ref32000_f0c.npy
│ ├── ref32000_f0f.npy
│ ├── ref32000_feats.npy
│ ├── ref40000.wav
│ ├── ref40000_f0c.npy
│ ├── ref40000_f0f.npy
│ ├── ref40000_feats.npy
│ ├── ref48000.wav
│ ├── ref48000_f0c.npy
│ ├── ref48000_f0f.npy
│ └── ref48000_feats.npy
├── requirements.txt
├── rvc
├── configs
│ ├── 32000.json
│ ├── 40000.json
│ ├── 48000.json
│ └── config.py
├── infer
│ ├── infer.py
│ └── pipeline.py
├── lib
│ ├── algorithm
│ │ ├── __init__.py
│ │ ├── attentions.py
│ │ ├── commons.py
│ │ ├── discriminators.py
│ │ ├── encoders.py
│ │ ├── generators
│ │ │ ├── __init__.py
│ │ │ ├── hifigan.py
│ │ │ ├── hifigan_mrf.py
│ │ │ ├── hifigan_nsf.py
│ │ │ └── refinegan.py
│ │ ├── modules.py
│ │ ├── normalization.py
│ │ ├── residuals.py
│ │ └── synthesizers.py
│ ├── predictors
│ │ ├── F0Extractor.py
│ │ ├── FCPE.py
│ │ └── RMVPE.py
│ ├── tools
│ │ ├── analyzer.py
│ │ ├── gdown.py
│ │ ├── launch_tensorboard.py
│ │ ├── model_download.py
│ │ ├── prerequisites_download.py
│ │ ├── pretrained_selector.py
│ │ ├── split_audio.py
│ │ ├── tts.py
│ │ └── tts_voices.json
│ ├── utils.py
│ └── zluda.py
├── models
│ ├── embedders
│ │ ├── .gitkeep
│ │ └── embedders_custom
│ │ │ └── .gitkeep
│ ├── formant
│ │ └── .gitkeep
│ ├── predictors
│ │ └── .gitkeep
│ └── pretraineds
│ │ ├── .gitkeep
│ │ ├── custom
│ │ └── .gitkeep
│ │ └── hifi-gan
│ │ └── .gitkeep
└── train
│ ├── data_utils.py
│ ├── extract
│ ├── extract.py
│ └── preparing_files.py
│ ├── losses.py
│ ├── mel_processing.py
│ ├── preprocess
│ ├── preprocess.py
│ └── slicer.py
│ ├── process
│ ├── change_info.py
│ ├── extract_index.py
│ ├── extract_model.py
│ ├── model_blender.py
│ └── model_information.py
│ ├── train.py
│ └── utils.py
├── rvc_cli.py
├── uvr
├── __init__.py
├── architectures
│ ├── __init__.py
│ ├── demucs_separator.py
│ ├── mdx_separator.py
│ ├── mdxc_separator.py
│ └── vr_separator.py
├── common_separator.py
├── separator.py
└── uvr_lib_v5
│ ├── __init__.py
│ ├── attend.py
│ ├── bs_roformer.py
│ ├── demucs
│ ├── __init__.py
│ ├── __main__.py
│ ├── apply.py
│ ├── demucs.py
│ ├── filtering.py
│ ├── hdemucs.py
│ ├── htdemucs.py
│ ├── model.py
│ ├── model_v2.py
│ ├── pretrained.py
│ ├── repo.py
│ ├── spec.py
│ ├── states.py
│ ├── tasnet.py
│ ├── tasnet_v2.py
│ ├── transformer.py
│ └── utils.py
│ ├── mdxnet.py
│ ├── mel_band_roformer.py
│ ├── mixer.ckpt
│ ├── modules.py
│ ├── playsound.py
│ ├── pyrb.py
│ ├── results.py
│ ├── spec_utils.py
│ ├── stft.py
│ ├── tfc_tdf_v3.py
│ └── vr_network
│ ├── __init__.py
│ ├── layers.py
│ ├── layers_new.py
│ ├── model_param_init.py
│ ├── modelparams
│ ├── 1band_sr16000_hl512.json
│ ├── 1band_sr32000_hl512.json
│ ├── 1band_sr33075_hl384.json
│ ├── 1band_sr44100_hl1024.json
│ ├── 1band_sr44100_hl256.json
│ ├── 1band_sr44100_hl512.json
│ ├── 1band_sr44100_hl512_cut.json
│ ├── 1band_sr44100_hl512_nf1024.json
│ ├── 2band_32000.json
│ ├── 2band_44100_lofi.json
│ ├── 2band_48000.json
│ ├── 3band_44100.json
│ ├── 3band_44100_mid.json
│ ├── 3band_44100_msb2.json
│ ├── 4band_44100.json
│ ├── 4band_44100_mid.json
│ ├── 4band_44100_msb.json
│ ├── 4band_44100_msb2.json
│ ├── 4band_44100_reverse.json
│ ├── 4band_44100_sw.json
│ ├── 4band_v2.json
│ ├── 4band_v2_sn.json
│ ├── 4band_v3.json
│ ├── 4band_v3_sn.json
│ └── ensemble.json
│ ├── nets.py
│ └── nets_new.py
└── uvr_cli.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: #
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: iahispano
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Before You Report a Bug**
11 | Reporting a bug is essential for us to improve our service, but we need detailed information to address the issue effectively. Since every computer setup is unique, there can be various reasons behind a bug. Before reporting, consider potential causes and provide as much detail as possible to help us understand the problem.
12 |
13 | **Bug Description**
14 | Please provide a clear and concise description of the bug.
15 |
16 | **Steps to Reproduce**
17 | Outline the steps to replicate the issue:
18 | 1. Go to '...'
19 | 2. Click on '....'
20 | 3. Scroll down to '....'
21 | 4. Observe the error.
22 |
23 | **Expected Behavior**
24 | Describe what you expected to happen.
25 |
26 | **Assets**
27 | Include screenshots or videos if they can illustrate the issue.
28 |
29 | **Desktop Details:**
30 | - Operating System: [e.g., Windows 11]
31 | - Browser: [e.g., Chrome, Safari]
32 |
33 | **Additional Context**
34 | Any additional information that might be relevant to the issue.
35 |
--------------------------------------------------------------------------------
/.github/workflows/code_formatter.yml:
--------------------------------------------------------------------------------
1 | name: Code Formatter
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | push_format:
10 | runs-on: ubuntu-latest
11 |
12 | permissions:
13 | contents: write
14 | pull-requests: write
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 | with:
19 | ref: ${{github.ref_name}}
20 |
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 |
26 | - name: Install Black
27 | run: pip install "black[jupyter]"
28 |
29 | - name: Run Black
30 | # run: black $(git ls-files '*.py')
31 | run: black . --exclude=".*\.ipynb$"
32 |
33 | - name: Commit Back
34 | continue-on-error: true
35 | id: commitback
36 | run: |
37 | git config --local user.email "github-actions[bot]@users.noreply.github.com"
38 | git config --local user.name "github-actions[bot]"
39 | git add --all
40 | git commit -m "chore(format): run black on ${{github.ref_name}}"
41 |
42 | - name: Create Pull Request
43 | if: steps.commitback.outcome == 'success'
44 | continue-on-error: true
45 | uses: peter-evans/create-pull-request@v5
46 | with:
47 | delete-branch: true
48 | body: "Automatically apply code formatter change"
49 | title: "chore(format): run black on ${{github.ref_name}}"
50 | commit-message: "chore(format): run black on ${{github.ref_name}}"
51 | branch: formatter-${{github.ref_name}}
52 |
--------------------------------------------------------------------------------
/.github/workflows/nextjs.yml:
--------------------------------------------------------------------------------
1 | # Sample workflow for building and deploying a Next.js site to GitHub Pages
2 | #
3 | # To get started with Next.js see: https://nextjs.org/docs/getting-started
4 | #
5 | name: Deploy Next.js site to Pages
6 |
7 | on:
8 | # Runs on pushes targeting the default branch
9 | push:
10 | branches: ["main"]
11 |
12 | # Allows you to run this workflow manually from the Actions tab
13 | workflow_dispatch:
14 |
15 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
16 | permissions:
17 | contents: read
18 | pages: write
19 | id-token: write
20 |
21 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
22 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
23 | concurrency:
24 | group: "pages"
25 | cancel-in-progress: false
26 |
27 | jobs:
28 | # Build job
29 | build:
30 | runs-on: ubuntu-latest
31 | steps:
32 | - name: Checkout
33 | uses: actions/checkout@v4
34 | - name: Detect package manager
35 | id: detect-package-manager
36 | run: |
37 | if [ -f "${{ github.workspace }}/yarn.lock" ]; then
38 | echo "manager=yarn" >> $GITHUB_OUTPUT
39 | echo "command=install" >> $GITHUB_OUTPUT
40 | echo "runner=yarn" >> $GITHUB_OUTPUT
41 | exit 0
42 | elif [ -f "${{ github.workspace }}/package.json" ]; then
43 | echo "manager=npm" >> $GITHUB_OUTPUT
44 | echo "command=ci" >> $GITHUB_OUTPUT
45 | echo "runner=npx --no-install" >> $GITHUB_OUTPUT
46 | exit 0
47 | else
48 | echo "Unable to determine package manager"
49 | exit 1
50 | fi
51 | - name: Setup Node
52 | uses: actions/setup-node@v4
53 | with:
54 | node-version: "20"
55 | cache: ${{ steps.detect-package-manager.outputs.manager }}
56 | - name: Setup Pages
57 | uses: actions/configure-pages@v5
58 | with:
59 | # Automatically inject basePath in your Next.js configuration file and disable
60 | # server side image optimization (https://nextjs.org/docs/api-reference/next/image#unoptimized).
61 | #
62 | # You may remove this line if you want to manage the configuration yourself.
63 | static_site_generator: next
64 | - name: Restore cache
65 | uses: actions/cache@v4
66 | with:
67 | path: |
68 | .next/cache
69 | # Generate a new cache whenever packages or source files change.
70 | key: ${{ runner.os }}-nextjs-${{ hashFiles('**/package-lock.json', '**/yarn.lock') }}-${{ hashFiles('**.[jt]s', '**.[jt]sx') }}
71 | # If source files changed but packages didn't, rebuild from a prior cache.
72 | restore-keys: |
73 | ${{ runner.os }}-nextjs-${{ hashFiles('**/package-lock.json', '**/yarn.lock') }}-
74 | - name: Install dependencies
75 | run: ${{ steps.detect-package-manager.outputs.manager }} ${{ steps.detect-package-manager.outputs.command }}
76 | - name: Build with Next.js
77 | run: ${{ steps.detect-package-manager.outputs.runner }} next build
78 | - name: Upload artifact
79 | uses: actions/upload-pages-artifact@v3
80 | with:
81 | path: ./out
82 |
83 | # Deployment job
84 | deploy:
85 | environment:
86 | name: github-pages
87 | url: ${{ steps.deployment.outputs.page_url }}
88 | runs-on: ubuntu-latest
89 | needs: build
90 | steps:
91 | - name: Deploy to GitHub Pages
92 | id: deployment
93 | uses: actions/deploy-pages@v4
94 |
--------------------------------------------------------------------------------
/.github/workflows/unittest.yml:
--------------------------------------------------------------------------------
1 | name: Test preprocess and extract
2 | on: [push, pull_request]
3 | jobs:
4 | build:
5 | runs-on: ${{ matrix.os }}
6 | strategy:
7 | matrix:
8 | python-version: ["3.9", "3.10"]
9 | os: [ubuntu-latest]
10 | fail-fast: true
11 |
12 | steps:
13 | - uses: actions/checkout@master
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 | - name: Install dependencies
19 | run: |
20 | sudo apt update
21 | sudo apt -y install ffmpeg
22 | python -m pip install --upgrade pip
23 | python -m pip install --upgrade setuptools
24 | python -m pip install --upgrade wheel
25 | pip install torch torchvision torchaudio
26 | pip install -r requirements.txt
27 | python rvc_cli.py prerequisites --models True
28 | - name: Test Preprocess
29 | run: |
30 | python rvc_cli.py preprocess --model_name "Evaluate" --dataset_path "logs/mute/sliced_audios" --sample_rate 48000
31 | - name: Test Extract
32 | run: |
33 | python rvc_cli.py extract --model_name "Evaluate" --sample_rate 48000
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore logs folder
2 | # logs
3 |
4 | # Ignore compiled executables
5 | *.exe
6 |
7 | # Ignore model files
8 | *.pt
9 | *.onnx
10 | *.pth
11 | *.index
12 | *.bin
13 | *.ckpt
14 | *.txt
15 |
16 | # Ignore Python bytecode files
17 | *.pyc
18 |
19 | # Ignore environment and virtual environment directories
20 | env/
21 | venv/
22 |
23 | # Ignore cached files
24 | .cache/
25 | docs/.next
26 | node_modules
27 |
28 | # Ignore specific project directories
29 | /tracks/
30 | /lyrics/
31 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.languageServer": "Pylance",
3 | "python.analysis.diagnosticSeverityOverrides": {
4 | "reportShadowedImports": "none"
5 | },
6 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 🚀 RVC + UVR = A perfect set of tools for voice cloning, easily and free!
2 |
3 | [](https://colab.research.google.com/github/iahispano/applio/blob/master/assets/Applio_NoUI.ipynb)
4 |
5 | > [!NOTE]
6 | > Issues are not handled in this repository due to time constraints. For questions or discussions, feel free to join [AI Hispano on Discord](https://discord.gg/iahispano). If you're experiencing a technical issue, please report it on [Applio's Issues page](https://github.com/IAHispano/Applio/issues) and specify that the problem occurs via CLI.
7 | > This repository is partially abandoned, use Applio for more constant updates (You can use CLI through the `core.py` file).
8 |
9 | ### Installation
10 |
11 | Ensure that you have the necessary Python packages installed by following these steps (Python 3.9 is recommended):
12 |
13 | #### Windows
14 |
15 | Execute the [install.bat](./install.bat) file to activate a Conda environment. Afterward, launch the application using `env/python.exe rvc_cli.py` instead of the conventional `python rvc_cli.py` command.
16 |
17 | #### Linux
18 |
19 | ```bash
20 | chmod +x install.sh
21 | ./install.sh
22 | ```
23 |
24 | ### Getting Started
25 |
26 | For detailed information and command-line options, refer to the help command:
27 |
28 | ```bash
29 | python rvc_cli.py -h
30 | python uvr_cli.py -h
31 | ```
32 |
33 | This command provides a clear overview of the available modes and their corresponding parameters, facilitating effective utilization of the RVC CLI, but if you need more information, you can check the [documentation](https://rvc-cli.pages.dev/).
34 |
35 | ### References
36 |
37 | The RVC CLI builds upon the foundations of the following projects:
38 |
39 | - **Vocoders:**
40 |
41 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) by jik876
42 | - [Vocos](https://github.com/gemelo-ai/vocos) by gemelo-ai
43 | - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA
44 | - [BigVSAN](https://github.com/sony/bigvsan) by sony
45 | - [vocoders](https://github.com/reppy4620/vocoders) by reppy4620
46 | - [vocoder](https://github.com/fishaudio/vocoder) by fishaudio
47 |
48 | - **VC Clients:**
49 |
50 | - [Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) by RVC-Project
51 | - [So-Vits-SVC](https://github.com/svc-develop-team/so-vits-svc) by svc-develop-team
52 | - [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork) by Mangio621
53 | - [VITS](https://github.com/jaywalnut310/vits) by jaywalnut310
54 | - [Harmonify](https://huggingface.co/Eempostor/Harmonify) by Eempostor
55 | - [rvc-trainer](https://github.com/thepowerfuldeez/rvc-trainer) by thepowerfuldeez
56 |
57 | - **Pitch Extractors:**
58 |
59 | - [RMVPE](https://github.com/Dream-High/RMVPE) by Dream-High
60 | - [torchfcpe](https://github.com/CNChTu/FCPE) by CNChTu
61 | - [torchcrepe](https://github.com/maxrmorrison/torchcrepe) by maxrmorrison
62 | - [anyf0](https://github.com/SoulMelody/anyf0) by SoulMelody
63 |
64 | - **Other:**
65 | - [FAIRSEQ](https://github.com/facebookresearch/fairseq) by facebookresearch
66 | - [FAISS](https://github.com/facebookresearch/faiss) by facebookresearch
67 | - [ContentVec](https://github.com/auspicious3000/contentvec/) by auspicious3000
68 | - [audio-slicer](https://github.com/openvpi/audio-slicer) by openvpi
69 | - [python-audio-separator](https://github.com/karaokenerds/python-audio-separator) by karaokenerds
70 | - [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui) by Anjok07
71 |
72 | We acknowledge and appreciate the contributions of the respective authors and communities involved in these projects.
73 |
--------------------------------------------------------------------------------
/docs/components/counters.module.css:
--------------------------------------------------------------------------------
1 | .counter {
2 | border: 1px solid #ccc;
3 | border-radius: 5px;
4 | padding: 2px 6px;
5 | margin: 12px 0 0;
6 | }
7 |
--------------------------------------------------------------------------------
/docs/components/counters.tsx:
--------------------------------------------------------------------------------
1 | // Example from https://beta.reactjs.org/learn
2 |
3 | import { useState } from 'react'
4 | import styles from './counters.module.css'
5 |
6 | function MyButton() {
7 | const [count, setCount] = useState(0)
8 |
9 | function handleClick() {
10 | setCount(count + 1)
11 | }
12 |
13 | return (
14 |
15 |
18 |
19 | )
20 | }
21 |
22 | export default function MyApp() {
23 | return
24 | }
25 |
--------------------------------------------------------------------------------
/docs/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/docs/favicon.ico
--------------------------------------------------------------------------------
/docs/next-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 | ///
3 |
4 | // NOTE: This file should not be edited
5 | // see https://nextjs.org/docs/basic-features/typescript for more information.
6 |
--------------------------------------------------------------------------------
/docs/next.config.js:
--------------------------------------------------------------------------------
1 | const withNextra = require('nextra')({
2 | theme: 'nextra-theme-docs',
3 | themeConfig: './theme.config.tsx',
4 | })
5 |
6 | module.exports = withNextra()
7 |
--------------------------------------------------------------------------------
/docs/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "rvc_cli",
3 | "version": "0.0.2",
4 | "description": "🚀 RVC + UVR = A perfect set of tools for voice cloning, easily and free!",
5 | "scripts": {
6 | "dev": "next dev",
7 | "build": "next build",
8 | "start": "next start"
9 | },
10 | "dependencies": {
11 | "next": "^13.0.6",
12 | "nextra": "latest",
13 | "nextra-theme-docs": "latest",
14 | "react": "^18.2.0",
15 | "react-dom": "^18.2.0"
16 | },
17 | "devDependencies": {
18 | "@types/node": "18.11.10",
19 | "typescript": "^4.9.3"
20 | }
21 | }
--------------------------------------------------------------------------------
/docs/pages/Usage/_meta.json:
--------------------------------------------------------------------------------
1 | {
2 | "rvc": "RVC",
3 | "uvr": "UVR"
4 | }
--------------------------------------------------------------------------------
/docs/pages/_meta.json:
--------------------------------------------------------------------------------
1 | {
2 | "index": "Introduction",
3 | "installation": "Installation",
4 | "contact": {
5 | "title": "Contact ↗",
6 | "type": "page",
7 | "href": "https://twitter.com/blaisewf",
8 | "newWindow": true
9 | }
10 | }
--------------------------------------------------------------------------------
/docs/pages/index.mdx:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | ### References
4 |
5 | The RVC CLI builds upon the foundations of the following projects:
6 |
7 | - **Vocoders:**
8 |
9 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) by jik876
10 | - [Vocos](https://github.com/gemelo-ai/vocos) by gemelo-ai
11 | - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA
12 | - [BigVSAN](https://github.com/sony/bigvsan) by sony
13 | - [vocoders](https://github.com/reppy4620/vocoders) by reppy4620
14 | - [vocoder](https://github.com/fishaudio/vocoder) by fishaudio
15 |
16 | - **VC Clients:**
17 |
18 | - [Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) by RVC-Project
19 | - [So-Vits-SVC](https://github.com/svc-develop-team/so-vits-svc) by svc-develop-team
20 | - [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork) by Mangio621
21 | - [VITS](https://github.com/jaywalnut310/vits) by jaywalnut310
22 | - [Harmonify](https://huggingface.co/Eempostor/Harmonify) by Eempostor
23 | - [rvc-trainer](https://github.com/thepowerfuldeez/rvc-trainer) by thepowerfuldeez
24 |
25 | - **Pitch Extractors:**
26 |
27 | - [RMVPE](https://github.com/Dream-High/RMVPE) by Dream-High
28 | - [torchfcpe](https://github.com/CNChTu/FCPE) by CNChTu
29 | - [torchcrepe](https://github.com/maxrmorrison/torchcrepe) by maxrmorrison
30 | - [anyf0](https://github.com/SoulMelody/anyf0) by SoulMelody
31 |
32 | - **Other:**
33 | - [FAIRSEQ](https://github.com/facebookresearch/fairseq) by facebookresearch
34 | - [FAISS](https://github.com/facebookresearch/faiss) by facebookresearch
35 | - [ContentVec](https://github.com/auspicious3000/contentvec/) by auspicious3000
36 | - [audio-slicer](https://github.com/openvpi/audio-slicer) by openvpi
37 | - [python-audio-separator](https://github.com/karaokenerds/python-audio-separator) by karaokenerds
38 | - [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui) by Anjok07
39 |
40 | We acknowledge and appreciate the contributions of the respective authors and communities involved in these projects.
41 |
--------------------------------------------------------------------------------
/docs/pages/installation.mdx:
--------------------------------------------------------------------------------
1 | ## Installation Guides
2 |
3 | ### Windows
4 |
5 | #### 1. **Install Dependencies:**
6 | - Open a Command Prompt or PowerShell window.
7 | - Navigate to the directory containing the `install.bat` file.
8 | - Run the `install.bat` file. This will create a Conda environment and install all necessary dependencies.
9 | #### 2. **Run the RVC CLI:**
10 | - After installing requirements, you can start using the CLI. However, make sure you run the `prerequisites` command first to download additional models and executables.
11 | - Launch the application using:
12 | ```bash
13 | env/python.exe cli.py prerequisites
14 | ```
15 | - Once the prerequisites are downloaded, you can use the RVC CLI as usual.
16 | ### macOS & Linux
17 |
18 | #### 1. **Install Python (3.9 or 3.10):**
19 | - If you don't have Python installed, use your system's package manager. For example, on Ubuntu:
20 | ```bash
21 | sudo apt update
22 | sudo apt install python3
23 | ```
24 | #### 2. **Create a Virtual Environment:**
25 | - Open a terminal window.
26 | - Navigate to the directory containing the `rvc_cli.py` file.
27 | - Run the following command to create a virtual environment:
28 | ```bash
29 | python3 -m venv venv
30 | ```
31 | #### 3. **Activate the Environment:**
32 | - Run the following command to activate the virtual environment:
33 | ```bash
34 | source venv/bin/activate
35 | ```
36 | #### 4. **Install Dependencies:**
37 | - Run the following command to install all necessary dependencies:
38 | ```bash
39 | pip install -r requirements.txt
40 | ```
41 | #### 5. **Run the RVC CLI with `prerequisites`:**
42 | - Launch the application and run the `prerequisites` command to download additional models and executables.
43 | ```bash
44 | python cli.py prerequisites
45 | ```
46 | #### 6. **Run the RVC CLI:**
47 | - After the prerequisites are downloaded, you can use the RVC CLI as usual.
48 |
49 | **Note:**
50 |
51 | - For Linux, you may need to install additional packages depending on your system.
52 | - If you face any errors during the installation process, consult the respective package managers' documentation for further instructions.
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/docs/theme.config.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import { DocsThemeConfig, useConfig } from "nextra-theme-docs";
3 |
4 | const config: DocsThemeConfig = {
5 | logo: 'RVC CLI',
6 | search: {
7 | placeholder: "What are you looking for? 🧐",
8 | },
9 | project: {
10 | link: "https://github.com/blaisewf/rvc_cli",
11 | },
12 | chat: {
13 | link: "https://discord.gg/iahispano",
14 | },
15 | docsRepositoryBase: "https://github.com/blaisewf/rvc_cli/tree/main/docs",
16 | footer: {
17 | text: (
18 |
19 | made w ❤️ by blaisewf
20 |
21 | ),
22 | },
23 | nextThemes: {
24 | defaultTheme: "dark",
25 | },
26 | feedback: {
27 | content: "Do you think we should improve something? Let us know!",
28 |
29 | },
30 | editLink: {
31 | component: null,
32 | },
33 | faviconGlyph: "favicon.ico",
34 | logoLink: "/",
35 | primaryHue: 317,
36 | head: () => {
37 | const { frontMatter } = useConfig();
38 |
39 | return (
40 | <>
41 |
42 |
43 |
44 |
45 |
49 |
53 |
54 |
55 |
56 |
60 |
61 |
62 | >
63 | );
64 | },
65 | useNextSeoProps() {
66 | return {
67 | titleTemplate: `%s - RVC CLI`,
68 | };
69 | },
70 | };
71 |
72 | export default config;
--------------------------------------------------------------------------------
/docs/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es5",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "skipLibCheck": true,
7 | "strict": false,
8 | "forceConsistentCasingInFileNames": true,
9 | "noEmit": true,
10 | "incremental": true,
11 | "esModuleInterop": true,
12 | "module": "esnext",
13 | "moduleResolution": "node",
14 | "resolveJsonModule": true,
15 | "isolatedModules": true,
16 | "jsx": "preserve"
17 | },
18 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"],
19 | "exclude": ["node_modules"]
20 | }
21 |
--------------------------------------------------------------------------------
/install.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal enabledelayedexpansion
3 | title RVC CLI Installer
4 |
5 | echo Welcome to the RVC CLI Installer!
6 | echo.
7 |
8 | set "INSTALL_DIR=%cd%"
9 | set "MINICONDA_DIR=%UserProfile%\Miniconda3"
10 | set "ENV_DIR=%INSTALL_DIR%\env"
11 | set "MINICONDA_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_24.7.1-0-Windows-x86_64.exe"
12 | set "CONDA_EXE=%MINICONDA_DIR%\Scripts\conda.exe"
13 |
14 | set "startTime=%TIME%"
15 | set "startHour=%TIME:~0,2%"
16 | set "startMin=%TIME:~3,2%"
17 | set "startSec=%TIME:~6,2%"
18 | set /a startHour=1%startHour% - 100
19 | set /a startMin=1%startMin% - 100
20 | set /a startSec=1%startSec% - 100
21 | set /a startTotal = startHour*3600 + startMin*60 + startSec
22 |
23 | call :cleanup
24 | call :install_miniconda
25 | call :create_conda_env
26 | call :install_dependencies
27 |
28 | set "endTime=%TIME%"
29 | set "endHour=%TIME:~0,2%"
30 | set "endMin=%TIME:~3,2%"
31 | set "endSec=%TIME:~6,2%"
32 | set /a endHour=1%endHour% - 100
33 | set /a endMin=1%endMin% - 100
34 | set /a endSec=1%endSec% - 100
35 | set /a endTotal = endHour*3600 + endMin*60 + endSec
36 | set /a elapsed = endTotal - startTotal
37 | if %elapsed% lss 0 set /a elapsed += 86400
38 | set /a hours = elapsed / 3600
39 | set /a minutes = (elapsed %% 3600) / 60
40 | set /a seconds = elapsed %% 60
41 |
42 | echo Installation time: %hours% hours, %minutes% minutes, %seconds% seconds.
43 | echo.
44 |
45 | echo RVC CLI has been installed successfully!
46 | echo.
47 | pause
48 | exit /b 0
49 |
50 | :cleanup
51 | echo Cleaning up unnecessary files...
52 | for %%F in (Makefile Dockerfile docker-compose.yaml *.sh) do if exist "%%F" del "%%F"
53 | echo Cleanup complete.
54 | echo.
55 | exit /b 0
56 |
57 | :install_miniconda
58 | if exist "%CONDA_EXE%" (
59 | echo Miniconda already installed. Skipping installation.
60 | exit /b 0
61 | )
62 |
63 | echo Miniconda not found. Starting download and installation...
64 | powershell -Command "& {Invoke-WebRequest -Uri '%MINICONDA_URL%' -OutFile 'miniconda.exe'}"
65 | if not exist "miniconda.exe" goto :download_error
66 |
67 | start /wait "" miniconda.exe /InstallationType=JustMe /RegisterPython=0 /S /D=%MINICONDA_DIR%
68 | if errorlevel 1 goto :install_error
69 |
70 | del miniconda.exe
71 | echo Miniconda installation complete.
72 | echo.
73 | exit /b 0
74 |
75 | :create_conda_env
76 | echo Creating Conda environment...
77 | call "%MINICONDA_DIR%\_conda.exe" create --no-shortcuts -y -k --prefix "%ENV_DIR%" python=3.10
78 | if errorlevel 1 goto :error
79 | echo Conda environment created successfully.
80 | echo.
81 |
82 | if exist "%ENV_DIR%\python.exe" (
83 | echo Installing uv package installer...
84 | "%ENV_DIR%\python.exe" -m pip install uv
85 | if errorlevel 1 goto :error
86 | echo uv installation complete.
87 | echo.
88 | )
89 | exit /b 0
90 |
91 | :install_dependencies
92 | echo Installing dependencies...
93 | call "%MINICONDA_DIR%\condabin\conda.bat" activate "%ENV_DIR%" || goto :error
94 | uv pip install --upgrade setuptools || goto :error
95 | uv pip install torch==2.7.0 torchvision torchaudio==2.7.0 --upgrade --index-url https://download.pytorch.org/whl/cu128 || goto :error
96 | uv pip install -r "%INSTALL_DIR%\requirements.txt" || goto :error
97 | call "%MINICONDA_DIR%\condabin\conda.bat" deactivate
98 | echo Dependencies installation complete.
99 | echo.
100 | exit /b 0
101 |
102 | :download_error
103 | echo Download failed. Please check your internet connection and try again.
104 | goto :error
105 |
106 | :install_error
107 | echo Miniconda installation failed.
108 | goto :error
109 |
110 | :error
111 | echo An error occurred during installation. Please check the output above for details.
112 | pause
113 | exit /b 1
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python3 -m venv venv
4 | source venv/bin/activate
5 |
6 | pip install -r requirements.txt
7 |
--------------------------------------------------------------------------------
/logs/mute/extracted/mute.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/extracted/mute.npy
--------------------------------------------------------------------------------
/logs/mute/f0/mute.wav.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/f0/mute.wav.npy
--------------------------------------------------------------------------------
/logs/mute/f0_voiced/mute.wav.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/f0_voiced/mute.wav.npy
--------------------------------------------------------------------------------
/logs/mute/sliced_audios/mute32000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute32000.wav
--------------------------------------------------------------------------------
/logs/mute/sliced_audios/mute40000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute40000.wav
--------------------------------------------------------------------------------
/logs/mute/sliced_audios/mute44100.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute44100.wav
--------------------------------------------------------------------------------
/logs/mute/sliced_audios/mute48000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute48000.wav
--------------------------------------------------------------------------------
/logs/mute/sliced_audios_16k/mute.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios_16k/mute.wav
--------------------------------------------------------------------------------
/logs/reference/ref32000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000.wav
--------------------------------------------------------------------------------
/logs/reference/ref32000_f0c.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_f0c.npy
--------------------------------------------------------------------------------
/logs/reference/ref32000_f0f.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_f0f.npy
--------------------------------------------------------------------------------
/logs/reference/ref32000_feats.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_feats.npy
--------------------------------------------------------------------------------
/logs/reference/ref40000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000.wav
--------------------------------------------------------------------------------
/logs/reference/ref40000_f0c.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_f0c.npy
--------------------------------------------------------------------------------
/logs/reference/ref40000_f0f.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_f0f.npy
--------------------------------------------------------------------------------
/logs/reference/ref40000_feats.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_feats.npy
--------------------------------------------------------------------------------
/logs/reference/ref48000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000.wav
--------------------------------------------------------------------------------
/logs/reference/ref48000_f0c.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_f0c.npy
--------------------------------------------------------------------------------
/logs/reference/ref48000_f0f.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_f0f.npy
--------------------------------------------------------------------------------
/logs/reference/ref48000_feats.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_feats.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core dependencies
2 | pip>=23.3; sys_platform == 'darwin'
3 | wheel; sys_platform == 'darwin'
4 | PyYAML; sys_platform == 'darwin'
5 | numpy==1.26.4
6 | requests>=2.31.0,<2.32.0
7 | tqdm
8 | wget
9 |
10 | # Audio processing
11 | ffmpeg-python>=0.2.0
12 | faiss-cpu==1.7.3
13 | librosa==0.9.2
14 | scipy==1.11.1
15 | soundfile==0.12.1
16 | noisereduce
17 | pedalboard
18 | stftpitchshift
19 | soxr
20 |
21 | # Machine learning and deep learning
22 | omegaconf>=2.0.6; sys_platform == 'darwin'
23 | numba; sys_platform == 'linux'
24 | numba==0.61.0; sys_platform == 'darwin' or sys_platform == 'win32'
25 | torch==2.7.0
26 | torchaudio==2.7.0
27 | torchvision
28 | torchcrepe==0.0.23
29 | torchfcpe
30 | einops
31 | transformers==4.44.2
32 |
33 | # Visualization and UI
34 | matplotlib==3.7.2
35 | tensorboard
36 | gradio==5.23.1
37 |
38 | # Miscellaneous utilities
39 | certifi>=2023.07.22; sys_platform == 'darwin'
40 | antlr4-python3-runtime==4.8; sys_platform == 'darwin'
41 | tensorboardX
42 | edge-tts==6.1.9
43 | pypresence
44 | beautifulsoup4
45 |
46 | # UVR
47 | samplerate==0.1.0
48 | six>=1.16
49 | pydub>=0.25
50 | onnx>=1.14
51 | onnx2torch>=1.5
52 | onnxruntime>=1.17; sys_platform != 'darwin'
53 | onnxruntime-gpu>=1.17; sys_platform != 'darwin'
54 | julius>=0.2
55 | diffq>=0.2
56 | ml_collections
57 | resampy>=0.4
58 | beartype==0.18.5
59 | rotary-embedding-torch==0.6.1
--------------------------------------------------------------------------------
/rvc/configs/32000.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "seed": 1234,
5 | "learning_rate": 1e-4,
6 | "betas": [0.8, 0.99],
7 | "eps": 1e-9,
8 | "lr_decay": 0.999875,
9 | "segment_size": 12800,
10 | "c_mel": 45,
11 | "c_kl": 1.0
12 | },
13 | "data": {
14 | "max_wav_value": 32768.0,
15 | "sample_rate": 32000,
16 | "filter_length": 1024,
17 | "hop_length": 320,
18 | "win_length": 1024,
19 | "n_mel_channels": 80,
20 | "mel_fmin": 0.0,
21 | "mel_fmax": null
22 | },
23 | "model": {
24 | "inter_channels": 192,
25 | "hidden_channels": 192,
26 | "filter_channels": 768,
27 | "text_enc_hidden_dim": 768,
28 | "n_heads": 2,
29 | "n_layers": 6,
30 | "kernel_size": 3,
31 | "p_dropout": 0,
32 | "resblock": "1",
33 | "resblock_kernel_sizes": [3,7,11],
34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
35 | "upsample_rates": [10,8,2,2],
36 | "upsample_initial_channel": 512,
37 | "upsample_kernel_sizes": [20,16,4,4],
38 | "use_spectral_norm": false,
39 | "gin_channels": 256,
40 | "spk_embed_dim": 109
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/rvc/configs/40000.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "seed": 1234,
5 | "learning_rate": 1e-4,
6 | "betas": [0.8, 0.99],
7 | "eps": 1e-9,
8 | "lr_decay": 0.999875,
9 | "segment_size": 12800,
10 | "c_mel": 45,
11 | "c_kl": 1.0
12 | },
13 | "data": {
14 | "max_wav_value": 32768.0,
15 | "sample_rate": 40000,
16 | "filter_length": 2048,
17 | "hop_length": 400,
18 | "win_length": 2048,
19 | "n_mel_channels": 125,
20 | "mel_fmin": 0.0,
21 | "mel_fmax": null
22 | },
23 | "model": {
24 | "inter_channels": 192,
25 | "hidden_channels": 192,
26 | "filter_channels": 768,
27 | "text_enc_hidden_dim": 768,
28 | "n_heads": 2,
29 | "n_layers": 6,
30 | "kernel_size": 3,
31 | "p_dropout": 0,
32 | "resblock": "1",
33 | "resblock_kernel_sizes": [3,7,11],
34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
35 | "upsample_rates": [10,10,2,2],
36 | "upsample_initial_channel": 512,
37 | "upsample_kernel_sizes": [16,16,4,4],
38 | "use_spectral_norm": false,
39 | "gin_channels": 256,
40 | "spk_embed_dim": 109
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/rvc/configs/48000.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "seed": 1234,
5 | "learning_rate": 1e-4,
6 | "betas": [0.8, 0.99],
7 | "eps": 1e-9,
8 | "lr_decay": 0.999875,
9 | "segment_size": 17280,
10 | "c_mel": 45,
11 | "c_kl": 1.0
12 | },
13 | "data": {
14 | "max_wav_value": 32768.0,
15 | "sample_rate": 48000,
16 | "filter_length": 2048,
17 | "hop_length": 480,
18 | "win_length": 2048,
19 | "n_mel_channels": 128,
20 | "mel_fmin": 0.0,
21 | "mel_fmax": null
22 | },
23 | "model": {
24 | "inter_channels": 192,
25 | "hidden_channels": 192,
26 | "filter_channels": 768,
27 | "text_enc_hidden_dim": 768,
28 | "n_heads": 2,
29 | "n_layers": 6,
30 | "kernel_size": 3,
31 | "p_dropout": 0,
32 | "resblock": "1",
33 | "resblock_kernel_sizes": [3,7,11],
34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
35 | "upsample_rates": [12,10,2,2],
36 | "upsample_initial_channel": 512,
37 | "upsample_kernel_sizes": [24,20,4,4],
38 | "use_spectral_norm": false,
39 | "gin_channels": 256,
40 | "spk_embed_dim": 109
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/rvc/configs/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 |
5 | version_config_paths = [
6 | os.path.join("48000.json"),
7 | os.path.join("40000.json"),
8 | os.path.join("32000.json"),
9 | ]
10 |
11 |
12 | def singleton(cls):
13 | instances = {}
14 |
15 | def get_instance(*args, **kwargs):
16 | if cls not in instances:
17 | instances[cls] = cls(*args, **kwargs)
18 | return instances[cls]
19 |
20 | return get_instance
21 |
22 |
23 | @singleton
24 | class Config:
25 | def __init__(self):
26 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
27 | self.gpu_name = (
28 | torch.cuda.get_device_name(int(self.device.split(":")[-1]))
29 | if self.device.startswith("cuda")
30 | else None
31 | )
32 | self.json_config = self.load_config_json()
33 | self.gpu_mem = None
34 | self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
35 |
36 | def load_config_json(self):
37 | configs = {}
38 | for config_file in version_config_paths:
39 | config_path = os.path.join("rvc", "configs", config_file)
40 | with open(config_path, "r") as f:
41 | configs[config_file] = json.load(f)
42 | return configs
43 |
44 | def device_config(self):
45 | if self.device.startswith("cuda"):
46 | self.set_cuda_config()
47 | else:
48 | self.device = "cpu"
49 |
50 | # Configuration for 6GB GPU memory
51 | x_pad, x_query, x_center, x_max = (1, 6, 38, 41)
52 | if self.gpu_mem is not None and self.gpu_mem <= 4:
53 | # Configuration for 5GB GPU memory
54 | x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
55 |
56 | return x_pad, x_query, x_center, x_max
57 |
58 | def set_cuda_config(self):
59 | i_device = int(self.device.split(":")[-1])
60 | self.gpu_name = torch.cuda.get_device_name(i_device)
61 | self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (
62 | 1024**3
63 | )
64 |
65 |
66 | def max_vram_gpu(gpu):
67 | if torch.cuda.is_available():
68 | gpu_properties = torch.cuda.get_device_properties(gpu)
69 | total_memory_gb = round(gpu_properties.total_memory / 1024 / 1024 / 1024)
70 | return total_memory_gb
71 | else:
72 | return "8"
73 |
74 |
75 | def get_gpu_info():
76 | ngpu = torch.cuda.device_count()
77 | gpu_infos = []
78 | if torch.cuda.is_available() or ngpu != 0:
79 | for i in range(ngpu):
80 | gpu_name = torch.cuda.get_device_name(i)
81 | mem = int(
82 | torch.cuda.get_device_properties(i).total_memory / 1024 / 1024 / 1024
83 | + 0.4
84 | )
85 | gpu_infos.append(f"{i}: {gpu_name} ({mem} GB)")
86 | if len(gpu_infos) > 0:
87 | gpu_info = "\n".join(gpu_infos)
88 | else:
89 | gpu_info = "Unfortunately, there is no compatible GPU available to support your training."
90 | return gpu_info
91 |
92 |
93 | def get_number_of_gpus():
94 | if torch.cuda.is_available():
95 | num_gpus = torch.cuda.device_count()
96 | return "-".join(map(str, range(num_gpus)))
97 | else:
98 | return "-"
99 |
--------------------------------------------------------------------------------
/rvc/lib/algorithm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/lib/algorithm/__init__.py
--------------------------------------------------------------------------------
/rvc/lib/algorithm/commons.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 |
5 | def init_weights(m, mean=0.0, std=0.01):
6 | """
7 | Initialize the weights of a module.
8 |
9 | Args:
10 | m: The module to initialize.
11 | mean: The mean of the normal distribution.
12 | std: The standard deviation of the normal distribution.
13 | """
14 | classname = m.__class__.__name__
15 | if classname.find("Conv") != -1:
16 | m.weight.data.normal_(mean, std)
17 |
18 |
19 | def get_padding(kernel_size, dilation=1):
20 | """
21 | Calculate the padding needed for a convolution.
22 |
23 | Args:
24 | kernel_size: The size of the kernel.
25 | dilation: The dilation of the convolution.
26 | """
27 | return int((kernel_size * dilation - dilation) / 2)
28 |
29 |
30 | def convert_pad_shape(pad_shape):
31 | """
32 | Convert the pad shape to a list of integers.
33 |
34 | Args:
35 | pad_shape: The pad shape..
36 | """
37 | l = pad_shape[::-1]
38 | pad_shape = [item for sublist in l for item in sublist]
39 | return pad_shape
40 |
41 |
42 | def slice_segments(
43 | x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2
44 | ):
45 | """
46 | Slice segments from a tensor, handling tensors with different numbers of dimensions.
47 |
48 | Args:
49 | x (torch.Tensor): The tensor to slice.
50 | ids_str (torch.Tensor): The starting indices of the segments.
51 | segment_size (int, optional): The size of each segment. Defaults to 4.
52 | dim (int, optional): The dimension to slice across (2D or 3D tensors). Defaults to 2.
53 | """
54 | if dim == 2:
55 | ret = torch.zeros_like(x[:, :segment_size])
56 | elif dim == 3:
57 | ret = torch.zeros_like(x[:, :, :segment_size])
58 |
59 | for i in range(x.size(0)):
60 | idx_str = ids_str[i].item()
61 | idx_end = idx_str + segment_size
62 | if dim == 2:
63 | ret[i] = x[i, idx_str:idx_end]
64 | else:
65 | ret[i] = x[i, :, idx_str:idx_end]
66 |
67 | return ret
68 |
69 |
70 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
71 | """
72 | Randomly slice segments from a tensor.
73 |
74 | Args:
75 | x: The tensor to slice.
76 | x_lengths: The lengths of the sequences.
77 | segment_size: The size of each segment.
78 | """
79 | b, d, t = x.size()
80 | if x_lengths is None:
81 | x_lengths = t
82 | ids_str_max = x_lengths - segment_size + 1
83 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
84 | ret = slice_segments(x, ids_str, segment_size, dim=3)
85 | return ret, ids_str
86 |
87 |
88 | @torch.jit.script
89 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
90 | """
91 | Fused add tanh sigmoid multiply operation.
92 |
93 | Args:
94 | input_a: The first input tensor.
95 | input_b: The second input tensor.
96 | n_channels: The number of channels.
97 | """
98 | n_channels_int = n_channels[0]
99 | in_act = input_a + input_b
100 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
101 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
102 | acts = t_act * s_act
103 | return acts
104 |
105 |
106 | def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
107 | """
108 | Generate a sequence mask.
109 |
110 | Args:
111 | length: The lengths of the sequences.
112 | max_length: The maximum length of the sequences.
113 | """
114 | if max_length is None:
115 | max_length = length.max()
116 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
117 | return x.unsqueeze(0) < length.unsqueeze(1)
118 |
119 |
120 | def grad_norm(parameters, norm_type: float = 2.0):
121 | """
122 | Calculates norm of parameter gradients
123 |
124 | Args:
125 | parameters: The list of parameters to clip.
126 | norm_type: The type of norm to use for clipping.
127 | """
128 | if isinstance(parameters, torch.Tensor):
129 | parameters = [parameters]
130 |
131 | parameters = [p for p in parameters if p.grad is not None]
132 |
133 | if not parameters:
134 | return 0.0
135 |
136 | return torch.linalg.vector_norm(
137 | torch.stack([p.grad.norm(norm_type) for p in parameters]), ord=norm_type
138 | ).item()
139 |
--------------------------------------------------------------------------------
/rvc/lib/algorithm/discriminators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.checkpoint import checkpoint
3 | from torch.nn.utils.parametrizations import spectral_norm, weight_norm
4 |
5 | from rvc.lib.algorithm.commons import get_padding
6 | from rvc.lib.algorithm.residuals import LRELU_SLOPE
7 |
8 |
9 | class MultiPeriodDiscriminator(torch.nn.Module):
10 | """
11 | Multi-period discriminator.
12 |
13 | This class implements a multi-period discriminator, which is used to
14 | discriminate between real and fake audio signals. The discriminator
15 | is composed of a series of convolutional layers that are applied to
16 | the input signal at different periods.
17 |
18 | Args:
19 | use_spectral_norm (bool): Whether to use spectral normalization.
20 | Defaults to False.
21 | """
22 |
23 | def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False):
24 | super().__init__()
25 | periods = [2, 3, 5, 7, 11, 17, 23, 37]
26 | self.checkpointing = checkpointing
27 | self.discriminators = torch.nn.ModuleList(
28 | [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
29 | + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods]
30 | )
31 |
32 | def forward(self, y, y_hat):
33 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
34 | for d in self.discriminators:
35 | if self.training and self.checkpointing:
36 | y_d_r, fmap_r = checkpoint(d, y, use_reentrant=False)
37 | y_d_g, fmap_g = checkpoint(d, y_hat, use_reentrant=False)
38 | else:
39 | y_d_r, fmap_r = d(y)
40 | y_d_g, fmap_g = d(y_hat)
41 | y_d_rs.append(y_d_r)
42 | y_d_gs.append(y_d_g)
43 | fmap_rs.append(fmap_r)
44 | fmap_gs.append(fmap_g)
45 |
46 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
47 |
48 |
49 | class DiscriminatorS(torch.nn.Module):
50 | """
51 | Discriminator for the short-term component.
52 |
53 | This class implements a discriminator for the short-term component
54 | of the audio signal. The discriminator is composed of a series of
55 | convolutional layers that are applied to the input signal.
56 | """
57 |
58 | def __init__(self, use_spectral_norm: bool = False):
59 | super().__init__()
60 |
61 | norm_f = spectral_norm if use_spectral_norm else weight_norm
62 | self.convs = torch.nn.ModuleList(
63 | [
64 | norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)),
65 | norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
66 | norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
67 | norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
68 | norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
69 | norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2)),
70 | ]
71 | )
72 | self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
73 | self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
74 |
75 | def forward(self, x):
76 | fmap = []
77 | for conv in self.convs:
78 | x = self.lrelu(conv(x))
79 | fmap.append(x)
80 | x = self.conv_post(x)
81 | fmap.append(x)
82 | x = torch.flatten(x, 1, -1)
83 | return x, fmap
84 |
85 |
86 | class DiscriminatorP(torch.nn.Module):
87 | """
88 | Discriminator for the long-term component.
89 |
90 | This class implements a discriminator for the long-term component
91 | of the audio signal. The discriminator is composed of a series of
92 | convolutional layers that are applied to the input signal at a given
93 | period.
94 |
95 | Args:
96 | period (int): Period of the discriminator.
97 | kernel_size (int): Kernel size of the convolutional layers. Defaults to 5.
98 | stride (int): Stride of the convolutional layers. Defaults to 3.
99 | use_spectral_norm (bool): Whether to use spectral normalization. Defaults to False.
100 | """
101 |
102 | def __init__(
103 | self,
104 | period: int,
105 | kernel_size: int = 5,
106 | stride: int = 3,
107 | use_spectral_norm: bool = False,
108 | ):
109 | super().__init__()
110 | self.period = period
111 | norm_f = spectral_norm if use_spectral_norm else weight_norm
112 |
113 | in_channels = [1, 32, 128, 512, 1024]
114 | out_channels = [32, 128, 512, 1024, 1024]
115 |
116 | self.convs = torch.nn.ModuleList(
117 | [
118 | norm_f(
119 | torch.nn.Conv2d(
120 | in_ch,
121 | out_ch,
122 | (kernel_size, 1),
123 | (stride, 1),
124 | padding=(get_padding(kernel_size, 1), 0),
125 | )
126 | )
127 | for in_ch, out_ch in zip(in_channels, out_channels)
128 | ]
129 | )
130 |
131 | self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
132 | self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
133 |
134 | def forward(self, x):
135 | fmap = []
136 | b, c, t = x.shape
137 | if t % self.period != 0:
138 | n_pad = self.period - (t % self.period)
139 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
140 | x = x.view(b, c, -1, self.period)
141 |
142 | for conv in self.convs:
143 | x = self.lrelu(conv(x))
144 | fmap.append(x)
145 | x = self.conv_post(x)
146 | fmap.append(x)
147 | x = torch.flatten(x, 1, -1)
148 | return x, fmap
149 |
--------------------------------------------------------------------------------
/rvc/lib/algorithm/encoders.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from typing import Optional
4 |
5 | from rvc.lib.algorithm.commons import sequence_mask
6 | from rvc.lib.algorithm.modules import WaveNet
7 | from rvc.lib.algorithm.normalization import LayerNorm
8 | from rvc.lib.algorithm.attentions import FFN, MultiHeadAttention
9 |
10 |
11 | class Encoder(torch.nn.Module):
12 | """
13 | Encoder module for the Transformer model.
14 |
15 | Args:
16 | hidden_channels (int): Number of hidden channels in the encoder.
17 | filter_channels (int): Number of filter channels in the feed-forward network.
18 | n_heads (int): Number of attention heads.
19 | n_layers (int): Number of encoder layers.
20 | kernel_size (int, optional): Kernel size of the convolution layers in the feed-forward network. Defaults to 1.
21 | p_dropout (float, optional): Dropout probability. Defaults to 0.0.
22 | window_size (int, optional): Window size for relative positional encoding. Defaults to 10.
23 | """
24 |
25 | def __init__(
26 | self,
27 | hidden_channels: int,
28 | filter_channels: int,
29 | n_heads: int,
30 | n_layers: int,
31 | kernel_size: int = 1,
32 | p_dropout: float = 0.0,
33 | window_size: int = 10,
34 | ):
35 | super().__init__()
36 |
37 | self.hidden_channels = hidden_channels
38 | self.n_layers = n_layers
39 | self.drop = torch.nn.Dropout(p_dropout)
40 |
41 | self.attn_layers = torch.nn.ModuleList(
42 | [
43 | MultiHeadAttention(
44 | hidden_channels,
45 | hidden_channels,
46 | n_heads,
47 | p_dropout=p_dropout,
48 | window_size=window_size,
49 | )
50 | for _ in range(n_layers)
51 | ]
52 | )
53 | self.norm_layers_1 = torch.nn.ModuleList(
54 | [LayerNorm(hidden_channels) for _ in range(n_layers)]
55 | )
56 | self.ffn_layers = torch.nn.ModuleList(
57 | [
58 | FFN(
59 | hidden_channels,
60 | hidden_channels,
61 | filter_channels,
62 | kernel_size,
63 | p_dropout=p_dropout,
64 | )
65 | for _ in range(n_layers)
66 | ]
67 | )
68 | self.norm_layers_2 = torch.nn.ModuleList(
69 | [LayerNorm(hidden_channels) for _ in range(n_layers)]
70 | )
71 |
72 | def forward(self, x, x_mask):
73 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
74 | x = x * x_mask
75 |
76 | for i in range(self.n_layers):
77 | y = self.attn_layers[i](x, x, attn_mask)
78 | y = self.drop(y)
79 | x = self.norm_layers_1[i](x + y)
80 |
81 | y = self.ffn_layers[i](x, x_mask)
82 | y = self.drop(y)
83 | x = self.norm_layers_2[i](x + y)
84 |
85 | return x * x_mask
86 |
87 |
88 | class TextEncoder(torch.nn.Module):
89 | """
90 | Text Encoder with configurable embedding dimension.
91 |
92 | Args:
93 | out_channels (int): Output channels of the encoder.
94 | hidden_channels (int): Hidden channels of the encoder.
95 | filter_channels (int): Filter channels of the encoder.
96 | n_heads (int): Number of attention heads.
97 | n_layers (int): Number of encoder layers.
98 | kernel_size (int): Kernel size of the convolutional layers.
99 | p_dropout (float): Dropout probability.
100 | embedding_dim (int): Embedding dimension for phone embeddings (v1 = 256, v2 = 768).
101 | f0 (bool, optional): Whether to use F0 embedding. Defaults to True.
102 | """
103 |
104 | def __init__(
105 | self,
106 | out_channels: int,
107 | hidden_channels: int,
108 | filter_channels: int,
109 | n_heads: int,
110 | n_layers: int,
111 | kernel_size: int,
112 | p_dropout: float,
113 | embedding_dim: int,
114 | f0: bool = True,
115 | ):
116 | super().__init__()
117 | self.hidden_channels = hidden_channels
118 | self.out_channels = out_channels
119 | self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
120 | self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
121 | self.emb_pitch = torch.nn.Embedding(256, hidden_channels) if f0 else None
122 |
123 | self.encoder = Encoder(
124 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
125 | )
126 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
127 |
128 | def forward(
129 | self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
130 | ):
131 | x = self.emb_phone(phone)
132 | if pitch is not None and self.emb_pitch:
133 | x += self.emb_pitch(pitch)
134 |
135 | x *= math.sqrt(self.hidden_channels)
136 | x = self.lrelu(x)
137 | x = x.transpose(1, -1) # [B, H, T]
138 |
139 | x_mask = sequence_mask(lengths, x.size(2)).unsqueeze(1).to(x.dtype)
140 | x = self.encoder(x, x_mask)
141 | stats = self.proj(x) * x_mask
142 |
143 | m, logs = torch.split(stats, self.out_channels, dim=1)
144 | return m, logs, x_mask
145 |
146 |
147 | class PosteriorEncoder(torch.nn.Module):
148 | """
149 | Posterior Encoder for inferring latent representation.
150 |
151 | Args:
152 | in_channels (int): Number of channels in the input.
153 | out_channels (int): Number of channels in the output.
154 | hidden_channels (int): Number of hidden channels in the encoder.
155 | kernel_size (int): Kernel size of the convolutional layers.
156 | dilation_rate (int): Dilation rate of the convolutional layers.
157 | n_layers (int): Number of layers in the encoder.
158 | gin_channels (int, optional): Number of channels for the global conditioning input. Defaults to 0.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels: int,
164 | out_channels: int,
165 | hidden_channels: int,
166 | kernel_size: int,
167 | dilation_rate: int,
168 | n_layers: int,
169 | gin_channels: int = 0,
170 | ):
171 | super().__init__()
172 | self.out_channels = out_channels
173 | self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
174 | self.enc = WaveNet(
175 | hidden_channels,
176 | kernel_size,
177 | dilation_rate,
178 | n_layers,
179 | gin_channels=gin_channels,
180 | )
181 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
182 |
183 | def forward(
184 | self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
185 | ):
186 | x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
187 |
188 | x = self.pre(x) * x_mask
189 | x = self.enc(x, x_mask, g=g)
190 |
191 | stats = self.proj(x) * x_mask
192 | m, logs = torch.split(stats, self.out_channels, dim=1)
193 |
194 | z = m + torch.randn_like(m) * torch.exp(logs)
195 | z *= x_mask
196 |
197 | return z, m, logs, x_mask
198 |
199 | def remove_weight_norm(self):
200 | self.enc.remove_weight_norm()
201 |
202 | def __prepare_scriptable__(self):
203 | for hook in self.enc._forward_pre_hooks.values():
204 | if (
205 | hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
206 | and hook.__class__.__name__ == "WeightNorm"
207 | ):
208 | torch.nn.utils.remove_weight_norm(self.enc)
209 | return self
210 |
--------------------------------------------------------------------------------
/rvc/lib/algorithm/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/lib/algorithm/generators/__init__.py
--------------------------------------------------------------------------------
/rvc/lib/algorithm/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from rvc.lib.algorithm.commons import fused_add_tanh_sigmoid_multiply
3 |
4 |
5 | class WaveNet(torch.nn.Module):
6 | """
7 | WaveNet residual blocks as used in WaveGlow.
8 |
9 | Args:
10 | hidden_channels (int): Number of hidden channels.
11 | kernel_size (int): Size of the convolutional kernel.
12 | dilation_rate (int): Dilation rate of the convolution.
13 | n_layers (int): Number of convolutional layers.
14 | gin_channels (int, optional): Number of conditioning channels. Defaults to 0.
15 | p_dropout (float, optional): Dropout probability. Defaults to 0.
16 | """
17 |
18 | def __init__(
19 | self,
20 | hidden_channels: int,
21 | kernel_size: int,
22 | dilation_rate,
23 | n_layers: int,
24 | gin_channels: int = 0,
25 | p_dropout: int = 0,
26 | ):
27 | super().__init__()
28 | assert kernel_size % 2 == 1, "Kernel size must be odd for proper padding."
29 |
30 | self.hidden_channels = hidden_channels
31 | self.kernel_size = (kernel_size,)
32 | self.dilation_rate = dilation_rate
33 | self.n_layers = n_layers
34 | self.gin_channels = gin_channels
35 | self.p_dropout = p_dropout
36 | self.n_channels_tensor = torch.IntTensor([hidden_channels]) # Static tensor
37 |
38 | self.in_layers = torch.nn.ModuleList()
39 | self.res_skip_layers = torch.nn.ModuleList()
40 | self.drop = torch.nn.Dropout(p_dropout)
41 |
42 | # Conditional layer for global conditioning
43 | if gin_channels:
44 | self.cond_layer = torch.nn.utils.parametrizations.weight_norm(
45 | torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1),
46 | name="weight",
47 | )
48 |
49 | # Precompute dilations and paddings
50 | dilations = [dilation_rate**i for i in range(n_layers)]
51 | paddings = [(kernel_size * d - d) // 2 for d in dilations]
52 |
53 | # Initialize layers
54 | for i in range(n_layers):
55 | self.in_layers.append(
56 | torch.nn.utils.parametrizations.weight_norm(
57 | torch.nn.Conv1d(
58 | hidden_channels,
59 | 2 * hidden_channels,
60 | kernel_size,
61 | dilation=dilations[i],
62 | padding=paddings[i],
63 | ),
64 | name="weight",
65 | )
66 | )
67 |
68 | res_skip_channels = (
69 | hidden_channels if i == n_layers - 1 else 2 * hidden_channels
70 | )
71 | self.res_skip_layers.append(
72 | torch.nn.utils.parametrizations.weight_norm(
73 | torch.nn.Conv1d(hidden_channels, res_skip_channels, 1),
74 | name="weight",
75 | )
76 | )
77 |
78 | def forward(self, x, x_mask, g=None):
79 | output = x.clone().zero_()
80 |
81 | # Apply conditional layer if global conditioning is provided
82 | g = self.cond_layer(g) if g is not None else None
83 |
84 | for i in range(self.n_layers):
85 | x_in = self.in_layers[i](x)
86 | g_l = (
87 | g[
88 | :,
89 | i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels,
90 | :,
91 | ]
92 | if g is not None
93 | else 0
94 | )
95 |
96 | # Activation with fused Tanh-Sigmoid
97 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.n_channels_tensor)
98 | acts = self.drop(acts)
99 |
100 | # Residual and skip connections
101 | res_skip_acts = self.res_skip_layers[i](acts)
102 | if i < self.n_layers - 1:
103 | res_acts = res_skip_acts[:, : self.hidden_channels, :]
104 | x = (x + res_acts) * x_mask
105 | output = output + res_skip_acts[:, self.hidden_channels :, :]
106 | else:
107 | output = output + res_skip_acts
108 |
109 | return output * x_mask
110 |
111 | def remove_weight_norm(self):
112 | if self.gin_channels:
113 | torch.nn.utils.remove_weight_norm(self.cond_layer)
114 | for layer in self.in_layers:
115 | torch.nn.utils.remove_weight_norm(layer)
116 | for layer in self.res_skip_layers:
117 | torch.nn.utils.remove_weight_norm(layer)
118 |
--------------------------------------------------------------------------------
/rvc/lib/algorithm/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LayerNorm(torch.nn.Module):
5 | """
6 | Layer normalization module.
7 |
8 | Args:
9 | channels (int): Number of channels.
10 | eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5.
11 | """
12 |
13 | def __init__(self, channels: int, eps: float = 1e-5):
14 | super().__init__()
15 | self.eps = eps
16 | self.gamma = torch.nn.Parameter(torch.ones(channels))
17 | self.beta = torch.nn.Parameter(torch.zeros(channels))
18 |
19 | def forward(self, x):
20 | # Transpose to (batch_size, time_steps, channels) for layer_norm
21 | x = x.transpose(1, -1)
22 | x = torch.nn.functional.layer_norm(
23 | x, (x.size(-1),), self.gamma, self.beta, self.eps
24 | )
25 | # Transpose back to (batch_size, channels, time_steps)
26 | return x.transpose(1, -1)
27 |
--------------------------------------------------------------------------------
/rvc/lib/predictors/F0Extractor.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import pathlib
3 | import librosa
4 | import numpy as np
5 | import resampy
6 | import torch
7 | import torchcrepe
8 | import torchfcpe
9 | import os
10 |
11 | # from tools.anyf0.rmvpe import RMVPE
12 | from rvc.lib.predictors.RMVPE import RMVPE0Predictor
13 | from rvc.configs.config import Config
14 |
15 | config = Config()
16 |
17 |
18 | @dataclasses.dataclass
19 | class F0Extractor:
20 | wav_path: pathlib.Path
21 | sample_rate: int = 44100
22 | hop_length: int = 512
23 | f0_min: int = 50
24 | f0_max: int = 1600
25 | method: str = "rmvpe"
26 | x: np.ndarray = dataclasses.field(init=False)
27 |
28 | def __post_init__(self):
29 | self.x, self.sample_rate = librosa.load(self.wav_path, sr=self.sample_rate)
30 |
31 | @property
32 | def hop_size(self):
33 | return self.hop_length / self.sample_rate
34 |
35 | @property
36 | def wav16k(self):
37 | return resampy.resample(self.x, self.sample_rate, 16000)
38 |
39 | def extract_f0(self):
40 | f0 = None
41 | method = self.method
42 | if method == "crepe":
43 | wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device)
44 | f0 = torchcrepe.predict(
45 | wav16k_torch,
46 | sample_rate=16000,
47 | hop_length=160,
48 | batch_size=512,
49 | fmin=self.f0_min,
50 | fmax=self.f0_max,
51 | device=config.device,
52 | )
53 | f0 = f0[0].cpu().numpy()
54 | elif method == "fcpe":
55 | audio = librosa.to_mono(self.x)
56 | audio_length = len(audio)
57 | f0_target_length = (audio_length // self.hop_length) + 1
58 | audio = (
59 | torch.from_numpy(audio)
60 | .float()
61 | .unsqueeze(0)
62 | .unsqueeze(-1)
63 | .to(config.device)
64 | )
65 | model = torchfcpe.spawn_bundled_infer_model(device=config.device)
66 |
67 | f0 = model.infer(
68 | audio,
69 | sr=self.sample_rate,
70 | decoder_mode="local_argmax",
71 | threshold=0.006,
72 | f0_min=self.f0_min,
73 | f0_max=self.f0_max,
74 | interp_uv=False,
75 | output_interp_target_length=f0_target_length,
76 | )
77 | f0 = f0.squeeze().cpu().numpy()
78 | elif method == "rmvpe":
79 | model_rmvpe = RMVPE0Predictor(
80 | os.path.join("rvc", "models", "predictors", "rmvpe.pt"),
81 | device=config.device,
82 | # hop_length=80
83 | )
84 | f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03)
85 |
86 | else:
87 | raise ValueError(f"Unknown method: {self.method}")
88 | return self.hz_to_cents(f0, librosa.midi_to_hz(0))
89 |
90 | def plot_f0(self, f0):
91 | from matplotlib import pyplot as plt
92 |
93 | plt.figure(figsize=(10, 4))
94 | plt.plot(f0)
95 | plt.title(self.method)
96 | plt.xlabel("Time (frames)")
97 | plt.ylabel("F0 (cents)")
98 | plt.show()
99 |
100 | @staticmethod
101 | def hz_to_cents(F, F_ref=55.0):
102 | F_temp = np.array(F).astype(float)
103 | F_temp[F_temp == 0] = np.nan
104 | F_cents = 1200 * np.log2(F_temp / F_ref)
105 | return F_cents
106 |
--------------------------------------------------------------------------------
/rvc/lib/tools/analyzer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import librosa.display
4 | import librosa
5 |
6 |
7 | def calculate_features(y, sr):
8 | stft = np.abs(librosa.stft(y))
9 | duration = librosa.get_duration(y=y, sr=sr)
10 | cent = librosa.feature.spectral_centroid(S=stft, sr=sr)[0]
11 | bw = librosa.feature.spectral_bandwidth(S=stft, sr=sr)[0]
12 | rolloff = librosa.feature.spectral_rolloff(S=stft, sr=sr)[0]
13 | return stft, duration, cent, bw, rolloff
14 |
15 |
16 | def plot_title(title):
17 | plt.suptitle(title, fontsize=16, fontweight="bold")
18 |
19 |
20 | def plot_spectrogram(y, sr, stft, duration, cmap="inferno"):
21 | plt.subplot(3, 1, 1)
22 | plt.imshow(
23 | librosa.amplitude_to_db(stft, ref=np.max),
24 | origin="lower",
25 | extent=[0, duration, 0, sr / 1000],
26 | aspect="auto",
27 | cmap=cmap, # Change the colormap here
28 | )
29 | plt.colorbar(format="%+2.0f dB")
30 | plt.xlabel("Time (s)")
31 | plt.ylabel("Frequency (kHz)")
32 | plt.title("Spectrogram")
33 |
34 |
35 | def plot_waveform(y, sr, duration):
36 | plt.subplot(3, 1, 2)
37 | librosa.display.waveshow(y, sr=sr)
38 | plt.xlabel("Time (s)")
39 | plt.ylabel("Amplitude")
40 | plt.title("Waveform")
41 |
42 |
43 | def plot_features(times, cent, bw, rolloff, duration):
44 | plt.subplot(3, 1, 3)
45 | plt.plot(times, cent, label="Spectral Centroid (kHz)", color="b")
46 | plt.plot(times, bw, label="Spectral Bandwidth (kHz)", color="g")
47 | plt.plot(times, rolloff, label="Spectral Rolloff (kHz)", color="r")
48 | plt.xlabel("Time (s)")
49 | plt.title("Spectral Features")
50 | plt.legend()
51 |
52 |
53 | def analyze_audio(audio_file, save_plot_path="logs/audio_analysis.png"):
54 | y, sr = librosa.load(audio_file)
55 | stft, duration, cent, bw, rolloff = calculate_features(y, sr)
56 |
57 | plt.figure(figsize=(12, 10))
58 |
59 | plot_title("Audio Analysis" + " - " + audio_file.split("/")[-1])
60 | plot_spectrogram(y, sr, stft, duration)
61 | plot_waveform(y, sr, duration)
62 | plot_features(librosa.times_like(cent), cent, bw, rolloff, duration)
63 |
64 | plt.tight_layout()
65 |
66 | if save_plot_path:
67 | plt.savefig(save_plot_path, bbox_inches="tight", dpi=300)
68 | plt.close()
69 |
70 | audio_info = f"""Sample Rate: {sr}\nDuration: {(
71 | str(round(duration, 2)) + " seconds"
72 | if duration < 60
73 | else str(round(duration / 60, 2)) + " minutes"
74 | )}\nNumber of Samples: {len(y)}\nBits per Sample: {librosa.get_samplerate(audio_file)}\nChannels: {"Mono (1)" if y.ndim == 1 else "Stereo (2)"}"""
75 |
76 | return audio_info, save_plot_path
77 |
--------------------------------------------------------------------------------
/rvc/lib/tools/launch_tensorboard.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | from tensorboard import program
4 |
5 | log_path = "logs"
6 |
7 |
8 | def launch_tensorboard_pipeline():
9 | logging.getLogger("root").setLevel(logging.WARNING)
10 | logging.getLogger("tensorboard").setLevel(logging.WARNING)
11 |
12 | tb = program.TensorBoard()
13 | tb.configure(argv=[None, "--logdir", log_path])
14 | url = tb.launch()
15 |
16 | print(
17 | f"Access the tensorboard using the following link:\n{url}?pinnedCards=%5B%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Ftotal%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fd%2Ftotal%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Fkl%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Fmel%22%7D%5D"
18 | )
19 |
20 | while True:
21 | time.sleep(600)
22 |
--------------------------------------------------------------------------------
/rvc/lib/tools/model_download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | import shutil
5 | import zipfile
6 | import requests
7 | from bs4 import BeautifulSoup
8 | from urllib.parse import unquote
9 | from tqdm import tqdm
10 |
11 | now_dir = os.getcwd()
12 | sys.path.append(now_dir)
13 |
14 | from rvc.lib.utils import format_title
15 | from rvc.lib.tools import gdown
16 |
17 |
18 | file_path = os.path.join(now_dir, "logs")
19 | zips_path = os.path.join(file_path, "zips")
20 | os.makedirs(zips_path, exist_ok=True)
21 |
22 |
23 | def search_pth_index(folder):
24 | pth_paths = [
25 | os.path.join(folder, file)
26 | for file in os.listdir(folder)
27 | if os.path.isfile(os.path.join(folder, file)) and file.endswith(".pth")
28 | ]
29 | index_paths = [
30 | os.path.join(folder, file)
31 | for file in os.listdir(folder)
32 | if os.path.isfile(os.path.join(folder, file)) and file.endswith(".index")
33 | ]
34 | return pth_paths, index_paths
35 |
36 |
37 | def download_from_url(url):
38 | os.chdir(zips_path)
39 |
40 | try:
41 | if "drive.google.com" in url:
42 | file_id = extract_google_drive_id(url)
43 | if file_id:
44 | gdown.download(
45 | url=f"https://drive.google.com/uc?id={file_id}",
46 | quiet=False,
47 | fuzzy=True,
48 | )
49 | elif "/blob/" in url or "/resolve/" in url:
50 | download_blob_or_resolve(url)
51 | elif "/tree/main" in url:
52 | download_from_huggingface(url)
53 | else:
54 | download_file(url)
55 |
56 | rename_downloaded_files()
57 | return "downloaded"
58 | except Exception as error:
59 | print(f"An error occurred downloading the file: {error}")
60 | return None
61 | finally:
62 | os.chdir(now_dir)
63 |
64 |
65 | def extract_google_drive_id(url):
66 | if "file/d/" in url:
67 | return url.split("file/d/")[1].split("/")[0]
68 | if "id=" in url:
69 | return url.split("id=")[1].split("&")[0]
70 | return None
71 |
72 |
73 | def download_blob_or_resolve(url):
74 | if "/blob/" in url:
75 | url = url.replace("/blob/", "/resolve/")
76 | response = requests.get(url, stream=True)
77 | if response.status_code == 200:
78 | save_response_content(response)
79 | else:
80 | raise ValueError(
81 | "Download failed with status code: " + str(response.status_code)
82 | )
83 |
84 |
85 | def save_response_content(response):
86 | content_disposition = unquote(response.headers.get("Content-Disposition", ""))
87 | file_name = (
88 | re.search(r'filename="([^"]+)"', content_disposition)
89 | .groups()[0]
90 | .replace(os.path.sep, "_")
91 | if content_disposition
92 | else "downloaded_file"
93 | )
94 |
95 | total_size = int(response.headers.get("Content-Length", 0))
96 | chunk_size = 1024
97 |
98 | with open(os.path.join(zips_path, file_name), "wb") as file, tqdm(
99 | total=total_size, unit="B", unit_scale=True, desc=file_name
100 | ) as progress_bar:
101 | for data in response.iter_content(chunk_size):
102 | file.write(data)
103 | progress_bar.update(len(data))
104 |
105 |
106 | def download_from_huggingface(url):
107 | response = requests.get(url)
108 | soup = BeautifulSoup(response.content, "html.parser")
109 | temp_url = next(
110 | (
111 | link["href"]
112 | for link in soup.find_all("a", href=True)
113 | if link["href"].endswith(".zip")
114 | ),
115 | None,
116 | )
117 | if temp_url:
118 | url = temp_url.replace("blob", "resolve")
119 | if "huggingface.co" not in url:
120 | url = "https://huggingface.co" + url
121 | download_file(url)
122 | else:
123 | raise ValueError("No zip file found in Huggingface URL")
124 |
125 |
126 | def download_file(url):
127 | response = requests.get(url, stream=True)
128 | if response.status_code == 200:
129 | save_response_content(response)
130 | else:
131 | raise ValueError(
132 | "Download failed with status code: " + str(response.status_code)
133 | )
134 |
135 |
136 | def rename_downloaded_files():
137 | for currentPath, _, zipFiles in os.walk(zips_path):
138 | for file in zipFiles:
139 | file_name, extension = os.path.splitext(file)
140 | real_path = os.path.join(currentPath, file)
141 | os.rename(real_path, file_name.replace(os.path.sep, "_") + extension)
142 |
143 |
144 | def extract(zipfile_path, unzips_path):
145 | try:
146 | with zipfile.ZipFile(zipfile_path, "r") as zip_ref:
147 | zip_ref.extractall(unzips_path)
148 | os.remove(zipfile_path)
149 | return True
150 | except Exception as error:
151 | print(f"An error occurred extracting the zip file: {error}")
152 | return False
153 |
154 |
155 | def unzip_file(zip_path, zip_file_name):
156 | zip_file_path = os.path.join(zip_path, zip_file_name + ".zip")
157 | extract_path = os.path.join(file_path, zip_file_name)
158 | with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
159 | zip_ref.extractall(extract_path)
160 | os.remove(zip_file_path)
161 |
162 |
163 | def model_download_pipeline(url: str):
164 | try:
165 | result = download_from_url(url)
166 | if result == "downloaded":
167 | return handle_extraction_process()
168 | else:
169 | return "Error"
170 | except Exception as error:
171 | print(f"An unexpected error occurred: {error}")
172 | return "Error"
173 |
174 |
175 | def handle_extraction_process():
176 | extract_folder_path = ""
177 | for filename in os.listdir(zips_path):
178 | if filename.endswith(".zip"):
179 | zipfile_path = os.path.join(zips_path, filename)
180 | model_name = format_title(os.path.basename(zipfile_path).split(".zip")[0])
181 | extract_folder_path = os.path.join("logs", os.path.normpath(model_name))
182 | success = extract(zipfile_path, extract_folder_path)
183 | clean_extracted_files(extract_folder_path, model_name)
184 |
185 | if success:
186 | print(f"Model {model_name} downloaded!")
187 | else:
188 | print(f"Error downloading {model_name}")
189 | return "Error"
190 | if not extract_folder_path:
191 | print("Zip file was not found.")
192 | return "Error"
193 | return search_pth_index(extract_folder_path)
194 |
195 |
196 | def clean_extracted_files(extract_folder_path, model_name):
197 | macosx_path = os.path.join(extract_folder_path, "__MACOSX")
198 | if os.path.exists(macosx_path):
199 | shutil.rmtree(macosx_path)
200 |
201 | subfolders = [
202 | f
203 | for f in os.listdir(extract_folder_path)
204 | if os.path.isdir(os.path.join(extract_folder_path, f))
205 | ]
206 | if len(subfolders) == 1:
207 | subfolder_path = os.path.join(extract_folder_path, subfolders[0])
208 | for item in os.listdir(subfolder_path):
209 | shutil.move(
210 | os.path.join(subfolder_path, item),
211 | os.path.join(extract_folder_path, item),
212 | )
213 | os.rmdir(subfolder_path)
214 |
215 | for item in os.listdir(extract_folder_path):
216 | source_path = os.path.join(extract_folder_path, item)
217 | if ".pth" in item:
218 | new_file_name = model_name + ".pth"
219 | elif ".index" in item:
220 | new_file_name = model_name + ".index"
221 | else:
222 | continue
223 |
224 | destination_path = os.path.join(extract_folder_path, new_file_name)
225 | if not os.path.exists(destination_path):
226 | os.rename(source_path, destination_path)
227 |
--------------------------------------------------------------------------------
/rvc/lib/tools/prerequisites_download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from concurrent.futures import ThreadPoolExecutor
3 | from tqdm import tqdm
4 | import requests
5 |
6 | url_base = "https://huggingface.co/IAHispano/Applio/resolve/main/Resources"
7 |
8 | pretraineds_hifigan_list = [
9 | (
10 | "pretrained_v2/",
11 | [
12 | "f0D32k.pth",
13 | "f0D40k.pth",
14 | "f0D48k.pth",
15 | "f0G32k.pth",
16 | "f0G40k.pth",
17 | "f0G48k.pth",
18 | ],
19 | )
20 | ]
21 | models_list = [("predictors/", ["rmvpe.pt", "fcpe.pt"])]
22 | embedders_list = [("embedders/contentvec/", ["pytorch_model.bin", "config.json"])]
23 | executables_list = [
24 | ("", ["ffmpeg.exe", "ffprobe.exe"]),
25 | ]
26 |
27 | folder_mapping_list = {
28 | "pretrained_v2/": "rvc/models/pretraineds/hifi-gan/",
29 | "embedders/contentvec/": "rvc/models/embedders/contentvec/",
30 | "predictors/": "rvc/models/predictors/",
31 | "formant/": "rvc/models/formant/",
32 | }
33 |
34 |
35 | def get_file_size_if_missing(file_list):
36 | """
37 | Calculate the total size of files to be downloaded only if they do not exist locally.
38 | """
39 | total_size = 0
40 | for remote_folder, files in file_list:
41 | local_folder = folder_mapping_list.get(remote_folder, "")
42 | for file in files:
43 | destination_path = os.path.join(local_folder, file)
44 | if not os.path.exists(destination_path):
45 | url = f"{url_base}/{remote_folder}{file}"
46 | response = requests.head(url)
47 | total_size += int(response.headers.get("content-length", 0))
48 | return total_size
49 |
50 |
51 | def download_file(url, destination_path, global_bar):
52 | """
53 | Download a file from the given URL to the specified destination path,
54 | updating the global progress bar as data is downloaded.
55 | """
56 |
57 | dir_name = os.path.dirname(destination_path)
58 | if dir_name:
59 | os.makedirs(dir_name, exist_ok=True)
60 | response = requests.get(url, stream=True)
61 | block_size = 1024
62 | with open(destination_path, "wb") as file:
63 | for data in response.iter_content(block_size):
64 | file.write(data)
65 | global_bar.update(len(data))
66 |
67 |
68 | def download_mapping_files(file_mapping_list, global_bar):
69 | """
70 | Download all files in the provided file mapping list using a thread pool executor,
71 | and update the global progress bar as downloads progress.
72 | """
73 | with ThreadPoolExecutor() as executor:
74 | futures = []
75 | for remote_folder, file_list in file_mapping_list:
76 | local_folder = folder_mapping_list.get(remote_folder, "")
77 | for file in file_list:
78 | destination_path = os.path.join(local_folder, file)
79 | if not os.path.exists(destination_path):
80 | url = f"{url_base}/{remote_folder}{file}"
81 | futures.append(
82 | executor.submit(
83 | download_file, url, destination_path, global_bar
84 | )
85 | )
86 | for future in futures:
87 | future.result()
88 |
89 |
90 | def split_pretraineds(pretrained_list):
91 | f0_list = []
92 | non_f0_list = []
93 | for folder, files in pretrained_list:
94 | f0_files = [f for f in files if f.startswith("f0")]
95 | non_f0_files = [f for f in files if not f.startswith("f0")]
96 | if f0_files:
97 | f0_list.append((folder, f0_files))
98 | if non_f0_files:
99 | non_f0_list.append((folder, non_f0_files))
100 | return f0_list, non_f0_list
101 |
102 |
103 | pretraineds_hifigan_list, _ = split_pretraineds(pretraineds_hifigan_list)
104 |
105 |
106 | def calculate_total_size(
107 | pretraineds_hifigan,
108 | models,
109 | exe,
110 | ):
111 | """
112 | Calculate the total size of all files to be downloaded based on selected categories.
113 | """
114 | total_size = 0
115 | if models:
116 | total_size += get_file_size_if_missing(models_list)
117 | total_size += get_file_size_if_missing(embedders_list)
118 | if exe and os.name == "nt":
119 | total_size += get_file_size_if_missing(executables_list)
120 | total_size += get_file_size_if_missing(pretraineds_hifigan)
121 | return total_size
122 |
123 |
124 | def prequisites_download_pipeline(
125 | pretraineds_hifigan,
126 | models,
127 | exe,
128 | ):
129 | """
130 | Manage the download pipeline for different categories of files.
131 | """
132 | total_size = calculate_total_size(
133 | pretraineds_hifigan_list if pretraineds_hifigan else [],
134 | models,
135 | exe,
136 | )
137 |
138 | if total_size > 0:
139 | with tqdm(
140 | total=total_size, unit="iB", unit_scale=True, desc="Downloading all files"
141 | ) as global_bar:
142 | if models:
143 | download_mapping_files(models_list, global_bar)
144 | download_mapping_files(embedders_list, global_bar)
145 | if exe:
146 | if os.name == "nt":
147 | download_mapping_files(executables_list, global_bar)
148 | else:
149 | print("No executables needed")
150 | if pretraineds_hifigan:
151 | download_mapping_files(pretraineds_hifigan_list, global_bar)
152 | else:
153 | pass
154 |
--------------------------------------------------------------------------------
/rvc/lib/tools/pretrained_selector.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def pretrained_selector(vocoder, sample_rate):
5 | base_path = os.path.join("rvc", "models", "pretraineds", f"{vocoder.lower()}")
6 |
7 | path_g = os.path.join(base_path, f"f0G{str(sample_rate)[:2]}k.pth")
8 | path_d = os.path.join(base_path, f"f0D{str(sample_rate)[:2]}k.pth")
9 |
10 | if os.path.exists(path_g) and os.path.exists(path_d):
11 | return path_g, path_d
12 | else:
13 | return "", ""
14 |
--------------------------------------------------------------------------------
/rvc/lib/tools/split_audio.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import librosa
3 |
4 |
5 | def process_audio(audio, sr=16000, silence_thresh=-60, min_silence_len=250):
6 | """
7 | Splits an audio signal into segments using a fixed frame size and hop size.
8 |
9 | Parameters:
10 | - audio (np.ndarray): The audio signal to split.
11 | - sr (int): The sample rate of the input audio (default is 16000).
12 | - silence_thresh (int): Silence threshold (default =-60dB)
13 | - min_silence_len (int): Minimum silence duration (default 250ms).
14 |
15 | Returns:
16 | - list of np.ndarray: A list of audio segments.
17 | - np.ndarray: The intervals where the audio was split.
18 | """
19 | frame_length = int(min_silence_len / 1000 * sr)
20 | hop_length = frame_length // 2
21 | intervals = librosa.effects.split(
22 | audio, top_db=-silence_thresh, frame_length=frame_length, hop_length=hop_length
23 | )
24 | audio_segments = [audio[start:end] for start, end in intervals]
25 |
26 | return audio_segments, intervals
27 |
28 |
29 | def merge_audio(audio_segments_org, audio_segments_new, intervals, sr_orig, sr_new):
30 | """
31 | Merges audio segments back into a single audio signal, filling gaps with silence.
32 | Assumes audio segments are already at sr_new.
33 |
34 | Parameters:
35 | - audio_segments_org (list of np.ndarray): The non-silent audio segments (at sr_orig).
36 | - audio_segments_new (list of np.ndarray): The non-silent audio segments (at sr_new).
37 | - intervals (np.ndarray): The intervals used for splitting the original audio.
38 | - sr_orig (int): The sample rate of the original audio
39 | - sr_new (int): The sample rate of the model
40 | Returns:
41 | - np.ndarray: The merged audio signal with silent gaps restored.
42 | """
43 | merged_audio = np.array([], dtype=audio_segments_new[0].dtype)
44 | sr_ratio = sr_new / sr_orig
45 |
46 | for i, (start, end) in enumerate(intervals):
47 |
48 | start_new = int(start * sr_ratio)
49 | end_new = int(end * sr_ratio)
50 |
51 | original_duration = len(audio_segments_org[i]) / sr_orig
52 | new_duration = len(audio_segments_new[i]) / sr_new
53 | duration_diff = new_duration - original_duration
54 |
55 | silence_samples = int(abs(duration_diff) * sr_new)
56 | silence_compensation = np.zeros(
57 | silence_samples, dtype=audio_segments_new[0].dtype
58 | )
59 |
60 | if i == 0 and start_new > 0:
61 | initial_silence = np.zeros(start_new, dtype=audio_segments_new[0].dtype)
62 | merged_audio = np.concatenate((merged_audio, initial_silence))
63 |
64 | if duration_diff > 0:
65 | merged_audio = np.concatenate((merged_audio, silence_compensation))
66 |
67 | merged_audio = np.concatenate((merged_audio, audio_segments_new[i]))
68 |
69 | if duration_diff < 0:
70 | merged_audio = np.concatenate((merged_audio, silence_compensation))
71 |
72 | if i < len(intervals) - 1:
73 | next_start_new = int(intervals[i + 1][0] * sr_ratio)
74 | silence_duration = next_start_new - end_new
75 | if silence_duration > 0:
76 | silence = np.zeros(silence_duration, dtype=audio_segments_new[0].dtype)
77 | merged_audio = np.concatenate((merged_audio, silence))
78 |
79 | return merged_audio
80 |
--------------------------------------------------------------------------------
/rvc/lib/tools/tts.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import asyncio
3 | import edge_tts
4 | import os
5 |
6 |
7 | async def main():
8 | # Parse command line arguments
9 | tts_file = str(sys.argv[1])
10 | text = str(sys.argv[2])
11 | voice = str(sys.argv[3])
12 | rate = int(sys.argv[4])
13 | output_file = str(sys.argv[5])
14 |
15 | rates = f"+{rate}%" if rate >= 0 else f"{rate}%"
16 | if tts_file and os.path.exists(tts_file):
17 | text = ""
18 | try:
19 | with open(tts_file, "r", encoding="utf-8") as file:
20 | text = file.read()
21 | except UnicodeDecodeError:
22 | with open(tts_file, "r") as file:
23 | text = file.read()
24 | await edge_tts.Communicate(text, voice, rate=rates).save(output_file)
25 | # print(f"TTS with {voice} completed. Output TTS file: '{output_file}'")
26 |
27 |
28 | if __name__ == "__main__":
29 | asyncio.run(main())
30 |
--------------------------------------------------------------------------------
/rvc/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import soxr
4 | import librosa
5 | import soundfile as sf
6 | import numpy as np
7 | import re
8 | import unicodedata
9 | import wget
10 | from torch import nn
11 |
12 | import logging
13 | from transformers import HubertModel
14 | import warnings
15 |
16 | # Remove this to see warnings about transformers models
17 | warnings.filterwarnings("ignore")
18 |
19 | logging.getLogger("fairseq").setLevel(logging.ERROR)
20 | logging.getLogger("faiss.loader").setLevel(logging.ERROR)
21 | logging.getLogger("transformers").setLevel(logging.ERROR)
22 | logging.getLogger("torch").setLevel(logging.ERROR)
23 |
24 | now_dir = os.getcwd()
25 | sys.path.append(now_dir)
26 |
27 | base_path = os.path.join(now_dir, "rvc", "models", "formant", "stftpitchshift")
28 | stft = base_path + ".exe" if sys.platform == "win32" else base_path
29 |
30 |
31 | class HubertModelWithFinalProj(HubertModel):
32 | def __init__(self, config):
33 | super().__init__(config)
34 | self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
35 |
36 |
37 | def load_audio(file, sample_rate):
38 | try:
39 | file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
40 | audio, sr = sf.read(file)
41 | if len(audio.shape) > 1:
42 | audio = librosa.to_mono(audio.T)
43 | if sr != sample_rate:
44 | audio = librosa.resample(
45 | audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq"
46 | )
47 | except Exception as error:
48 | raise RuntimeError(f"An error occurred loading the audio: {error}")
49 |
50 | return audio.flatten()
51 |
52 |
53 | def load_audio_infer(
54 | file,
55 | sample_rate,
56 | **kwargs,
57 | ):
58 | formant_shifting = kwargs.get("formant_shifting", False)
59 | try:
60 | file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
61 | if not os.path.isfile(file):
62 | raise FileNotFoundError(f"File not found: {file}")
63 | audio, sr = sf.read(file)
64 | if len(audio.shape) > 1:
65 | audio = librosa.to_mono(audio.T)
66 | if sr != sample_rate:
67 | audio = librosa.resample(
68 | audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq"
69 | )
70 | if formant_shifting:
71 | formant_qfrency = kwargs.get("formant_qfrency", 0.8)
72 | formant_timbre = kwargs.get("formant_timbre", 0.8)
73 |
74 | from stftpitchshift import StftPitchShift
75 |
76 | pitchshifter = StftPitchShift(1024, 32, sample_rate)
77 | audio = pitchshifter.shiftpitch(
78 | audio,
79 | factors=1,
80 | quefrency=formant_qfrency * 1e-3,
81 | distortion=formant_timbre,
82 | )
83 | except Exception as error:
84 | raise RuntimeError(f"An error occurred loading the audio: {error}")
85 | return np.array(audio).flatten()
86 |
87 |
88 | def format_title(title):
89 | formatted_title = unicodedata.normalize("NFC", title)
90 | formatted_title = re.sub(r"[\u2500-\u257F]+", "", formatted_title)
91 | formatted_title = re.sub(r"[^\w\s.-]", "", formatted_title, flags=re.UNICODE)
92 | formatted_title = re.sub(r"\s+", "_", formatted_title)
93 | return formatted_title
94 |
95 |
96 | def load_embedding(embedder_model, custom_embedder=None):
97 | embedder_root = os.path.join(now_dir, "rvc", "models", "embedders")
98 | embedding_list = {
99 | "contentvec": os.path.join(embedder_root, "contentvec"),
100 | "chinese-hubert-base": os.path.join(embedder_root, "chinese_hubert_base"),
101 | "japanese-hubert-base": os.path.join(embedder_root, "japanese_hubert_base"),
102 | "korean-hubert-base": os.path.join(embedder_root, "korean_hubert_base"),
103 | }
104 |
105 | online_embedders = {
106 | "contentvec": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/pytorch_model.bin",
107 | "chinese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/pytorch_model.bin",
108 | "japanese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/pytorch_model.bin",
109 | "korean-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/pytorch_model.bin",
110 | }
111 |
112 | config_files = {
113 | "contentvec": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/config.json",
114 | "chinese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/config.json",
115 | "japanese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/config.json",
116 | "korean-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/config.json",
117 | }
118 |
119 | if embedder_model == "custom":
120 | if os.path.exists(custom_embedder):
121 | model_path = custom_embedder
122 | else:
123 | print(f"Custom embedder not found: {custom_embedder}, using contentvec")
124 | model_path = embedding_list["contentvec"]
125 | else:
126 | model_path = embedding_list[embedder_model]
127 | bin_file = os.path.join(model_path, "pytorch_model.bin")
128 | json_file = os.path.join(model_path, "config.json")
129 | os.makedirs(model_path, exist_ok=True)
130 | if not os.path.exists(bin_file):
131 | url = online_embedders[embedder_model]
132 | print(f"Downloading {url} to {model_path}...")
133 | wget.download(url, out=bin_file)
134 | if not os.path.exists(json_file):
135 | url = config_files[embedder_model]
136 | print(f"Downloading {url} to {model_path}...")
137 | wget.download(url, out=json_file)
138 |
139 | models = HubertModelWithFinalProj.from_pretrained(model_path)
140 | return models
141 |
--------------------------------------------------------------------------------
/rvc/lib/zluda.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
4 |
5 | class STFT:
6 | def __init__(self):
7 | self.device = "cuda"
8 | self.fourier_bases = {} # Cache for Fourier bases
9 |
10 | def _get_fourier_basis(self, n_fft):
11 | # Check if the basis for this n_fft is already cached
12 | if n_fft in self.fourier_bases:
13 | return self.fourier_bases[n_fft]
14 | fourier_basis = torch.fft.fft(torch.eye(n_fft, device="cpu")).to(
15 | self.device
16 | )
17 | # stack separated real and imaginary components and convert to torch tensor
18 | cutoff = n_fft // 2 + 1
19 | fourier_basis = torch.cat(
20 | [fourier_basis.real[:cutoff], fourier_basis.imag[:cutoff]], dim=0
21 | )
22 | # cache the tensor and return
23 | self.fourier_bases[n_fft] = fourier_basis
24 | return fourier_basis
25 |
26 | def transform(self, input, n_fft, hop_length, window):
27 | # fetch cached Fourier basis
28 | fourier_basis = self._get_fourier_basis(n_fft)
29 | # apply hann window to Fourier basis
30 | fourier_basis = fourier_basis * window
31 | # pad input to center with reflect
32 | pad_amount = n_fft // 2
33 | input = torch.nn.functional.pad(
34 | input, (pad_amount, pad_amount), mode="reflect"
35 | )
36 | # separate input into n_fft-sized frames
37 | input_frames = input.unfold(1, n_fft, hop_length).permute(0, 2, 1)
38 | # apply fft to each frame
39 | fourier_transform = torch.matmul(fourier_basis, input_frames)
40 | cutoff = n_fft // 2 + 1
41 | return torch.complex(
42 | fourier_transform[:, :cutoff, :], fourier_transform[:, cutoff:, :]
43 | )
44 |
45 | stft = STFT()
46 | _torch_stft = torch.stft
47 |
48 | def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
49 | # only optimizing a specific call from rvc.train.mel_processing.MultiScaleMelSpectrogramLoss
50 | if (
51 | kwargs.get("win_length") == None
52 | and kwargs.get("center") == None
53 | and kwargs.get("return_complex") == True
54 | ):
55 | # use GPU accelerated calculation
56 | return stft.transform(
57 | input, kwargs.get("n_fft"), kwargs.get("hop_length"), window
58 | )
59 | else:
60 | # simply do the operation on CPU
61 | return _torch_stft(
62 | input=input.cpu(), window=window.cpu(), *args, **kwargs
63 | ).to(input.device)
64 |
65 | def z_jit(f, *_, **__):
66 | f.graph = torch._C.Graph()
67 | return f
68 |
69 | # hijacks
70 | torch.stft = z_stft
71 | torch.jit.script = z_jit
72 | # disabling unsupported cudnn
73 | torch.backends.cudnn.enabled = False
74 | torch.backends.cuda.enable_flash_sdp(False)
75 | torch.backends.cuda.enable_math_sdp(True)
76 | torch.backends.cuda.enable_mem_efficient_sdp(False)
77 |
--------------------------------------------------------------------------------
/rvc/models/embedders/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rvc/models/embedders/embedders_custom/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rvc/models/formant/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rvc/models/predictors/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/predictors/.gitkeep
--------------------------------------------------------------------------------
/rvc/models/pretraineds/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/pretraineds/.gitkeep
--------------------------------------------------------------------------------
/rvc/models/pretraineds/custom/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rvc/models/pretraineds/hifi-gan/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/pretraineds/hifi-gan/.gitkeep
--------------------------------------------------------------------------------
/rvc/train/extract/preparing_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from random import shuffle
4 | from rvc.configs.config import Config
5 | import json
6 |
7 | config = Config()
8 | current_directory = os.getcwd()
9 |
10 |
11 | def generate_config(sample_rate: int, model_path: str):
12 | config_path = os.path.join("rvc", "configs", f"{sample_rate}.json")
13 | config_save_path = os.path.join(model_path, "config.json")
14 | if not os.path.exists(config_save_path):
15 | shutil.copyfile(config_path, config_save_path)
16 |
17 |
18 | def generate_filelist(model_path: str, sample_rate: int, include_mutes: int = 2):
19 | gt_wavs_dir = os.path.join(model_path, "sliced_audios")
20 | feature_dir = os.path.join(model_path, f"extracted")
21 |
22 | f0_dir, f0nsf_dir = None, None
23 | f0_dir = os.path.join(model_path, "f0")
24 | f0nsf_dir = os.path.join(model_path, "f0_voiced")
25 |
26 | gt_wavs_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir))
27 | feature_files = set(name.split(".")[0] for name in os.listdir(feature_dir))
28 |
29 | f0_files = set(name.split(".")[0] for name in os.listdir(f0_dir))
30 | f0nsf_files = set(name.split(".")[0] for name in os.listdir(f0nsf_dir))
31 | names = gt_wavs_files & feature_files & f0_files & f0nsf_files
32 |
33 | options = []
34 | mute_base_path = os.path.join(current_directory, "logs", "mute")
35 | sids = []
36 | for name in names:
37 | sid = name.split("_")[0]
38 | if sid not in sids:
39 | sids.append(sid)
40 | options.append(
41 | f"{os.path.join(gt_wavs_dir, name)}.wav|{os.path.join(feature_dir, name)}.npy|{os.path.join(f0_dir, name)}.wav.npy|{os.path.join(f0nsf_dir, name)}.wav.npy|{sid}"
42 | )
43 |
44 | if include_mutes > 0:
45 | mute_audio_path = os.path.join(
46 | mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"
47 | )
48 | mute_feature_path = os.path.join(mute_base_path, f"extracted", "mute.npy")
49 | mute_f0_path = os.path.join(mute_base_path, "f0", "mute.wav.npy")
50 | mute_f0nsf_path = os.path.join(mute_base_path, "f0_voiced", "mute.wav.npy")
51 |
52 | # adding x files per sid
53 | for sid in sids * include_mutes:
54 | options.append(
55 | f"{mute_audio_path}|{mute_feature_path}|{mute_f0_path}|{mute_f0nsf_path}|{sid}"
56 | )
57 |
58 | file_path = os.path.join(model_path, "model_info.json")
59 | if os.path.exists(file_path):
60 | with open(file_path, "r") as f:
61 | data = json.load(f)
62 | else:
63 | data = {}
64 | data.update(
65 | {
66 | "speakers_id": len(sids),
67 | }
68 | )
69 | with open(file_path, "w") as f:
70 | json.dump(data, f, indent=4)
71 |
72 | shuffle(options)
73 |
74 | with open(os.path.join(model_path, "filelist.txt"), "w") as f:
75 | f.write("\n".join(options))
76 |
--------------------------------------------------------------------------------
/rvc/train/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def feature_loss(fmap_r, fmap_g):
5 | """
6 | Compute the feature loss between reference and generated feature maps.
7 |
8 | Args:
9 | fmap_r (list of torch.Tensor): List of reference feature maps.
10 | fmap_g (list of torch.Tensor): List of generated feature maps.
11 | """
12 | return 2 * sum(
13 | torch.mean(torch.abs(rl - gl))
14 | for dr, dg in zip(fmap_r, fmap_g)
15 | for rl, gl in zip(dr, dg)
16 | )
17 |
18 |
19 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
20 | """
21 | Compute the discriminator loss for real and generated outputs.
22 |
23 | Args:
24 | disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples.
25 | disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
26 | """
27 | loss = 0
28 | r_losses = []
29 | g_losses = []
30 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
31 | r_loss = torch.mean((1 - dr.float()) ** 2)
32 | g_loss = torch.mean(dg.float() ** 2)
33 |
34 | # r_losses.append(r_loss.item())
35 | # g_losses.append(g_loss.item())
36 | loss += r_loss + g_loss
37 |
38 | return loss, r_losses, g_losses
39 |
40 |
41 | def generator_loss(disc_outputs):
42 | """
43 | Compute the generator loss based on discriminator outputs.
44 |
45 | Args:
46 | disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
47 | """
48 | loss = 0
49 | gen_losses = []
50 | for dg in disc_outputs:
51 | l = torch.mean((1 - dg.float()) ** 2)
52 | # gen_losses.append(l.item())
53 | loss += l
54 |
55 | return loss, gen_losses
56 |
57 |
58 | def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0):
59 | loss = 0
60 | for i, (d_real, d_fake) in enumerate(zip(disc_real, disc_fake)):
61 | real_loss = torch.mean((1 - d_real) ** 2)
62 | fake_loss = torch.mean(d_fake**2)
63 | _loss = real_loss + fake_loss
64 | loss += _loss if i < len(disc_real) / 2 else scale * _loss
65 | return loss, None, None
66 |
67 |
68 | def generator_loss_scaled(disc_outputs, scale=1.0):
69 | loss = 0
70 | for i, d_fake in enumerate(disc_outputs):
71 | d_fake = d_fake.float()
72 | _loss = torch.mean((1 - d_fake) ** 2)
73 | loss += _loss if i < len(disc_outputs) / 2 else scale * _loss
74 | return loss, None, None
75 |
76 |
77 | def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0):
78 | """
79 | Compute the scaled discriminator loss for real and generated outputs.
80 |
81 | Args:
82 | disc_real (list of torch.Tensor): List of discriminator outputs for real samples.
83 | disc_fake (list of torch.Tensor): List of discriminator outputs for generated samples.
84 | scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0.
85 | """
86 | midpoint = len(disc_real) // 2
87 | losses = []
88 | for i, (d_real, d_fake) in enumerate(zip(disc_real, disc_fake)):
89 | real_loss = (1 - d_real).pow(2).mean()
90 | fake_loss = d_fake.pow(2).mean()
91 | total_loss = real_loss + fake_loss
92 | if i >= midpoint:
93 | total_loss *= scale
94 | losses.append(total_loss)
95 | loss = sum(losses)
96 | return loss, None, None
97 |
98 |
99 | def generator_loss_scaled(disc_outputs, scale=1.0):
100 | """
101 | Compute the scaled generator loss based on discriminator outputs.
102 |
103 | Args:
104 | disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples.
105 | scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0.
106 | """
107 | midpoint = len(disc_outputs) // 2
108 | losses = []
109 | for i, d_fake in enumerate(disc_outputs):
110 | loss_value = (1 - d_fake).pow(2).mean()
111 | if i >= midpoint:
112 | loss_value *= scale
113 | losses.append(loss_value)
114 | loss = sum(losses)
115 | return loss, None, None
116 |
117 |
118 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
119 | """
120 | Compute the Kullback-Leibler divergence loss.
121 |
122 | Args:
123 | z_p (torch.Tensor): Latent variable z_p [b, h, t_t].
124 | logs_q (torch.Tensor): Log variance of q [b, h, t_t].
125 | m_p (torch.Tensor): Mean of p [b, h, t_t].
126 | logs_p (torch.Tensor): Log variance of p [b, h, t_t].
127 | z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t].
128 | """
129 | kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2 * logs_p)
130 | kl = (kl * z_mask).sum()
131 | loss = kl / z_mask.sum()
132 | return loss
133 |
--------------------------------------------------------------------------------
/rvc/train/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from librosa.filters import mel as librosa_mel_fn
4 |
5 |
6 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
7 | """
8 | Dynamic range compression using log10.
9 |
10 | Args:
11 | x (torch.Tensor): Input tensor.
12 | C (float, optional): Scaling factor. Defaults to 1.
13 | clip_val (float, optional): Minimum value for clamping. Defaults to 1e-5.
14 | """
15 | return torch.log(torch.clamp(x, min=clip_val) * C)
16 |
17 |
18 | def dynamic_range_decompression_torch(x, C=1):
19 | """
20 | Dynamic range decompression using exp.
21 |
22 | Args:
23 | x (torch.Tensor): Input tensor.
24 | C (float, optional): Scaling factor. Defaults to 1.
25 | """
26 | return torch.exp(x) / C
27 |
28 |
29 | def spectral_normalize_torch(magnitudes):
30 | """
31 | Spectral normalization using dynamic range compression.
32 |
33 | Args:
34 | magnitudes (torch.Tensor): Magnitude spectrogram.
35 | """
36 | return dynamic_range_compression_torch(magnitudes)
37 |
38 |
39 | def spectral_de_normalize_torch(magnitudes):
40 | """
41 | Spectral de-normalization using dynamic range decompression.
42 |
43 | Args:
44 | magnitudes (torch.Tensor): Normalized spectrogram.
45 | """
46 | return dynamic_range_decompression_torch(magnitudes)
47 |
48 |
49 | mel_basis = {}
50 | hann_window = {}
51 |
52 |
53 | def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
54 | """
55 | Compute the spectrogram of a signal using STFT.
56 |
57 | Args:
58 | y (torch.Tensor): Input signal.
59 | n_fft (int): FFT window size.
60 | hop_size (int): Hop size between frames.
61 | win_size (int): Window size.
62 | center (bool, optional): Whether to center the window. Defaults to False.
63 | """
64 | global hann_window
65 | dtype_device = str(y.dtype) + "_" + str(y.device)
66 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
67 | if wnsize_dtype_device not in hann_window:
68 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
69 | dtype=y.dtype, device=y.device
70 | )
71 |
72 | y = torch.nn.functional.pad(
73 | y.unsqueeze(1),
74 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
75 | mode="reflect",
76 | )
77 | y = y.squeeze(1)
78 |
79 | spec = torch.stft(
80 | y,
81 | n_fft=n_fft,
82 | hop_length=hop_size,
83 | win_length=win_size,
84 | window=hann_window[wnsize_dtype_device],
85 | center=center,
86 | pad_mode="reflect",
87 | normalized=False,
88 | onesided=True,
89 | return_complex=True,
90 | )
91 |
92 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
93 |
94 | return spec
95 |
96 |
97 | def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
98 | """
99 | Convert a spectrogram to a mel-spectrogram.
100 |
101 | Args:
102 | spec (torch.Tensor): Magnitude spectrogram.
103 | n_fft (int): FFT window size.
104 | num_mels (int): Number of mel frequency bins.
105 | sample_rate (int): Sampling rate of the audio signal.
106 | fmin (float): Minimum frequency.
107 | fmax (float): Maximum frequency.
108 | """
109 | global mel_basis
110 | dtype_device = str(spec.dtype) + "_" + str(spec.device)
111 | fmax_dtype_device = str(fmax) + "_" + dtype_device
112 | if fmax_dtype_device not in mel_basis:
113 | mel = librosa_mel_fn(
114 | sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
115 | )
116 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
117 | dtype=spec.dtype, device=spec.device
118 | )
119 |
120 | melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
121 | melspec = spectral_normalize_torch(melspec)
122 | return melspec
123 |
124 |
125 | def mel_spectrogram_torch(
126 | y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False
127 | ):
128 | """
129 | Compute the mel-spectrogram of a signal.
130 |
131 | Args:
132 | y (torch.Tensor): Input signal.
133 | n_fft (int): FFT window size.
134 | num_mels (int): Number of mel frequency bins.
135 | sample_rate (int): Sampling rate of the audio signal.
136 | hop_size (int): Hop size between frames.
137 | win_size (int): Window size.
138 | fmin (float): Minimum frequency.
139 | fmax (float): Maximum frequency.
140 | center (bool, optional): Whether to center the window. Defaults to False.
141 | """
142 | spec = spectrogram_torch(y, n_fft, hop_size, win_size, center)
143 |
144 | melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax)
145 |
146 | return melspec
147 |
148 |
149 | def compute_window_length(n_mels: int, sample_rate: int):
150 | f_min = 0
151 | f_max = sample_rate / 2
152 | window_length_seconds = 8 * n_mels / (f_max - f_min)
153 | window_length = int(window_length_seconds * sample_rate)
154 | return 2 ** (window_length.bit_length() - 1)
155 |
156 |
157 | class MultiScaleMelSpectrogramLoss(torch.nn.Module):
158 |
159 | def __init__(
160 | self,
161 | sample_rate: int = 24000,
162 | n_mels: list[int] = [5, 10, 20, 40, 80, 160, 320, 480],
163 | loss_fn=torch.nn.L1Loss(),
164 | ):
165 | super().__init__()
166 | self.sample_rate = sample_rate
167 | self.loss_fn = loss_fn
168 | self.log_base = torch.log(torch.tensor(10.0))
169 | self.stft_params: list[tuple] = []
170 | self.hann_window: dict[int, torch.Tensor] = {}
171 | self.mel_banks: dict[int, torch.Tensor] = {}
172 |
173 | self.stft_params = [
174 | (mel, compute_window_length(mel, sample_rate), self.sample_rate // 100)
175 | for mel in n_mels
176 | ]
177 |
178 | def mel_spectrogram(
179 | self,
180 | wav: torch.Tensor,
181 | n_mels: int,
182 | window_length: int,
183 | hop_length: int,
184 | ):
185 | # IDs for caching
186 | dtype_device = str(wav.dtype) + "_" + str(wav.device)
187 | win_dtype_device = str(window_length) + "_" + dtype_device
188 | mel_dtype_device = str(n_mels) + "_" + dtype_device
189 | # caching hann window
190 | if win_dtype_device not in self.hann_window:
191 | self.hann_window[win_dtype_device] = torch.hann_window(
192 | window_length, device=wav.device, dtype=torch.float32
193 | )
194 |
195 | wav = wav.squeeze(1) # -> torch(B, T)
196 |
197 | stft = torch.stft(
198 | wav.float(),
199 | n_fft=window_length,
200 | hop_length=hop_length,
201 | window=self.hann_window[win_dtype_device],
202 | return_complex=True,
203 | ) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1)
204 |
205 | magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6)
206 |
207 | # caching mel filter
208 | if mel_dtype_device not in self.mel_banks:
209 | self.mel_banks[mel_dtype_device] = torch.from_numpy(
210 | librosa_mel_fn(
211 | sr=self.sample_rate,
212 | n_mels=n_mels,
213 | n_fft=window_length,
214 | fmin=0,
215 | fmax=None,
216 | )
217 | ).to(device=wav.device, dtype=torch.float32)
218 |
219 | mel_spectrogram = torch.matmul(
220 | self.mel_banks[mel_dtype_device], magnitude
221 | ) # torch(B, n_mels, stft.frames)
222 | return mel_spectrogram
223 |
224 | def forward(
225 | self, real: torch.Tensor, fake: torch.Tensor
226 | ): # real: torch(B, 1, T) , fake: torch(B, 1, T)
227 | loss = 0.0
228 | for p in self.stft_params:
229 | real_mels = self.mel_spectrogram(real, *p)
230 | fake_mels = self.mel_spectrogram(fake, *p)
231 | real_logmels = torch.log(real_mels.clamp(min=1e-5)) / self.log_base
232 | fake_logmels = torch.log(fake_mels.clamp(min=1e-5)) / self.log_base
233 | loss += self.loss_fn(real_logmels, fake_logmels)
234 | return loss
235 |
--------------------------------------------------------------------------------
/rvc/train/process/change_info.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | def change_info(path, info, name):
6 | try:
7 | ckpt = torch.load(path, map_location="cpu", weights_only=True)
8 | ckpt["info"] = info
9 |
10 | if not name:
11 | name = os.path.splitext(os.path.basename(path))[0]
12 |
13 | target_dir = os.path.join("logs", name)
14 | os.makedirs(target_dir, exist_ok=True)
15 |
16 | torch.save(ckpt, os.path.join(target_dir, f"{name}.pth"))
17 |
18 | return "Success."
19 |
20 | except Exception as error:
21 | print(f"An error occurred while changing the info: {error}")
22 | return f"Error: {error}"
23 |
--------------------------------------------------------------------------------
/rvc/train/process/extract_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from multiprocessing import cpu_count
4 |
5 | import faiss
6 | import numpy as np
7 | from sklearn.cluster import MiniBatchKMeans
8 |
9 | # Parse command line arguments
10 | exp_dir = str(sys.argv[1])
11 | index_algorithm = str(sys.argv[2])
12 |
13 | try:
14 | feature_dir = os.path.join(exp_dir, f"extracted")
15 | model_name = os.path.basename(exp_dir)
16 |
17 | if not os.path.exists(feature_dir):
18 | print(
19 | f"Feature to generate index file not found at {feature_dir}. Did you run preprocessing and feature extraction steps?"
20 | )
21 | sys.exit(1)
22 |
23 | index_filename_added = f"{model_name}.index"
24 | index_filepath_added = os.path.join(exp_dir, index_filename_added)
25 |
26 | if os.path.exists(index_filepath_added):
27 | pass
28 | else:
29 | npys = []
30 | listdir_res = sorted(os.listdir(feature_dir))
31 |
32 | for name in listdir_res:
33 | file_path = os.path.join(feature_dir, name)
34 | phone = np.load(file_path)
35 | npys.append(phone)
36 |
37 | big_npy = np.concatenate(npys, axis=0)
38 |
39 | big_npy_idx = np.arange(big_npy.shape[0])
40 | np.random.shuffle(big_npy_idx)
41 | big_npy = big_npy[big_npy_idx]
42 |
43 | if big_npy.shape[0] > 2e5 and (
44 | index_algorithm == "Auto" or index_algorithm == "KMeans"
45 | ):
46 | big_npy = (
47 | MiniBatchKMeans(
48 | n_clusters=10000,
49 | verbose=True,
50 | batch_size=256 * cpu_count(),
51 | compute_labels=False,
52 | init="random",
53 | )
54 | .fit(big_npy)
55 | .cluster_centers_
56 | )
57 |
58 | n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
59 |
60 | # index_added
61 | index_added = faiss.index_factory(768, f"IVF{n_ivf},Flat")
62 | index_ivf_added = faiss.extract_index_ivf(index_added)
63 | index_ivf_added.nprobe = 1
64 | index_added.train(big_npy)
65 |
66 | batch_size_add = 8192
67 | for i in range(0, big_npy.shape[0], batch_size_add):
68 | index_added.add(big_npy[i : i + batch_size_add])
69 |
70 | faiss.write_index(index_added, index_filepath_added)
71 | print(f"Saved index file '{index_filepath_added}'")
72 |
73 | except Exception as error:
74 | print(f"An error occurred extracting the index: {error}")
75 | print(
76 | "If you are running this code in a virtual environment, make sure you have enough GPU available to generate the Index file."
77 | )
78 |
--------------------------------------------------------------------------------
/rvc/train/process/extract_model.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import hashlib
3 | import json
4 | import os
5 | import sys
6 | from collections import OrderedDict
7 |
8 | import torch
9 |
10 | now_dir = os.getcwd()
11 | sys.path.append(now_dir)
12 |
13 |
14 | def replace_keys_in_dict(d, old_key_part, new_key_part):
15 | if isinstance(d, OrderedDict):
16 | updated_dict = OrderedDict()
17 | else:
18 | updated_dict = {}
19 | for key, value in d.items():
20 | new_key = key.replace(old_key_part, new_key_part)
21 | if isinstance(value, dict):
22 | value = replace_keys_in_dict(value, old_key_part, new_key_part)
23 | updated_dict[new_key] = value
24 | return updated_dict
25 |
26 |
27 | def extract_model(
28 | ckpt,
29 | sr,
30 | name,
31 | model_path,
32 | epoch,
33 | step,
34 | hps,
35 | overtrain_info,
36 | vocoder,
37 | pitch_guidance=True,
38 | version="v2",
39 | ):
40 | try:
41 | model_dir = os.path.dirname(model_path)
42 | os.makedirs(model_dir, exist_ok=True)
43 |
44 | if os.path.exists(os.path.join(model_dir, "model_info.json")):
45 | with open(os.path.join(model_dir, "model_info.json"), "r") as f:
46 | data = json.load(f)
47 | dataset_length = data.get("total_dataset_duration", None)
48 | embedder_model = data.get("embedder_model", None)
49 | speakers_id = data.get("speakers_id", 1)
50 | else:
51 | dataset_length = None
52 |
53 | with open(os.path.join(now_dir, "assets", "config.json"), "r") as f:
54 | data = json.load(f)
55 | model_author = data.get("model_author", None)
56 |
57 | opt = OrderedDict(
58 | weight={
59 | key: value.half() for key, value in ckpt.items() if "enc_q" not in key
60 | }
61 | )
62 | opt["config"] = [
63 | hps.data.filter_length // 2 + 1,
64 | 32,
65 | hps.model.inter_channels,
66 | hps.model.hidden_channels,
67 | hps.model.filter_channels,
68 | hps.model.n_heads,
69 | hps.model.n_layers,
70 | hps.model.kernel_size,
71 | hps.model.p_dropout,
72 | hps.model.resblock,
73 | hps.model.resblock_kernel_sizes,
74 | hps.model.resblock_dilation_sizes,
75 | hps.model.upsample_rates,
76 | hps.model.upsample_initial_channel,
77 | hps.model.upsample_kernel_sizes,
78 | hps.model.spk_embed_dim,
79 | hps.model.gin_channels,
80 | hps.data.sample_rate,
81 | ]
82 |
83 | opt["epoch"] = epoch
84 | opt["step"] = step
85 | opt["sr"] = sr
86 | opt["f0"] = pitch_guidance
87 | opt["version"] = version
88 | opt["creation_date"] = datetime.datetime.now().isoformat()
89 |
90 | hash_input = f"{name}-{epoch}-{step}-{sr}-{version}-{opt['config']}"
91 | opt["model_hash"] = hashlib.sha256(hash_input.encode()).hexdigest()
92 | opt["overtrain_info"] = overtrain_info
93 | opt["dataset_length"] = dataset_length
94 | opt["model_name"] = name
95 | opt["author"] = model_author
96 | opt["embedder_model"] = embedder_model
97 | opt["speakers_id"] = speakers_id
98 | opt["vocoder"] = vocoder
99 |
100 | torch.save(
101 | replace_keys_in_dict(
102 | replace_keys_in_dict(
103 | opt, ".parametrizations.weight.original1", ".weight_v"
104 | ),
105 | ".parametrizations.weight.original0",
106 | ".weight_g",
107 | ),
108 | model_path,
109 | )
110 |
111 | print(f"Saved model '{model_path}' (epoch {epoch} and step {step})")
112 |
113 | except Exception as error:
114 | print(f"An error occurred extracting the model: {error}")
115 |
--------------------------------------------------------------------------------
/rvc/train/process/model_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 |
5 |
6 | def extract(ckpt):
7 | a = ckpt["model"]
8 | opt = OrderedDict()
9 | opt["weight"] = {}
10 | for key in a.keys():
11 | if "enc_q" in key:
12 | continue
13 | opt["weight"][key] = a[key]
14 | return opt
15 |
16 |
17 | def model_blender(name, path1, path2, ratio):
18 | try:
19 | message = f"Model {path1} and {path2} are merged with alpha {ratio}."
20 | ckpt1 = torch.load(path1, map_location="cpu", weights_only=True)
21 | ckpt2 = torch.load(path2, map_location="cpu", weights_only=True)
22 |
23 | if ckpt1["sr"] != ckpt2["sr"]:
24 | return "The sample rates of the two models are not the same."
25 |
26 | cfg = ckpt1["config"]
27 | cfg_f0 = ckpt1["f0"]
28 | cfg_version = ckpt1["version"]
29 | cfg_sr = ckpt1["sr"]
30 | vocoder = ckpt1.get("vocoder", "HiFi-GAN")
31 |
32 | if "model" in ckpt1:
33 | ckpt1 = extract(ckpt1)
34 | else:
35 | ckpt1 = ckpt1["weight"]
36 | if "model" in ckpt2:
37 | ckpt2 = extract(ckpt2)
38 | else:
39 | ckpt2 = ckpt2["weight"]
40 |
41 | if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
42 | return "Fail to merge the models. The model architectures are not the same."
43 |
44 | opt = OrderedDict()
45 | opt["weight"] = {}
46 | for key in ckpt1.keys():
47 | if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
48 | min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
49 | opt["weight"][key] = (
50 | ratio * (ckpt1[key][:min_shape0].float())
51 | + (1 - ratio) * (ckpt2[key][:min_shape0].float())
52 | ).half()
53 | else:
54 | opt["weight"][key] = (
55 | ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float())
56 | ).half()
57 |
58 | opt["config"] = cfg
59 | opt["sr"] = cfg_sr
60 | opt["f0"] = cfg_f0
61 | opt["version"] = cfg_version
62 | opt["info"] = message
63 | opt["vocoder"] = vocoder
64 |
65 | torch.save(opt, os.path.join("logs", f"{name}.pth"))
66 | print(message)
67 | return message, os.path.join("logs", f"{name}.pth")
68 | except Exception as error:
69 | print(f"An error occurred blending the models: {error}")
70 | return error
71 |
--------------------------------------------------------------------------------
/rvc/train/process/model_information.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 |
4 |
5 | def prettify_date(date_str):
6 | if date_str is None:
7 | return "None"
8 | try:
9 | date_time_obj = datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f")
10 | return date_time_obj.strftime("%Y-%m-%d %H:%M:%S")
11 | except ValueError:
12 | return "Invalid date format"
13 |
14 |
15 | def model_information(path):
16 | model_data = torch.load(path, map_location="cpu", weights_only=True)
17 |
18 | print(f"Loaded model from {path}")
19 |
20 | model_name = model_data.get("model_name", "None")
21 | epochs = model_data.get("epoch", "None")
22 | steps = model_data.get("step", "None")
23 | sr = model_data.get("sr", "None")
24 | f0 = model_data.get("f0", "None")
25 | dataset_length = model_data.get("dataset_length", "None")
26 | vocoder = model_data.get("vocoder", "None")
27 | creation_date = model_data.get("creation_date", "None")
28 | model_hash = model_data.get("model_hash", None)
29 | overtrain_info = model_data.get("overtrain_info", "None")
30 | model_author = model_data.get("author", "None")
31 | embedder_model = model_data.get("embedder_model", "None")
32 | speakers_id = model_data.get("speakers_id", 0)
33 |
34 | creation_date_str = prettify_date(creation_date) if creation_date else "None"
35 |
36 | return (
37 | f"Model Name: {model_name}\n"
38 | f"Model Creator: {model_author}\n"
39 | f"Epochs: {epochs}\n"
40 | f"Steps: {steps}\n"
41 | f"Vocoder: {vocoder}\n"
42 | f"Sampling Rate: {sr}\n"
43 | f"Dataset Length: {dataset_length}\n"
44 | f"Creation Date: {creation_date_str}\n"
45 | f"Overtrain Info: {overtrain_info}\n"
46 | f"Embedder Model: {embedder_model}\n"
47 | f"Max Speakers ID: {speakers_id}"
48 | f"Hash: {model_hash}\n"
49 | )
50 |
--------------------------------------------------------------------------------
/uvr/__init__.py:
--------------------------------------------------------------------------------
1 | from .separator import Separator
2 |
--------------------------------------------------------------------------------
/uvr/architectures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/architectures/__init__.py
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/uvr_lib_v5/__init__.py
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from einops import rearrange, reduce
10 |
11 | # constants
12 |
13 | FlashAttentionConfig = namedtuple(
14 | "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15 | )
16 |
17 | # helpers
18 |
19 |
20 | def exists(val):
21 | return val is not None
22 |
23 |
24 | def once(fn):
25 | called = False
26 |
27 | @wraps(fn)
28 | def inner(x):
29 | nonlocal called
30 | if called:
31 | return
32 | called = True
33 | return fn(x)
34 |
35 | return inner
36 |
37 |
38 | print_once = once(print)
39 |
40 | # main class
41 |
42 |
43 | class Attend(nn.Module):
44 | def __init__(self, dropout=0.0, flash=False):
45 | super().__init__()
46 | self.dropout = dropout
47 | self.attn_dropout = nn.Dropout(dropout)
48 |
49 | self.flash = flash
50 | assert not (
51 | flash and version.parse(torch.__version__) < version.parse("2.0.0")
52 | ), "in order to use flash attention, you must be using pytorch 2.0 or above"
53 |
54 | # determine efficient attention configs for cuda and cpu
55 |
56 | self.cpu_config = FlashAttentionConfig(True, True, True)
57 | self.cuda_config = None
58 |
59 | if not torch.cuda.is_available() or not flash:
60 | return
61 |
62 | device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
63 |
64 | if device_properties.major == 8 and device_properties.minor == 0:
65 | print_once(
66 | "A100 GPU detected, using flash attention if input tensor is on cuda"
67 | )
68 | self.cuda_config = FlashAttentionConfig(True, False, False)
69 | else:
70 | self.cuda_config = FlashAttentionConfig(False, True, True)
71 |
72 | def flash_attn(self, q, k, v):
73 | _, heads, q_len, _, k_len, is_cuda, device = (
74 | *q.shape,
75 | k.shape[-2],
76 | q.is_cuda,
77 | q.device,
78 | )
79 |
80 | # Check if there is a compatible device for flash attention
81 |
82 | config = self.cuda_config if is_cuda else self.cpu_config
83 |
84 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
85 |
86 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
87 | out = F.scaled_dot_product_attention(
88 | q, k, v, dropout_p=self.dropout if self.training else 0.0
89 | )
90 |
91 | return out
92 |
93 | def forward(self, q, k, v):
94 | """
95 | einstein notation
96 | b - batch
97 | h - heads
98 | n, i, j - sequence length (base sequence length, source, target)
99 | d - feature dimension
100 | """
101 |
102 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
103 |
104 | scale = q.shape[-1] ** -0.5
105 |
106 | if self.flash:
107 | return self.flash_attn(q, k, v)
108 |
109 | # similarity
110 |
111 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
112 |
113 | # attention
114 |
115 | attn = sim.softmax(dim=-1)
116 | attn = self.attn_dropout(attn)
117 |
118 | # aggregate values
119 |
120 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
121 |
122 | return out
123 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/model_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | import julius
10 | from torch import nn
11 | from .tasnet_v2 import ConvTasNet
12 |
13 | from .utils import capture_init, center_trim
14 |
15 |
16 | class BLSTM(nn.Module):
17 | def __init__(self, dim, layers=1):
18 | super().__init__()
19 | self.lstm = nn.LSTM(
20 | bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim
21 | )
22 | self.linear = nn.Linear(2 * dim, dim)
23 |
24 | def forward(self, x):
25 | x = x.permute(2, 0, 1)
26 | x = self.lstm(x)[0]
27 | x = self.linear(x)
28 | x = x.permute(1, 2, 0)
29 | return x
30 |
31 |
32 | def rescale_conv(conv, reference):
33 | std = conv.weight.std().detach()
34 | scale = (std / reference) ** 0.5
35 | conv.weight.data /= scale
36 | if conv.bias is not None:
37 | conv.bias.data /= scale
38 |
39 |
40 | def rescale_module(module, reference):
41 | for sub in module.modules():
42 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
43 | rescale_conv(sub, reference)
44 |
45 |
46 | def auto_load_demucs_model_v2(sources, demucs_model_name):
47 |
48 | if "48" in demucs_model_name:
49 | channels = 48
50 | elif "unittest" in demucs_model_name:
51 | channels = 4
52 | else:
53 | channels = 64
54 |
55 | if "tasnet" in demucs_model_name:
56 | init_demucs_model = ConvTasNet(sources, X=10)
57 | else:
58 | init_demucs_model = Demucs(sources, channels=channels)
59 |
60 | return init_demucs_model
61 |
62 |
63 | class Demucs(nn.Module):
64 | @capture_init
65 | def __init__(
66 | self,
67 | sources,
68 | audio_channels=2,
69 | channels=64,
70 | depth=6,
71 | rewrite=True,
72 | glu=True,
73 | rescale=0.1,
74 | resample=True,
75 | kernel_size=8,
76 | stride=4,
77 | growth=2.0,
78 | lstm_layers=2,
79 | context=3,
80 | normalize=False,
81 | samplerate=44100,
82 | segment_length=4 * 10 * 44100,
83 | ):
84 | """
85 | Args:
86 | sources (list[str]): list of source names
87 | audio_channels (int): stereo or mono
88 | channels (int): first convolution channels
89 | depth (int): number of encoder/decoder layers
90 | rewrite (bool): add 1x1 convolution to each encoder layer
91 | and a convolution to each decoder layer.
92 | For the decoder layer, `context` gives the kernel size.
93 | glu (bool): use glu instead of ReLU
94 | resample_input (bool): upsample x2 the input and downsample /2 the output.
95 | rescale (int): rescale initial weights of convolutions
96 | to get their standard deviation closer to `rescale`
97 | kernel_size (int): kernel size for convolutions
98 | stride (int): stride for convolutions
99 | growth (float): multiply (resp divide) number of channels by that
100 | for each layer of the encoder (resp decoder)
101 | lstm_layers (int): number of lstm layers, 0 = no lstm
102 | context (int): kernel size of the convolution in the
103 | decoder before the transposed convolution. If > 1,
104 | will provide some context from neighboring time
105 | steps.
106 | samplerate (int): stored as meta information for easing
107 | future evaluations of the model.
108 | segment_length (int): stored as meta information for easing
109 | future evaluations of the model. Length of the segments on which
110 | the model was trained.
111 | """
112 |
113 | super().__init__()
114 | self.audio_channels = audio_channels
115 | self.sources = sources
116 | self.kernel_size = kernel_size
117 | self.context = context
118 | self.stride = stride
119 | self.depth = depth
120 | self.resample = resample
121 | self.channels = channels
122 | self.normalize = normalize
123 | self.samplerate = samplerate
124 | self.segment_length = segment_length
125 |
126 | self.encoder = nn.ModuleList()
127 | self.decoder = nn.ModuleList()
128 |
129 | if glu:
130 | activation = nn.GLU(dim=1)
131 | ch_scale = 2
132 | else:
133 | activation = nn.ReLU()
134 | ch_scale = 1
135 | in_channels = audio_channels
136 | for index in range(depth):
137 | encode = []
138 | encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
139 | if rewrite:
140 | encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
141 | self.encoder.append(nn.Sequential(*encode))
142 |
143 | decode = []
144 | if index > 0:
145 | out_channels = in_channels
146 | else:
147 | out_channels = len(self.sources) * audio_channels
148 | if rewrite:
149 | decode += [
150 | nn.Conv1d(channels, ch_scale * channels, context),
151 | activation,
152 | ]
153 | decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
154 | if index > 0:
155 | decode.append(nn.ReLU())
156 | self.decoder.insert(0, nn.Sequential(*decode))
157 | in_channels = channels
158 | channels = int(growth * channels)
159 |
160 | channels = in_channels
161 |
162 | if lstm_layers:
163 | self.lstm = BLSTM(channels, lstm_layers)
164 | else:
165 | self.lstm = None
166 |
167 | if rescale:
168 | rescale_module(self, reference=rescale)
169 |
170 | def valid_length(self, length):
171 | """
172 | Return the nearest valid length to use with the model so that
173 | there is no time steps left over in a convolutions, e.g. for all
174 | layers, size of the input - kernel_size % stride = 0.
175 |
176 | If the mixture has a valid length, the estimated sources
177 | will have exactly the same length when context = 1. If context > 1,
178 | the two signals can be center trimmed to match.
179 |
180 | For training, extracts should have a valid length.For evaluation
181 | on full tracks we recommend passing `pad = True` to :method:`forward`.
182 | """
183 | if self.resample:
184 | length *= 2
185 | for _ in range(self.depth):
186 | length = math.ceil((length - self.kernel_size) / self.stride) + 1
187 | length = max(1, length)
188 | length += self.context - 1
189 | for _ in range(self.depth):
190 | length = (length - 1) * self.stride + self.kernel_size
191 |
192 | if self.resample:
193 | length = math.ceil(length / 2)
194 | return int(length)
195 |
196 | def forward(self, mix):
197 | x = mix
198 |
199 | if self.normalize:
200 | mono = mix.mean(dim=1, keepdim=True)
201 | mean = mono.mean(dim=-1, keepdim=True)
202 | std = mono.std(dim=-1, keepdim=True)
203 | else:
204 | mean = 0
205 | std = 1
206 |
207 | x = (x - mean) / (1e-5 + std)
208 |
209 | if self.resample:
210 | x = julius.resample_frac(x, 1, 2)
211 |
212 | saved = []
213 | for encode in self.encoder:
214 | x = encode(x)
215 | saved.append(x)
216 | if self.lstm:
217 | x = self.lstm(x)
218 | for decode in self.decoder:
219 | skip = center_trim(saved.pop(-1), x)
220 | x = x + skip
221 | x = decode(x)
222 |
223 | if self.resample:
224 | x = julius.resample_frac(x, 2, 1)
225 | x = x * std + mean
226 | x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
227 | return x
228 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/pretrained.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Loading pretrained models.
7 | """
8 |
9 | import logging
10 | from pathlib import Path
11 | import typing as tp
12 |
13 | # from dora.log import fatal
14 |
15 | import logging
16 |
17 | from diffq import DiffQuantizer
18 | import torch.hub
19 |
20 | from .model import Demucs
21 | from .tasnet_v2 import ConvTasNet
22 | from .utils import set_state
23 |
24 | from .hdemucs import HDemucs
25 | from .repo import (
26 | RemoteRepo,
27 | LocalRepo,
28 | ModelOnlyRepo,
29 | BagOnlyRepo,
30 | AnyModelRepo,
31 | ModelLoadingError,
32 | ) # noqa
33 |
34 | logger = logging.getLogger(__name__)
35 | ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
36 | REMOTE_ROOT = Path(__file__).parent / "remote"
37 |
38 | SOURCES = ["drums", "bass", "other", "vocals"]
39 |
40 |
41 | def demucs_unittest():
42 | model = HDemucs(channels=4, sources=SOURCES)
43 | return model
44 |
45 |
46 | def add_model_flags(parser):
47 | group = parser.add_mutually_exclusive_group(required=False)
48 | group.add_argument("-s", "--sig", help="Locally trained XP signature.")
49 | group.add_argument(
50 | "-n",
51 | "--name",
52 | default="mdx_extra_q",
53 | help="Pretrained model name or signature. Default is mdx_extra_q.",
54 | )
55 | parser.add_argument(
56 | "--repo",
57 | type=Path,
58 | help="Folder containing all pre-trained models for use with -n.",
59 | )
60 |
61 |
62 | def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
63 | root: str = ""
64 | models: tp.Dict[str, str] = {}
65 | for line in remote_file_list.read_text().split("\n"):
66 | line = line.strip()
67 | if line.startswith("#"):
68 | continue
69 | elif line.startswith("root:"):
70 | root = line.split(":", 1)[1].strip()
71 | else:
72 | sig = line.split("-", 1)[0]
73 | assert sig not in models
74 | models[sig] = ROOT_URL + root + line
75 | return models
76 |
77 |
78 | def get_model(name: str, repo: tp.Optional[Path] = None):
79 | """`name` must be a bag of models name or a pretrained signature
80 | from the remote AWS model repo or the specified local repo if `repo` is not None.
81 | """
82 | if name == "demucs_unittest":
83 | return demucs_unittest()
84 | model_repo: ModelOnlyRepo
85 | if repo is None:
86 | models = _parse_remote_files(REMOTE_ROOT / "files.txt")
87 | model_repo = RemoteRepo(models)
88 | bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
89 | else:
90 | if not repo.is_dir():
91 | fatal(f"{repo} must exist and be a directory.")
92 | model_repo = LocalRepo(repo)
93 | bag_repo = BagOnlyRepo(repo, model_repo)
94 | any_repo = AnyModelRepo(model_repo, bag_repo)
95 | model = any_repo.get_model(name)
96 | model.eval()
97 | return model
98 |
99 |
100 | def get_model_from_args(args):
101 | """
102 | Load local model package or pre-trained model.
103 | """
104 | return get_model(name=args.name, repo=args.repo)
105 |
106 |
107 | logger = logging.getLogger(__name__)
108 | ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
109 |
110 | PRETRAINED_MODELS = {
111 | "demucs": "e07c671f",
112 | "demucs48_hq": "28a1282c",
113 | "demucs_extra": "3646af93",
114 | "demucs_quantized": "07afea75",
115 | "tasnet": "beb46fac",
116 | "tasnet_extra": "df3777b2",
117 | "demucs_unittest": "09ebc15f",
118 | }
119 |
120 | SOURCES = ["drums", "bass", "other", "vocals"]
121 |
122 |
123 | def get_url(name):
124 | sig = PRETRAINED_MODELS[name]
125 | return ROOT + name + "-" + sig[:8] + ".th"
126 |
127 |
128 | def is_pretrained(name):
129 | return name in PRETRAINED_MODELS
130 |
131 |
132 | def load_pretrained(name):
133 | if name == "demucs":
134 | return demucs(pretrained=True)
135 | elif name == "demucs48_hq":
136 | return demucs(pretrained=True, hq=True, channels=48)
137 | elif name == "demucs_extra":
138 | return demucs(pretrained=True, extra=True)
139 | elif name == "demucs_quantized":
140 | return demucs(pretrained=True, quantized=True)
141 | elif name == "demucs_unittest":
142 | return demucs_unittest(pretrained=True)
143 | elif name == "tasnet":
144 | return tasnet(pretrained=True)
145 | elif name == "tasnet_extra":
146 | return tasnet(pretrained=True, extra=True)
147 | else:
148 | raise ValueError(f"Invalid pretrained name {name}")
149 |
150 |
151 | def _load_state(name, model, quantizer=None):
152 | url = get_url(name)
153 | state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
154 | set_state(model, quantizer, state)
155 | if quantizer:
156 | quantizer.detach()
157 |
158 |
159 | def demucs_unittest(pretrained=True):
160 | model = Demucs(channels=4, sources=SOURCES)
161 | if pretrained:
162 | _load_state("demucs_unittest", model)
163 | return model
164 |
165 |
166 | def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
167 | if not pretrained and (extra or quantized or hq):
168 | raise ValueError("if extra or quantized is True, pretrained must be True.")
169 | model = Demucs(sources=SOURCES, channels=channels)
170 | if pretrained:
171 | name = "demucs"
172 | if channels != 64:
173 | name += str(channels)
174 | quantizer = None
175 | if sum([extra, quantized, hq]) > 1:
176 | raise ValueError("Only one of extra, quantized, hq, can be True.")
177 | if quantized:
178 | quantizer = DiffQuantizer(model, group_size=8, min_size=1)
179 | name += "_quantized"
180 | if extra:
181 | name += "_extra"
182 | if hq:
183 | name += "_hq"
184 | _load_state(name, model, quantizer)
185 | return model
186 |
187 |
188 | def tasnet(pretrained=True, extra=False):
189 | if not pretrained and extra:
190 | raise ValueError("if extra is True, pretrained must be True.")
191 | model = ConvTasNet(X=10, sources=SOURCES)
192 | if pretrained:
193 | name = "tasnet"
194 | if extra:
195 | name = "tasnet_extra"
196 | _load_state(name, model)
197 | return model
198 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/repo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Represents a model repository, including pre-trained models and bags of models.
7 | A repo can either be the main remote repository stored in AWS, or a local repository
8 | with your own models.
9 | """
10 |
11 | from hashlib import sha256
12 | from pathlib import Path
13 | import typing as tp
14 |
15 | import torch
16 | import yaml
17 |
18 | from .apply import BagOfModels, Model
19 | from .states import load_model
20 |
21 |
22 | AnyModel = tp.Union[Model, BagOfModels]
23 |
24 |
25 | class ModelLoadingError(RuntimeError):
26 | pass
27 |
28 |
29 | def check_checksum(path: Path, checksum: str):
30 | sha = sha256()
31 | with open(path, "rb") as file:
32 | while True:
33 | buf = file.read(2**20)
34 | if not buf:
35 | break
36 | sha.update(buf)
37 | actual_checksum = sha.hexdigest()[: len(checksum)]
38 | if actual_checksum != checksum:
39 | raise ModelLoadingError(
40 | f"Invalid checksum for file {path}, "
41 | f"expected {checksum} but got {actual_checksum}"
42 | )
43 |
44 |
45 | class ModelOnlyRepo:
46 | """Base class for all model only repos."""
47 |
48 | def has_model(self, sig: str) -> bool:
49 | raise NotImplementedError()
50 |
51 | def get_model(self, sig: str) -> Model:
52 | raise NotImplementedError()
53 |
54 |
55 | class RemoteRepo(ModelOnlyRepo):
56 | def __init__(self, models: tp.Dict[str, str]):
57 | self._models = models
58 |
59 | def has_model(self, sig: str) -> bool:
60 | return sig in self._models
61 |
62 | def get_model(self, sig: str) -> Model:
63 | try:
64 | url = self._models[sig]
65 | except KeyError:
66 | raise ModelLoadingError(
67 | f"Could not find a pre-trained model with signature {sig}."
68 | )
69 | pkg = torch.hub.load_state_dict_from_url(
70 | url, map_location="cpu", check_hash=True
71 | )
72 | return load_model(pkg)
73 |
74 |
75 | class LocalRepo(ModelOnlyRepo):
76 | def __init__(self, root: Path):
77 | self.root = root
78 | self.scan()
79 |
80 | def scan(self):
81 | self._models = {}
82 | self._checksums = {}
83 | for file in self.root.iterdir():
84 | if file.suffix == ".th":
85 | if "-" in file.stem:
86 | xp_sig, checksum = file.stem.split("-")
87 | self._checksums[xp_sig] = checksum
88 | else:
89 | xp_sig = file.stem
90 | if xp_sig in self._models:
91 | print("Whats xp? ", xp_sig)
92 | raise ModelLoadingError(
93 | f"Duplicate pre-trained model exist for signature {xp_sig}. "
94 | "Please delete all but one."
95 | )
96 | self._models[xp_sig] = file
97 |
98 | def has_model(self, sig: str) -> bool:
99 | return sig in self._models
100 |
101 | def get_model(self, sig: str) -> Model:
102 | try:
103 | file = self._models[sig]
104 | except KeyError:
105 | raise ModelLoadingError(
106 | f"Could not find pre-trained model with signature {sig}."
107 | )
108 | if sig in self._checksums:
109 | check_checksum(file, self._checksums[sig])
110 | return load_model(file)
111 |
112 |
113 | class BagOnlyRepo:
114 | """Handles only YAML files containing bag of models, leaving the actual
115 | model loading to some Repo.
116 | """
117 |
118 | def __init__(self, root: Path, model_repo: ModelOnlyRepo):
119 | self.root = root
120 | self.model_repo = model_repo
121 | self.scan()
122 |
123 | def scan(self):
124 | self._bags = {}
125 | for file in self.root.iterdir():
126 | if file.suffix == ".yaml":
127 | self._bags[file.stem] = file
128 |
129 | def has_model(self, name: str) -> bool:
130 | return name in self._bags
131 |
132 | def get_model(self, name: str) -> BagOfModels:
133 | try:
134 | yaml_file = self._bags[name]
135 | except KeyError:
136 | raise ModelLoadingError(
137 | f"{name} is neither a single pre-trained model or " "a bag of models."
138 | )
139 | bag = yaml.safe_load(open(yaml_file))
140 | signatures = bag["models"]
141 | models = [self.model_repo.get_model(sig) for sig in signatures]
142 | weights = bag.get("weights")
143 | segment = bag.get("segment")
144 | return BagOfModels(models, weights, segment)
145 |
146 |
147 | class AnyModelRepo:
148 | def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
149 | self.model_repo = model_repo
150 | self.bag_repo = bag_repo
151 |
152 | def has_model(self, name_or_sig: str) -> bool:
153 | return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(
154 | name_or_sig
155 | )
156 |
157 | def get_model(self, name_or_sig: str) -> AnyModel:
158 | # print('name_or_sig: ', name_or_sig)
159 | if self.model_repo.has_model(name_or_sig):
160 | return self.model_repo.get_model(name_or_sig)
161 | else:
162 | return self.bag_repo.get_model(name_or_sig)
163 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/spec.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Conveniance wrapper to perform STFT and iSTFT"""
7 |
8 | import torch as th
9 |
10 |
11 | def spectro(x, n_fft=512, hop_length=None, pad=0):
12 | *other, length = x.shape
13 | x = x.reshape(-1, length)
14 |
15 | device_type = x.device.type
16 | is_other_gpu = not device_type in ["cuda", "cpu"]
17 |
18 | if is_other_gpu:
19 | x = x.cpu()
20 | z = th.stft(
21 | x,
22 | n_fft * (1 + pad),
23 | hop_length or n_fft // 4,
24 | window=th.hann_window(n_fft).to(x),
25 | win_length=n_fft,
26 | normalized=True,
27 | center=True,
28 | return_complex=True,
29 | pad_mode="reflect",
30 | )
31 | _, freqs, frame = z.shape
32 | return z.view(*other, freqs, frame)
33 |
34 |
35 | def ispectro(z, hop_length=None, length=None, pad=0):
36 | *other, freqs, frames = z.shape
37 | n_fft = 2 * freqs - 2
38 | z = z.view(-1, freqs, frames)
39 | win_length = n_fft // (1 + pad)
40 |
41 | device_type = z.device.type
42 | is_other_gpu = not device_type in ["cuda", "cpu"]
43 |
44 | if is_other_gpu:
45 | z = z.cpu()
46 | x = th.istft(
47 | z,
48 | n_fft,
49 | hop_length,
50 | window=th.hann_window(win_length).to(z.real),
51 | win_length=win_length,
52 | normalized=True,
53 | length=length,
54 | center=True,
55 | )
56 | _, length = x.shape
57 | return x.view(*other, length)
58 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/demucs/states.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Utilities to save and load models.
8 | """
9 | from contextlib import contextmanager
10 |
11 | import functools
12 | import hashlib
13 | import inspect
14 | import io
15 | from pathlib import Path
16 | import warnings
17 |
18 | from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
19 | import torch
20 |
21 |
22 | def get_quantizer(model, args, optimizer=None):
23 | """Return the quantizer given the XP quantization args."""
24 | quantizer = None
25 | if args.diffq:
26 | quantizer = DiffQuantizer(
27 | model, min_size=args.min_size, group_size=args.group_size
28 | )
29 | if optimizer is not None:
30 | quantizer.setup_optimizer(optimizer)
31 | elif args.qat:
32 | quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size)
33 | return quantizer
34 |
35 |
36 | def load_model(path_or_package, strict=False):
37 | """Load a model from the given serialized model, either given as a dict (already loaded)
38 | or a path to a file on disk."""
39 | if isinstance(path_or_package, dict):
40 | package = path_or_package
41 | elif isinstance(path_or_package, (str, Path)):
42 | with warnings.catch_warnings():
43 | warnings.simplefilter("ignore")
44 | path = path_or_package
45 | package = torch.load(path, "cpu")
46 | else:
47 | raise ValueError(f"Invalid type for {path_or_package}.")
48 |
49 | klass = package["klass"]
50 | args = package["args"]
51 | kwargs = package["kwargs"]
52 |
53 | if strict:
54 | model = klass(*args, **kwargs)
55 | else:
56 | sig = inspect.signature(klass)
57 | for key in list(kwargs):
58 | if key not in sig.parameters:
59 | warnings.warn("Dropping inexistant parameter " + key)
60 | del kwargs[key]
61 | model = klass(*args, **kwargs)
62 |
63 | state = package["state"]
64 |
65 | set_state(model, state)
66 | return model
67 |
68 |
69 | def get_state(model, quantizer, half=False):
70 | """Get the state from a model, potentially with quantization applied.
71 | If `half` is True, model are stored as half precision, which shouldn't impact performance
72 | but half the state size."""
73 | if quantizer is None:
74 | dtype = torch.half if half else None
75 | state = {
76 | k: p.data.to(device="cpu", dtype=dtype)
77 | for k, p in model.state_dict().items()
78 | }
79 | else:
80 | state = quantizer.get_quantized_state()
81 | state["__quantized"] = True
82 | return state
83 |
84 |
85 | def set_state(model, state, quantizer=None):
86 | """Set the state on a given model."""
87 | if state.get("__quantized"):
88 | if quantizer is not None:
89 | quantizer.restore_quantized_state(model, state["quantized"])
90 | else:
91 | restore_quantized_state(model, state)
92 | else:
93 | model.load_state_dict(state)
94 | return state
95 |
96 |
97 | def save_with_checksum(content, path):
98 | """Save the given value on disk, along with a sha256 hash.
99 | Should be used with the output of either `serialize_model` or `get_state`."""
100 | buf = io.BytesIO()
101 | torch.save(content, buf)
102 | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
103 |
104 | path = path.parent / (path.stem + "-" + sig + path.suffix)
105 | path.write_bytes(buf.getvalue())
106 |
107 |
108 | def copy_state(state):
109 | return {k: v.cpu().clone() for k, v in state.items()}
110 |
111 |
112 | @contextmanager
113 | def swap_state(model, state):
114 | """
115 | Context manager that swaps the state of a model, e.g:
116 |
117 | # model is in old state
118 | with swap_state(model, new_state):
119 | # model in new state
120 | # model back to old state
121 | """
122 | old_state = copy_state(model.state_dict())
123 | model.load_state_dict(state, strict=False)
124 | try:
125 | yield
126 | finally:
127 | model.load_state_dict(old_state)
128 |
129 |
130 | def capture_init(init):
131 | @functools.wraps(init)
132 | def __init__(self, *args, **kwargs):
133 | self._init_args_kwargs = (args, kwargs)
134 | init(self, *args, **kwargs)
135 |
136 | return __init__
137 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/mdxnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .modules import TFC_TDF
4 | from pytorch_lightning import LightningModule
5 |
6 | dim_s = 4
7 |
8 |
9 | class AbstractMDXNet(LightningModule):
10 | def __init__(
11 | self,
12 | target_name,
13 | lr,
14 | optimizer,
15 | dim_c,
16 | dim_f,
17 | dim_t,
18 | n_fft,
19 | hop_length,
20 | overlap,
21 | ):
22 | super().__init__()
23 | self.target_name = target_name
24 | self.lr = lr
25 | self.optimizer = optimizer
26 | self.dim_c = dim_c
27 | self.dim_f = dim_f
28 | self.dim_t = dim_t
29 | self.n_fft = n_fft
30 | self.n_bins = n_fft // 2 + 1
31 | self.hop_length = hop_length
32 | self.window = nn.Parameter(
33 | torch.hann_window(window_length=self.n_fft, periodic=True),
34 | requires_grad=False,
35 | )
36 | self.freq_pad = nn.Parameter(
37 | torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]),
38 | requires_grad=False,
39 | )
40 |
41 | def get_optimizer(self):
42 | if self.optimizer == "rmsprop":
43 | return torch.optim.RMSprop(self.parameters(), self.lr)
44 |
45 | if self.optimizer == "adamw":
46 | return torch.optim.AdamW(self.parameters(), self.lr)
47 |
48 |
49 | class ConvTDFNet(AbstractMDXNet):
50 | def __init__(
51 | self,
52 | target_name,
53 | lr,
54 | optimizer,
55 | dim_c,
56 | dim_f,
57 | dim_t,
58 | n_fft,
59 | hop_length,
60 | num_blocks,
61 | l,
62 | g,
63 | k,
64 | bn,
65 | bias,
66 | overlap,
67 | ):
68 |
69 | super(ConvTDFNet, self).__init__(
70 | target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap
71 | )
72 | # self.save_hyperparameters()
73 |
74 | self.num_blocks = num_blocks
75 | self.l = l
76 | self.g = g
77 | self.k = k
78 | self.bn = bn
79 | self.bias = bias
80 |
81 | if optimizer == "rmsprop":
82 | norm = nn.BatchNorm2d
83 |
84 | if optimizer == "adamw":
85 | norm = lambda input: nn.GroupNorm(2, input)
86 |
87 | self.n = num_blocks // 2
88 | scale = (2, 2)
89 |
90 | self.first_conv = nn.Sequential(
91 | nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
92 | norm(g),
93 | nn.ReLU(),
94 | )
95 |
96 | f = self.dim_f
97 | c = g
98 | self.encoding_blocks = nn.ModuleList()
99 | self.ds = nn.ModuleList()
100 | for i in range(self.n):
101 | self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
102 | self.ds.append(
103 | nn.Sequential(
104 | nn.Conv2d(
105 | in_channels=c,
106 | out_channels=c + g,
107 | kernel_size=scale,
108 | stride=scale,
109 | ),
110 | norm(c + g),
111 | nn.ReLU(),
112 | )
113 | )
114 | f = f // 2
115 | c += g
116 |
117 | self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
118 |
119 | self.decoding_blocks = nn.ModuleList()
120 | self.us = nn.ModuleList()
121 | for i in range(self.n):
122 | self.us.append(
123 | nn.Sequential(
124 | nn.ConvTranspose2d(
125 | in_channels=c,
126 | out_channels=c - g,
127 | kernel_size=scale,
128 | stride=scale,
129 | ),
130 | norm(c - g),
131 | nn.ReLU(),
132 | )
133 | )
134 | f = f * 2
135 | c -= g
136 |
137 | self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
138 |
139 | self.final_conv = nn.Sequential(
140 | nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
141 | )
142 |
143 | def forward(self, x):
144 |
145 | x = self.first_conv(x)
146 |
147 | x = x.transpose(-1, -2)
148 |
149 | ds_outputs = []
150 | for i in range(self.n):
151 | x = self.encoding_blocks[i](x)
152 | ds_outputs.append(x)
153 | x = self.ds[i](x)
154 |
155 | x = self.bottleneck_block(x)
156 |
157 | for i in range(self.n):
158 | x = self.us[i](x)
159 | x *= ds_outputs[-i - 1]
160 | x = self.decoding_blocks[i](x)
161 |
162 | x = x.transpose(-1, -2)
163 |
164 | x = self.final_conv(x)
165 |
166 | return x
167 |
168 |
169 | class Mixer(nn.Module):
170 | def __init__(self, device, mixer_path):
171 |
172 | super(Mixer, self).__init__()
173 |
174 | self.linear = nn.Linear((dim_s + 1) * 2, dim_s * 2, bias=False)
175 |
176 | self.load_state_dict(torch.load(mixer_path, map_location=device))
177 |
178 | def forward(self, x):
179 | x = x.reshape(1, (dim_s + 1) * 2, -1).transpose(-1, -2)
180 | x = self.linear(x)
181 | return x.transpose(-1, -2).reshape(dim_s, 2, -1)
182 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/mixer.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/uvr_lib_v5/mixer.ckpt
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class TFC(nn.Module):
6 | def __init__(self, c, l, k, norm):
7 | super(TFC, self).__init__()
8 |
9 | self.H = nn.ModuleList()
10 | for i in range(l):
11 | self.H.append(
12 | nn.Sequential(
13 | nn.Conv2d(
14 | in_channels=c,
15 | out_channels=c,
16 | kernel_size=k,
17 | stride=1,
18 | padding=k // 2,
19 | ),
20 | norm(c),
21 | nn.ReLU(),
22 | )
23 | )
24 |
25 | def forward(self, x):
26 | for h in self.H:
27 | x = h(x)
28 | return x
29 |
30 |
31 | class DenseTFC(nn.Module):
32 | def __init__(self, c, l, k, norm):
33 | super(DenseTFC, self).__init__()
34 |
35 | self.conv = nn.ModuleList()
36 | for i in range(l):
37 | self.conv.append(
38 | nn.Sequential(
39 | nn.Conv2d(
40 | in_channels=c,
41 | out_channels=c,
42 | kernel_size=k,
43 | stride=1,
44 | padding=k // 2,
45 | ),
46 | norm(c),
47 | nn.ReLU(),
48 | )
49 | )
50 |
51 | def forward(self, x):
52 | for layer in self.conv[:-1]:
53 | x = torch.cat([layer(x), x], 1)
54 | return self.conv[-1](x)
55 |
56 |
57 | class TFC_TDF(nn.Module):
58 | def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
59 |
60 | super(TFC_TDF, self).__init__()
61 |
62 | self.use_tdf = bn is not None
63 |
64 | self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
65 |
66 | if self.use_tdf:
67 | if bn == 0:
68 | self.tdf = nn.Sequential(nn.Linear(f, f, bias=bias), norm(c), nn.ReLU())
69 | else:
70 | self.tdf = nn.Sequential(
71 | nn.Linear(f, f // bn, bias=bias),
72 | norm(c),
73 | nn.ReLU(),
74 | nn.Linear(f // bn, f, bias=bias),
75 | norm(c),
76 | nn.ReLU(),
77 | )
78 |
79 | def forward(self, x):
80 | x = self.tfc(x)
81 | return x + self.tdf(x) if self.use_tdf else x
82 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/pyrb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import tempfile
4 | import six
5 | import numpy as np
6 | import soundfile as sf
7 | import sys
8 |
9 | if getattr(sys, "frozen", False):
10 | BASE_PATH_RUB = sys._MEIPASS
11 | else:
12 | BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__))
13 |
14 | __all__ = ["time_stretch", "pitch_shift"]
15 |
16 | __RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, "rubberband")
17 |
18 | if six.PY2:
19 | DEVNULL = open(os.devnull, "w")
20 | else:
21 | DEVNULL = subprocess.DEVNULL
22 |
23 |
24 | def __rubberband(y, sr, **kwargs):
25 |
26 | assert sr > 0
27 |
28 | # Get the input and output tempfile
29 | fd, infile = tempfile.mkstemp(suffix=".wav")
30 | os.close(fd)
31 | fd, outfile = tempfile.mkstemp(suffix=".wav")
32 | os.close(fd)
33 |
34 | # dump the audio
35 | sf.write(infile, y, sr)
36 |
37 | try:
38 | # Execute rubberband
39 | arguments = [__RUBBERBAND_UTIL, "-q"]
40 |
41 | for key, value in six.iteritems(kwargs):
42 | arguments.append(str(key))
43 | arguments.append(str(value))
44 |
45 | arguments.extend([infile, outfile])
46 |
47 | subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
48 |
49 | # Load the processed audio.
50 | y_out, _ = sf.read(outfile, always_2d=True)
51 |
52 | # make sure that output dimensions matches input
53 | if y.ndim == 1:
54 | y_out = np.squeeze(y_out)
55 |
56 | except OSError as exc:
57 | six.raise_from(
58 | RuntimeError(
59 | "Failed to execute rubberband. "
60 | "Please verify that rubberband-cli "
61 | "is installed."
62 | ),
63 | exc,
64 | )
65 |
66 | finally:
67 | # Remove temp files
68 | os.unlink(infile)
69 | os.unlink(outfile)
70 |
71 | return y_out
72 |
73 |
74 | def time_stretch(y, sr, rate, rbargs=None):
75 | if rate <= 0:
76 | raise ValueError("rate must be strictly positive")
77 |
78 | if rate == 1.0:
79 | return y
80 |
81 | if rbargs is None:
82 | rbargs = dict()
83 |
84 | rbargs.setdefault("--tempo", rate)
85 |
86 | return __rubberband(y, sr, **rbargs)
87 |
88 |
89 | def pitch_shift(y, sr, n_steps, rbargs=None):
90 |
91 | if n_steps == 0:
92 | return y
93 |
94 | if rbargs is None:
95 | rbargs = dict()
96 |
97 | rbargs.setdefault("--pitch", n_steps)
98 |
99 | return __rubberband(y, sr, **rbargs)
100 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/results.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Matchering - Audio Matching and Mastering Python Library
5 | Copyright (C) 2016-2022 Sergree
6 |
7 | This program is free software: you can redistribute it and/or modify
8 | it under the terms of the GNU General Public License as published by
9 | the Free Software Foundation, either version 3 of the License, or
10 | (at your option) any later version.
11 |
12 | This program is distributed in the hope that it will be useful,
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | GNU General Public License for more details.
16 |
17 | You should have received a copy of the GNU General Public License
18 | along with this program. If not, see .
19 | """
20 |
21 | import os
22 | import soundfile as sf
23 |
24 |
25 | class Result:
26 | def __init__(
27 | self, file: str, subtype: str, use_limiter: bool = True, normalize: bool = True
28 | ):
29 | _, file_ext = os.path.splitext(file)
30 | file_ext = file_ext[1:].upper()
31 | if not sf.check_format(file_ext):
32 | raise TypeError(f"{file_ext} format is not supported")
33 | if not sf.check_format(file_ext, subtype):
34 | raise TypeError(f"{file_ext} format does not have {subtype} subtype")
35 | self.file = file
36 | self.subtype = subtype
37 | self.use_limiter = use_limiter
38 | self.normalize = normalize
39 |
40 |
41 | def pcm16(file: str) -> Result:
42 | return Result(file, "PCM_16")
43 |
44 |
45 | def pcm24(file: str) -> Result:
46 | return Result(file, "FLOAT")
47 |
48 |
49 | def save_audiofile(file: str, wav_set="PCM_16") -> Result:
50 | return Result(file, wav_set)
51 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class STFT:
5 | """
6 | This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT).
7 | These functions are essential for converting the audio between the time domain and the frequency domain,
8 | which is a crucial aspect of audio processing in neural networks.
9 | """
10 |
11 | def __init__(self, logger, n_fft, hop_length, dim_f, device):
12 | self.logger = logger
13 | self.n_fft = n_fft
14 | self.hop_length = hop_length
15 | self.dim_f = dim_f
16 | self.device = device
17 | # Create a Hann window tensor for use in the STFT.
18 | self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
19 |
20 | def __call__(self, input_tensor):
21 | # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
22 | is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
23 |
24 | # If on a non-standard device, temporarily move the tensor to CPU for processing.
25 | if is_non_standard_device:
26 | input_tensor = input_tensor.cpu()
27 |
28 | # Transfer the pre-defined window tensor to the same device as the input tensor.
29 | stft_window = self.hann_window.to(input_tensor.device)
30 |
31 | # Extract batch dimensions (all dimensions except the last two which are channel and time).
32 | batch_dimensions = input_tensor.shape[:-2]
33 |
34 | # Extract channel and time dimensions (last two dimensions of the tensor).
35 | channel_dim, time_dim = input_tensor.shape[-2:]
36 |
37 | # Reshape the tensor to merge batch and channel dimensions for STFT processing.
38 | reshaped_tensor = input_tensor.reshape([-1, time_dim])
39 |
40 | # Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
41 | stft_output = torch.stft(
42 | reshaped_tensor,
43 | n_fft=self.n_fft,
44 | hop_length=self.hop_length,
45 | window=stft_window,
46 | center=True,
47 | return_complex=False,
48 | )
49 |
50 | # Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
51 | permuted_stft_output = stft_output.permute([0, 3, 1, 2])
52 |
53 | # Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
54 | final_output = permuted_stft_output.reshape(
55 | [*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]
56 | ).reshape(
57 | [*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
58 | )
59 |
60 | # If the original tensor was on a non-standard device, move the processed tensor back to that device.
61 | if is_non_standard_device:
62 | final_output = final_output.to(self.device)
63 |
64 | # Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
65 | return final_output[..., : self.dim_f, :]
66 |
67 | def pad_frequency_dimension(
68 | self,
69 | input_tensor,
70 | batch_dimensions,
71 | channel_dim,
72 | freq_dim,
73 | time_dim,
74 | num_freq_bins,
75 | ):
76 | """
77 | Adds zero padding to the frequency dimension of the input tensor.
78 | """
79 | # Create a padding tensor for the frequency dimension
80 | freq_padding = torch.zeros(
81 | [*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]
82 | ).to(input_tensor.device)
83 |
84 | # Concatenate the padding to the input tensor along the frequency dimension.
85 | padded_tensor = torch.cat([input_tensor, freq_padding], -2)
86 |
87 | return padded_tensor
88 |
89 | def calculate_inverse_dimensions(self, input_tensor):
90 | # Extract batch dimensions and frequency-time dimensions.
91 | batch_dimensions = input_tensor.shape[:-3]
92 | channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
93 |
94 | # Calculate the number of frequency bins for the inverse STFT.
95 | num_freq_bins = self.n_fft // 2 + 1
96 |
97 | return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
98 |
99 | def prepare_for_istft(
100 | self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim
101 | ):
102 | """
103 | Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
104 | and creating a complex tensor from the real and imaginary parts.
105 | """
106 | # Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
107 | reshaped_tensor = padded_tensor.reshape(
108 | [*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]
109 | )
110 |
111 | # Flatten batch dimensions and rearrange for ISTFT.
112 | flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
113 |
114 | # Rearrange the dimensions of the tensor to bring the frequency dimension forward.
115 | permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
116 |
117 | # Combine real and imaginary parts into a complex tensor.
118 | complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
119 |
120 | return complex_tensor
121 |
122 | def inverse(self, input_tensor):
123 | # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
124 | is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
125 |
126 | # If on a non-standard device, temporarily move the tensor to CPU for processing.
127 | if is_non_standard_device:
128 | input_tensor = input_tensor.cpu()
129 |
130 | # Transfer the pre-defined Hann window tensor to the same device as the input tensor.
131 | stft_window = self.hann_window.to(input_tensor.device)
132 |
133 | batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = (
134 | self.calculate_inverse_dimensions(input_tensor)
135 | )
136 |
137 | padded_tensor = self.pad_frequency_dimension(
138 | input_tensor,
139 | batch_dimensions,
140 | channel_dim,
141 | freq_dim,
142 | time_dim,
143 | num_freq_bins,
144 | )
145 |
146 | complex_tensor = self.prepare_for_istft(
147 | padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim
148 | )
149 |
150 | # Perform the Inverse Short-Time Fourier Transform (ISTFT).
151 | istft_result = torch.istft(
152 | complex_tensor,
153 | n_fft=self.n_fft,
154 | hop_length=self.hop_length,
155 | window=stft_window,
156 | center=True,
157 | )
158 |
159 | # Reshape ISTFT result to restore original batch and channel dimensions.
160 | final_output = istft_result.reshape([*batch_dimensions, 2, -1])
161 |
162 | # If the original tensor was on a non-standard device, move the processed tensor back to that device.
163 | if is_non_standard_device:
164 | final_output = final_output.to(self.device)
165 |
166 | return final_output
167 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/__init__.py:
--------------------------------------------------------------------------------
1 | # VR init.
2 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/layers_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from uvr.uvr_lib_v5 import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | """
10 | Conv2DBNActiv Class:
11 | This class implements a convolutional layer followed by batch normalization and an activation function.
12 | It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks.
13 | The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation.
14 | """
15 |
16 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
17 | super(Conv2DBNActiv, self).__init__()
18 |
19 | # Sequential model combining Conv2D, BatchNorm, and activation function into a single module
20 | self.conv = nn.Sequential(
21 | nn.Conv2d(
22 | nin,
23 | nout,
24 | kernel_size=ksize,
25 | stride=stride,
26 | padding=pad,
27 | dilation=dilation,
28 | bias=False,
29 | ),
30 | nn.BatchNorm2d(nout),
31 | activ(),
32 | )
33 |
34 | def __call__(self, input_tensor):
35 | # Forward pass through the sequential model
36 | return self.conv(input_tensor)
37 |
38 |
39 | class Encoder(nn.Module):
40 | """
41 | Encoder Class:
42 | This class defines an encoder module typically used in autoencoder architectures.
43 | It consists of two convolutional layers, each followed by batch normalization and an activation function.
44 | """
45 |
46 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
47 | super(Encoder, self).__init__()
48 |
49 | # First convolutional layer of the encoder
50 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
51 | # Second convolutional layer of the encoder
52 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
53 |
54 | def __call__(self, input_tensor):
55 | # Applying the first and then the second convolutional layers
56 | hidden = self.conv1(input_tensor)
57 | hidden = self.conv2(hidden)
58 |
59 | return hidden
60 |
61 |
62 | class Decoder(nn.Module):
63 | """
64 | Decoder Class:
65 | This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures.
66 | It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization.
67 | """
68 |
69 | def __init__(
70 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
71 | ):
72 | super(Decoder, self).__init__()
73 | # Convolutional layer with optional dropout for regularization
74 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
75 | # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
76 | self.dropout = nn.Dropout2d(0.1) if dropout else None
77 |
78 | def __call__(self, input_tensor, skip=None):
79 | # Forward pass through the convolutional layer and optional dropout
80 | input_tensor = F.interpolate(
81 | input_tensor, scale_factor=2, mode="bilinear", align_corners=True
82 | )
83 |
84 | if skip is not None:
85 | skip = spec_utils.crop_center(skip, input_tensor)
86 | input_tensor = torch.cat([input_tensor, skip], dim=1)
87 |
88 | hidden = self.conv1(input_tensor)
89 | # hidden = self.conv2(hidden)
90 |
91 | if self.dropout is not None:
92 | hidden = self.dropout(hidden)
93 |
94 | return hidden
95 |
96 |
97 | class ASPPModule(nn.Module):
98 | """
99 | ASPPModule Class:
100 | This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks.
101 | It captures multi-scale contextual information by applying convolutions at multiple dilation rates.
102 | """
103 |
104 | def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
105 | super(ASPPModule, self).__init__()
106 |
107 | # Global context convolution captures the overall context
108 | self.conv1 = nn.Sequential(
109 | nn.AdaptiveAvgPool2d((1, None)),
110 | Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
111 | )
112 | self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
113 | self.conv3 = Conv2DBNActiv(
114 | nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
115 | )
116 | self.conv4 = Conv2DBNActiv(
117 | nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
118 | )
119 | self.conv5 = Conv2DBNActiv(
120 | nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
121 | )
122 | self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
123 | self.dropout = nn.Dropout2d(0.1) if dropout else None
124 |
125 | def forward(self, input_tensor):
126 | _, _, h, w = input_tensor.size()
127 |
128 | # Upsample global context to match input size and combine with local and multi-scale features
129 | feat1 = F.interpolate(
130 | self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True
131 | )
132 | feat2 = self.conv2(input_tensor)
133 | feat3 = self.conv3(input_tensor)
134 | feat4 = self.conv4(input_tensor)
135 | feat5 = self.conv5(input_tensor)
136 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
137 | out = self.bottleneck(out)
138 |
139 | if self.dropout is not None:
140 | out = self.dropout(out)
141 |
142 | return out
143 |
144 |
145 | class LSTMModule(nn.Module):
146 | """
147 | LSTMModule Class:
148 | This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling.
149 | It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing.
150 | """
151 |
152 | def __init__(self, nin_conv, nin_lstm, nout_lstm):
153 | super(LSTMModule, self).__init__()
154 | # Convolutional layer for initial feature extraction
155 | self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
156 |
157 | # Bidirectional LSTM for capturing temporal dynamics
158 | self.lstm = nn.LSTM(
159 | input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True
160 | )
161 |
162 | # Dense layer for output dimensionality matching
163 | self.dense = nn.Sequential(
164 | nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
165 | )
166 |
167 | def forward(self, input_tensor):
168 | N, _, nbins, nframes = input_tensor.size()
169 |
170 | # Extract features and prepare for LSTM
171 | hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
172 | hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
173 | hidden, _ = self.lstm(hidden)
174 |
175 | # Apply dense layer and reshape to match expected output format
176 | hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
177 | hidden = hidden.reshape(nframes, N, 1, nbins)
178 | hidden = hidden.permute(1, 2, 3, 0)
179 |
180 | return hidden
181 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/model_param_init.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | default_param = {}
4 | default_param["bins"] = -1
5 | default_param["unstable_bins"] = -1 # training only
6 | default_param["stable_bins"] = -1 # training only
7 | default_param["sr"] = 44100
8 | default_param["pre_filter_start"] = -1
9 | default_param["pre_filter_stop"] = -1
10 | default_param["band"] = {}
11 |
12 | N_BINS = "n_bins"
13 |
14 |
15 | def int_keys(d):
16 | """
17 | Converts string keys that represent integers into actual integer keys in a list.
18 |
19 | This function is particularly useful when dealing with JSON data that may represent
20 | integer keys as strings due to the nature of JSON encoding. By converting these keys
21 | back to integers, it ensures that the data can be used in a manner consistent with
22 | its original representation, especially in contexts where the distinction between
23 | string and integer keys is important.
24 |
25 | Args:
26 | input_list (list of tuples): A list of (key, value) pairs where keys are strings
27 | that may represent integers.
28 |
29 | Returns:
30 | dict: A dictionary with keys converted to integers where applicable.
31 | """
32 | # Initialize an empty dictionary to hold the converted key-value pairs.
33 | result_dict = {}
34 | # Iterate through each key-value pair in the input list.
35 | for key, value in d:
36 | # Check if the key is a digit (i.e., represents an integer).
37 | if key.isdigit():
38 | # Convert the key from a string to an integer.
39 | key = int(key)
40 | result_dict[key] = value
41 | return result_dict
42 |
43 |
44 | class ModelParameters(object):
45 | """
46 | A class to manage model parameters, including loading from a configuration file.
47 |
48 | Attributes:
49 | param (dict): Dictionary holding all parameters for the model.
50 | """
51 |
52 | def __init__(self, config_path=""):
53 | """
54 | Initializes the ModelParameters object by loading parameters from a JSON configuration file.
55 |
56 | Args:
57 | config_path (str): Path to the JSON configuration file.
58 | """
59 |
60 | # Load parameters from the given configuration file path.
61 | with open(config_path, "r") as f:
62 | self.param = json.loads(f.read(), object_pairs_hook=int_keys)
63 |
64 | # Ensure certain parameters are set to False if not specified in the configuration.
65 | for k in [
66 | "mid_side",
67 | "mid_side_b",
68 | "mid_side_b2",
69 | "stereo_w",
70 | "stereo_n",
71 | "reverse",
72 | ]:
73 | if not k in self.param:
74 | self.param[k] = False
75 |
76 | # If 'n_bins' is specified in the parameters, it's used as the value for 'bins'.
77 | if N_BINS in self.param:
78 | self.param["bins"] = self.param[N_BINS]
79 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 16000,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 16000,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 32000,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "kaiser_fast"
14 | }
15 | },
16 | "sr": 32000,
17 | "pre_filter_start": 1000,
18 | "pre_filter_stop": 1021
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 33075,
8 | "hl": 384,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 33075,
17 | "pre_filter_start": 1000,
18 | "pre_filter_stop": 1021
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 1024,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl256.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 256,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 256,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 256,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 256,
18 | "pre_filter_stop": 256
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 700,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 700
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 512,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 512,
9 | "n_fft": 1024,
10 | "crop_start": 0,
11 | "crop_stop": 512,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 511,
18 | "pre_filter_stop": 512
19 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/2band_32000.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 705,
5 | "band": {
6 | "1": {
7 | "sr": 6000,
8 | "hl": 66,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 240,
12 | "lpf_start": 60,
13 | "lpf_stop": 118,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 32000,
18 | "hl": 352,
19 | "n_fft": 1024,
20 | "crop_start": 22,
21 | "crop_stop": 505,
22 | "hpf_start": 44,
23 | "hpf_stop": 23,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 32000,
28 | "pre_filter_start": 710,
29 | "pre_filter_stop": 731
30 | }
31 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/2band_44100_lofi.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 512,
3 | "unstable_bins": 7,
4 | "reduction_bins": 510,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 160,
9 | "n_fft": 768,
10 | "crop_start": 0,
11 | "crop_stop": 192,
12 | "lpf_start": 41,
13 | "lpf_stop": 139,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 44100,
18 | "hl": 640,
19 | "n_fft": 1024,
20 | "crop_start": 10,
21 | "crop_stop": 320,
22 | "hpf_start": 47,
23 | "hpf_stop": 15,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 44100,
28 | "pre_filter_start": 510,
29 | "pre_filter_stop": 512
30 | }
31 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/2band_48000.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 705,
5 | "band": {
6 | "1": {
7 | "sr": 6000,
8 | "hl": 66,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 240,
12 | "lpf_start": 60,
13 | "lpf_stop": 240,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 48000,
18 | "hl": 528,
19 | "n_fft": 1536,
20 | "crop_start": 22,
21 | "crop_stop": 505,
22 | "hpf_start": 82,
23 | "hpf_stop": 22,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 48000,
28 | "pre_filter_start": 710,
29 | "pre_filter_stop": 731
30 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/3band_44100.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 5,
4 | "reduction_bins": 733,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 128,
9 | "n_fft": 768,
10 | "crop_start": 0,
11 | "crop_stop": 278,
12 | "lpf_start": 28,
13 | "lpf_stop": 140,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 22050,
18 | "hl": 256,
19 | "n_fft": 768,
20 | "crop_start": 14,
21 | "crop_stop": 322,
22 | "hpf_start": 70,
23 | "hpf_stop": 14,
24 | "lpf_start": 283,
25 | "lpf_stop": 314,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 44100,
30 | "hl": 512,
31 | "n_fft": 768,
32 | "crop_start": 131,
33 | "crop_stop": 313,
34 | "hpf_start": 154,
35 | "hpf_stop": 141,
36 | "res_type": "sinc_medium"
37 | }
38 | },
39 | "sr": 44100,
40 | "pre_filter_start": 757,
41 | "pre_filter_stop": 768
42 | }
43 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/3band_44100_mid.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side": true,
3 | "bins": 768,
4 | "unstable_bins": 5,
5 | "reduction_bins": 733,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 768,
11 | "crop_start": 0,
12 | "crop_stop": 278,
13 | "lpf_start": 28,
14 | "lpf_stop": 140,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 256,
20 | "n_fft": 768,
21 | "crop_start": 14,
22 | "crop_stop": 322,
23 | "hpf_start": 70,
24 | "hpf_stop": 14,
25 | "lpf_start": 283,
26 | "lpf_stop": 314,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 512,
32 | "n_fft": 768,
33 | "crop_start": 131,
34 | "crop_stop": 313,
35 | "hpf_start": 154,
36 | "hpf_stop": 141,
37 | "res_type": "sinc_medium"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 757,
42 | "pre_filter_stop": 768
43 | }
44 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/3band_44100_msb2.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b2": true,
3 | "bins": 640,
4 | "unstable_bins": 7,
5 | "reduction_bins": 565,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 108,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 187,
13 | "lpf_start": 92,
14 | "lpf_stop": 186,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 216,
20 | "n_fft": 768,
21 | "crop_start": 0,
22 | "crop_stop": 212,
23 | "hpf_start": 68,
24 | "hpf_stop": 34,
25 | "lpf_start": 174,
26 | "lpf_stop": 209,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 432,
32 | "n_fft": 640,
33 | "crop_start": 66,
34 | "crop_stop": 307,
35 | "hpf_start": 86,
36 | "hpf_stop": 72,
37 | "res_type": "kaiser_fast"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 639,
42 | "pre_filter_stop": 640
43 | }
44 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 668,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 128,
9 | "n_fft": 1024,
10 | "crop_start": 0,
11 | "crop_stop": 186,
12 | "lpf_start": 37,
13 | "lpf_stop": 73,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 11025,
18 | "hl": 128,
19 | "n_fft": 512,
20 | "crop_start": 4,
21 | "crop_stop": 185,
22 | "hpf_start": 36,
23 | "hpf_stop": 18,
24 | "lpf_start": 93,
25 | "lpf_stop": 185,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 22050,
30 | "hl": 256,
31 | "n_fft": 512,
32 | "crop_start": 46,
33 | "crop_stop": 186,
34 | "hpf_start": 93,
35 | "hpf_stop": 46,
36 | "lpf_start": 164,
37 | "lpf_stop": 186,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 512,
43 | "n_fft": 768,
44 | "crop_start": 121,
45 | "crop_stop": 382,
46 | "hpf_start": 138,
47 | "hpf_stop": 123,
48 | "res_type": "sinc_medium"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 740,
53 | "pre_filter_stop": 768
54 | }
55 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_mid.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "mid_side": true,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
56 |
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_msb.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_msb2.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_reverse.json:
--------------------------------------------------------------------------------
1 | {
2 | "reverse": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_sw.json:
--------------------------------------------------------------------------------
1 | {
2 | "stereo_w": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 637,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "res_type": "kaiser_fast"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 668,
53 | "pre_filter_stop": 672
54 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_v2_sn.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 637,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "convert_channels": "stereo_n",
49 | "res_type": "kaiser_fast"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 668,
54 | "pre_filter_stop": 672
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 530,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "res_type": "kaiser_fast"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 668,
53 | "pre_filter_stop": 672
54 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/4band_v3_sn.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_bins": 672,
3 | "unstable_bins": 8,
4 | "stable_bins": 530,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "convert_channels": "stereo_n",
49 | "res_type": "kaiser_fast"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 668,
54 | "pre_filter_stop": 672
55 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/modelparams/ensemble.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b2": true,
3 | "bins": 1280,
4 | "unstable_bins": 7,
5 | "reduction_bins": 565,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 108,
10 | "n_fft": 2048,
11 | "crop_start": 0,
12 | "crop_stop": 374,
13 | "lpf_start": 92,
14 | "lpf_stop": 186,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 216,
20 | "n_fft": 1536,
21 | "crop_start": 0,
22 | "crop_stop": 424,
23 | "hpf_start": 68,
24 | "hpf_stop": 34,
25 | "lpf_start": 348,
26 | "lpf_stop": 418,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 432,
32 | "n_fft": 1280,
33 | "crop_start": 132,
34 | "crop_stop": 614,
35 | "hpf_start": 172,
36 | "hpf_stop": 144,
37 | "res_type": "polyphase"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 1280,
42 | "pre_filter_stop": 1280
43 | }
--------------------------------------------------------------------------------
/uvr/uvr_lib_v5/vr_network/nets_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from . import layers_new as layers
5 |
6 |
7 | class BaseNet(nn.Module):
8 | """
9 | BaseNet Class:
10 | This class defines the base network architecture for vocal removal. It includes a series of encoders for feature extraction,
11 | an ASPP module for capturing multi-scale context, and a series of decoders for reconstructing the output. Additionally,
12 | it incorporates an LSTM module for capturing temporal dependencies.
13 | """
14 |
15 | def __init__(
16 | self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))
17 | ):
18 | super(BaseNet, self).__init__()
19 | # Initialize the encoder layers with increasing output channels for hierarchical feature extraction.
20 | self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
21 | self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
22 | self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1)
23 | self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1)
24 | self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1)
25 |
26 | # ASPP module for capturing multi-scale features with different dilation rates.
27 | self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
28 |
29 | # Decoder layers for upscaling and merging features from different levels of the encoder and ASPP module.
30 | self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
31 | self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
32 | self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
33 |
34 | # LSTM module for capturing temporal dependencies in the sequence of features.
35 | self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
36 | self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
37 |
38 | def __call__(self, input_tensor):
39 | # Sequentially pass the input through the encoder layers.
40 | encoded1 = self.enc1(input_tensor)
41 | encoded2 = self.enc2(encoded1)
42 | encoded3 = self.enc3(encoded2)
43 | encoded4 = self.enc4(encoded3)
44 | encoded5 = self.enc5(encoded4)
45 |
46 | # Pass the deepest encoder output through the ASPP module.
47 | bottleneck = self.aspp(encoded5)
48 |
49 | # Sequentially upscale and merge the features using the decoder layers.
50 | bottleneck = self.dec4(bottleneck, encoded4)
51 | bottleneck = self.dec3(bottleneck, encoded3)
52 | bottleneck = self.dec2(bottleneck, encoded2)
53 | # Concatenate the LSTM module output for temporal feature enhancement.
54 | bottleneck = torch.cat([bottleneck, self.lstm_dec2(bottleneck)], dim=1)
55 | bottleneck = self.dec1(bottleneck, encoded1)
56 |
57 | return bottleneck
58 |
59 |
60 | class CascadedNet(nn.Module):
61 | """
62 | CascadedNet Class:
63 | This class defines a cascaded network architecture that processes input in multiple stages, each stage focusing on different frequency bands.
64 | It utilizes the BaseNet for processing, and combines outputs from different stages to produce the final mask for vocal removal.
65 | """
66 |
67 | def __init__(self, n_fft, nn_arch_size=51000, nout=32, nout_lstm=128):
68 | super(CascadedNet, self).__init__()
69 | # Calculate frequency bins based on FFT size.
70 | self.max_bin = n_fft // 2
71 | self.output_bin = n_fft // 2 + 1
72 | self.nin_lstm = self.max_bin // 2
73 | self.offset = 64
74 | # Adjust output channels based on the architecture size.
75 | nout = 64 if nn_arch_size == 218409 else nout
76 |
77 | # print(nout, nout_lstm, n_fft)
78 |
79 | # Initialize the network stages, each focusing on different frequency bands and progressively refining the output.
80 | self.stg1_low_band_net = nn.Sequential(
81 | BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
82 | layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
83 | )
84 | self.stg1_high_band_net = BaseNet(
85 | 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
86 | )
87 |
88 | self.stg2_low_band_net = nn.Sequential(
89 | BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
90 | layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
91 | )
92 | self.stg2_high_band_net = BaseNet(
93 | nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
94 | )
95 |
96 | self.stg3_full_band_net = BaseNet(
97 | 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
98 | )
99 |
100 | # Output layer for generating the final mask.
101 | self.out = nn.Conv2d(nout, 2, 1, bias=False)
102 | # Auxiliary output layer for intermediate supervision during training.
103 | self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
104 |
105 | def forward(self, input_tensor):
106 | # Preprocess input tensor to match the maximum frequency bin.
107 | input_tensor = input_tensor[:, :, : self.max_bin]
108 |
109 | # Split the input into low and high frequency bands.
110 | bandw = input_tensor.size()[2] // 2
111 | l1_in = input_tensor[:, :, :bandw]
112 | h1_in = input_tensor[:, :, bandw:]
113 |
114 | # Process each band through the first stage networks.
115 | l1 = self.stg1_low_band_net(l1_in)
116 | h1 = self.stg1_high_band_net(h1_in)
117 |
118 | # Combine the outputs for auxiliary supervision.
119 | aux1 = torch.cat([l1, h1], dim=2)
120 |
121 | # Prepare inputs for the second stage by concatenating the original and processed bands.
122 | l2_in = torch.cat([l1_in, l1], dim=1)
123 | h2_in = torch.cat([h1_in, h1], dim=1)
124 |
125 | # Process through the second stage networks.
126 | l2 = self.stg2_low_band_net(l2_in)
127 | h2 = self.stg2_high_band_net(h2_in)
128 |
129 | # Combine the outputs for auxiliary supervision.
130 | aux2 = torch.cat([l2, h2], dim=2)
131 |
132 | # Prepare input for the third stage by concatenating all previous outputs with the original input.
133 | f3_in = torch.cat([input_tensor, aux1, aux2], dim=1)
134 |
135 | # Process through the third stage network.
136 | f3 = self.stg3_full_band_net(f3_in)
137 |
138 | # Apply the output layer to generate the final mask and apply sigmoid for normalization.
139 | mask = torch.sigmoid(self.out(f3))
140 |
141 | # Pad the mask to match the output frequency bin size.
142 | mask = F.pad(
143 | input=mask,
144 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
145 | mode="replicate",
146 | )
147 |
148 | # During training, generate and pad the auxiliary output for additional supervision.
149 | if self.training:
150 | aux = torch.cat([aux1, aux2], dim=1)
151 | aux = torch.sigmoid(self.aux_out(aux))
152 | aux = F.pad(
153 | input=aux,
154 | pad=(0, 0, 0, self.output_bin - aux.size()[2]),
155 | mode="replicate",
156 | )
157 | return mask, aux
158 | else:
159 | return mask
160 |
161 | # Method for predicting the mask given an input tensor.
162 | def predict_mask(self, input_tensor):
163 | mask = self.forward(input_tensor)
164 |
165 | # If an offset is specified, crop the mask to remove edge artifacts.
166 | if self.offset > 0:
167 | mask = mask[:, :, :, self.offset : -self.offset]
168 | assert mask.size()[3] > 0
169 |
170 | return mask
171 |
172 | # Method for applying the predicted mask to the input tensor to obtain the predicted magnitude.
173 | def predict(self, input_tensor):
174 | mask = self.forward(input_tensor)
175 | pred_mag = input_tensor * mask
176 |
177 | # If an offset is specified, crop the predicted magnitude to remove edge artifacts.
178 | if self.offset > 0:
179 | pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
180 | assert pred_mag.size()[3] > 0
181 |
182 | return pred_mag
183 |
--------------------------------------------------------------------------------